train.sh 2.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103
  1. # Args parameters
  2. MODEL=$1
  3. BATCH_SIZE=$2
  4. WORLD_SIZE=$3
  5. MASTER_PORT=$4
  6. RESUME=$5
  7. # Dataset setting
  8. DATASET="coco"
  9. DATA_ROOT="/data/datasets/"
  10. # MODEL setting
  11. IMAGE_SIZE=640
  12. if [[ $MODEL == *"rtcdet"* ]]; then
  13. # Epoch setting
  14. MAX_EPOCH=500
  15. WP_EPOCH=3
  16. EVAL_EPOCH=10
  17. NO_AUG_EPOCH=20
  18. elif [[ $MODEL == *"yolov8"* ]]; then
  19. # Epoch setting
  20. MAX_EPOCH=500
  21. WP_EPOCH=3
  22. EVAL_EPOCH=10
  23. NO_AUG_EPOCH=20
  24. elif [[ $MODEL == *"yolox"* ]]; then
  25. # Epoch setting
  26. MAX_EPOCH=300
  27. WP_EPOCH=3
  28. EVAL_EPOCH=10
  29. NO_AUG_EPOCH=20
  30. elif [[ $MODEL == *"yolov7"* ]]; then
  31. # Epoch setting
  32. MAX_EPOCH=300
  33. WP_EPOCH=3
  34. EVAL_EPOCH=10
  35. NO_AUG_EPOCH=20
  36. elif [[ $MODEL == *"yolov5"* ]]; then
  37. # Epoch setting
  38. MAX_EPOCH=300
  39. WP_EPOCH=3
  40. EVAL_EPOCH=10
  41. NO_AUG_EPOCH=20
  42. elif [[ $MODEL == *"yolov4"* ]]; then
  43. # Epoch setting
  44. MAX_EPOCH=300
  45. WP_EPOCH=3
  46. EVAL_EPOCH=10
  47. NO_AUG_EPOCH=20
  48. elif [[ $MODEL == *"yolov3"* ]]; then
  49. # Epoch setting
  50. MAX_EPOCH=300
  51. WP_EPOCH=3
  52. EVAL_EPOCH=10
  53. NO_AUG_EPOCH=20
  54. else
  55. # Epoch setting
  56. MAX_EPOCH=150
  57. WP_EPOCH=3
  58. EVAL_EPOCH=10
  59. NO_AUG_EPOCH=10
  60. fi
  61. # -------------------------- Train Pipeline --------------------------
  62. if [ $WORLD_SIZE == 1 ]; then
  63. python train.py \
  64. --cuda \
  65. --dataset ${DATASET} \
  66. --root ${DATA_ROOT} \
  67. --model ${MODEL} \
  68. --batch_size ${BATCH_SIZE} \
  69. --img_size ${IMAGE_SIZE} \
  70. --wp_epoch ${WP_EPOCH} \
  71. --max_epoch ${MAX_EPOCH} \
  72. --eval_epoch ${EVAL_EPOCH} \
  73. --no_aug_epoch ${NO_AUG_EPOCH} \
  74. --resume ${RESUME} \
  75. --ema \
  76. --fp16 \
  77. --multi_scale
  78. elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
  79. python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
  80. --cuda \
  81. -dist \
  82. --dataset ${DATASET} \
  83. --root ${DATA_ROOT} \
  84. --model ${MODEL} \
  85. --batch_size ${BATCH_SIZE} \
  86. --img_size ${IMAGE_SIZE} \
  87. --wp_epoch ${WP_EPOCH} \
  88. --max_epoch ${MAX_EPOCH} \
  89. --eval_epoch ${EVAL_EPOCH} \
  90. --no_aug_epoch ${NO_AUG_EPOCH} \
  91. --resume ${RESUME} \
  92. --ema \
  93. --fp16 \
  94. --multi_scale \
  95. --sybn
  96. else
  97. echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
  98. multi-card training mode, which is currently unsupported."
  99. exit 1
  100. fi