train.sh 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. # Args setting
  2. MODEL=$1
  3. DATASET=$2
  4. DATA_ROOT=$3
  5. WORLD_SIZE=$4
  6. MASTER_PORT=$5
  7. if [[ $MODEL == *"yolof"* ]]; then
  8. # Epoch setting
  9. BATCH_SIZE=64
  10. EVAL_EPOCH=2
  11. elif [[ $MODEL == *"fcos"* ]]; then
  12. # Epoch setting
  13. BATCH_SIZE=16
  14. EVAL_EPOCH=2
  15. elif [[ $MODEL == *"retinanet"* ]]; then
  16. # Epoch setting
  17. BATCH_SIZE=16
  18. EVAL_EPOCH=2
  19. elif [[ $MODEL == *"plain_detr"* ]]; then
  20. # Epoch setting
  21. BATCH_SIZE=16
  22. EVAL_EPOCH=2
  23. elif [[ $MODEL == *"rtdetr"* ]]; then
  24. # Epoch setting
  25. BATCH_SIZE=16
  26. EVAL_EPOCH=1
  27. fi
  28. # -------------------------- Train Pipeline --------------------------
  29. if [ $WORLD_SIZE == 1 ]; then
  30. python main.py \
  31. --cuda \
  32. --dataset ${DATASET} \
  33. --root ${DATA_ROOT} \
  34. --model ${MODEL} \
  35. --batch_size ${BATCH_SIZE} \
  36. --eval_epoch ${EVAL_EPOCH}
  37. elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
  38. python -m torch.distributed.run --nproc_per_node=$WORLD_SIZE --master_port ${MASTER_PORT} \
  39. main.py \
  40. --cuda \
  41. --distributed \
  42. --dataset ${DATASET} \
  43. --root ${DATA_ROOT} \
  44. --model ${MODEL} \
  45. --batch_size ${BATCH_SIZE} \
  46. --eval_epoch ${EVAL_EPOCH}
  47. else
  48. echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
  49. multi-card training mode, which is currently unsupported."
  50. exit 1
  51. fi