train.sh 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # Dataset setting
  2. DATASET="coco"
  3. DATA_ROOT="/data/datasets/"
  4. # MODEL setting
  5. MODEL=$1
  6. RESUME="None"
  7. BATCH_SIZE=$2
  8. IMAGE_SIZE=640
  9. if [[ $MODEL == *"yolov8"* ]]; then
  10. # Epoch setting
  11. MAX_EPOCH=500
  12. WP_EPOCH=3
  13. EVAL_EPOCH=10
  14. NO_AUG_EPOCH=20
  15. elif [[ $MODEL == *"yolox"* ]]; then
  16. # Epoch setting
  17. MAX_EPOCH=300
  18. WP_EPOCH=3
  19. EVAL_EPOCH=10
  20. NO_AUG_EPOCH=20
  21. elif [[ $MODEL == *"yolov7"* ]]; then
  22. # Epoch setting
  23. MAX_EPOCH=300
  24. WP_EPOCH=3
  25. EVAL_EPOCH=10
  26. NO_AUG_EPOCH=20
  27. elif [[ $MODEL == *"yolov5"* ]]; then
  28. # Epoch setting
  29. MAX_EPOCH=300
  30. WP_EPOCH=3
  31. EVAL_EPOCH=10
  32. NO_AUG_EPOCH=20
  33. elif [[ $MODEL == *"yolov4"* ]]; then
  34. # Epoch setting
  35. MAX_EPOCH=300
  36. WP_EPOCH=3
  37. EVAL_EPOCH=10
  38. NO_AUG_EPOCH=20
  39. elif [[ $MODEL == *"yolov3"* ]]; then
  40. # Epoch setting
  41. MAX_EPOCH=300
  42. WP_EPOCH=3
  43. EVAL_EPOCH=10
  44. NO_AUG_EPOCH=20
  45. elif [[ $MODEL == *"rtcdet"* ]]; then
  46. # Epoch setting
  47. MAX_EPOCH=300
  48. WP_EPOCH=3
  49. EVAL_EPOCH=10
  50. NO_AUG_EPOCH=20
  51. elif [[ $MODEL == *"ctrnet"* ]]; then
  52. # Epoch setting
  53. MAX_EPOCH=300
  54. WP_EPOCH=3
  55. EVAL_EPOCH=10
  56. NO_AUG_EPOCH=20
  57. else
  58. # Epoch setting
  59. MAX_EPOCH=150
  60. WP_EPOCH=3
  61. EVAL_EPOCH=10
  62. NO_AUG_EPOCH=0
  63. fi
  64. # -------------------------- Train Pipeline --------------------------
  65. WORLD_SIZE=$3
  66. MASTER_PORT=$4
  67. if [ $WORLD_SIZE == 1 ]; then
  68. python train.py \
  69. --cuda \
  70. --dataset ${DATASET} \
  71. --root ${DATA_ROOT} \
  72. --model ${MODEL} \
  73. --batch_size ${BATCH_SIZE} \
  74. --img_size ${IMAGE_SIZE} \
  75. --wp_epoch ${WP_EPOCH} \
  76. --max_epoch ${MAX_EPOCH} \
  77. --eval_epoch ${EVAL_EPOCH} \
  78. --no_aug_epoch ${NO_AUG_EPOCH} \
  79. --resume ${RESUME} \
  80. --ema \
  81. --fp16 \
  82. --multi_scale
  83. elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
  84. python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
  85. --cuda \
  86. -dist \
  87. --dataset ${DATASET} \
  88. --root ${DATA_ROOT} \
  89. --model ${MODEL} \
  90. --batch_size ${BATCH_SIZE} \
  91. --img_size ${IMAGE_SIZE} \
  92. --wp_epoch ${WP_EPOCH} \
  93. --max_epoch ${MAX_EPOCH} \
  94. --eval_epoch ${EVAL_EPOCH} \
  95. --no_aug_epoch ${NO_AUG_EPOCH} \
  96. --resume ${RESUME} \
  97. --ema \
  98. --fp16 \
  99. --multi_scale \
  100. --sybn
  101. else
  102. echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
  103. multi-card training mode, which is currently unsupported."
  104. exit 1
  105. fi