train.sh 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # Args parameters
  2. MODEL=$1
  3. DATASET=$2
  4. DATASET_ROOT=$3
  5. BATCH_SIZE=$4
  6. WORLD_SIZE=$5
  7. MASTER_PORT=$6
  8. RESUME=$7
  9. # MODEL setting
  10. IMAGE_SIZE=640
  11. FIND_UNUSED_PARAMS=False
  12. if [[ $MODEL == *"rtdetr"* ]]; then
  13. # Epoch setting
  14. MAX_EPOCH=72
  15. WP_EPOCH=-1
  16. EVAL_EPOCH=1
  17. NO_AUG_EPOCH=-1
  18. FIND_UNUSED_PARAMS=True
  19. elif [[ $MODEL == *"yolov8"* ]]; then
  20. # Epoch setting
  21. MAX_EPOCH=500
  22. WP_EPOCH=3
  23. EVAL_EPOCH=10
  24. NO_AUG_EPOCH=20
  25. elif [[ $MODEL == *"yolox"* ]]; then
  26. # Epoch setting
  27. MAX_EPOCH=300
  28. WP_EPOCH=3
  29. EVAL_EPOCH=10
  30. NO_AUG_EPOCH=20
  31. elif [[ $MODEL == *"yolov7"* ]]; then
  32. # Epoch setting
  33. MAX_EPOCH=300
  34. WP_EPOCH=3
  35. EVAL_EPOCH=10
  36. NO_AUG_EPOCH=20
  37. elif [[ $MODEL == *"yolov5"* ]]; then
  38. # Epoch setting
  39. MAX_EPOCH=300
  40. WP_EPOCH=3
  41. EVAL_EPOCH=10
  42. NO_AUG_EPOCH=20
  43. elif [[ $MODEL == *"yolov4"* ]]; then
  44. # Epoch setting
  45. MAX_EPOCH=300
  46. WP_EPOCH=3
  47. EVAL_EPOCH=10
  48. NO_AUG_EPOCH=20
  49. elif [[ $MODEL == *"yolov3"* ]]; then
  50. # Epoch setting
  51. MAX_EPOCH=300
  52. WP_EPOCH=3
  53. EVAL_EPOCH=10
  54. NO_AUG_EPOCH=20
  55. else
  56. # Epoch setting
  57. MAX_EPOCH=150
  58. WP_EPOCH=3
  59. EVAL_EPOCH=10
  60. NO_AUG_EPOCH=10
  61. fi
  62. # -------------------------- Train Pipeline --------------------------
  63. if [ $WORLD_SIZE == 1 ]; then
  64. python train.py \
  65. --cuda \
  66. --dataset ${DATASET} \
  67. --root ${DATASET_ROOT} \
  68. --model ${MODEL} \
  69. --batch_size ${BATCH_SIZE} \
  70. --img_size ${IMAGE_SIZE} \
  71. --wp_epoch ${WP_EPOCH} \
  72. --max_epoch ${MAX_EPOCH} \
  73. --eval_epoch ${EVAL_EPOCH} \
  74. --no_aug_epoch ${NO_AUG_EPOCH} \
  75. --resume ${RESUME} \
  76. --ema \
  77. --fp16 \
  78. --find_unused_parameters ${FIND_UNUSED_PARAMS} \
  79. --multi_scale
  80. elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
  81. python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
  82. --cuda \
  83. -dist \
  84. --dataset ${DATASET} \
  85. --root ${DATASET_ROOT} \
  86. --model ${MODEL} \
  87. --batch_size ${BATCH_SIZE} \
  88. --img_size ${IMAGE_SIZE} \
  89. --wp_epoch ${WP_EPOCH} \
  90. --max_epoch ${MAX_EPOCH} \
  91. --eval_epoch ${EVAL_EPOCH} \
  92. --no_aug_epoch ${NO_AUG_EPOCH} \
  93. --resume ${RESUME} \
  94. --ema \
  95. --fp16 \
  96. --find_unused_parameters ${FIND_UNUSED_PARAMS} \
  97. --multi_scale \
  98. --sybn
  99. else
  100. echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
  101. multi-card training mode, which is currently unsupported."
  102. exit 1
  103. fi