train.sh 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. # Dataset setting
  2. DATASET="coco"
  3. DATA_ROOT="/data/datasets/"
  4. # MODEL setting
  5. MODEL="yolox_s"
  6. IMAGE_SIZE=640
  7. RESUME="None"
  8. if [[ $MODEL == *"yolov8"* ]]; then
  9. # Epoch setting
  10. BATCH_SIZE=128
  11. MAX_EPOCH=500
  12. WP_EPOCH=3
  13. EVAL_EPOCH=10
  14. NO_AUG_EPOCH=20
  15. elif [[ $MODEL == *"yolox"* ]]; then
  16. # Epoch setting
  17. BATCH_SIZE=128
  18. MAX_EPOCH=300
  19. WP_EPOCH=3
  20. EVAL_EPOCH=10
  21. NO_AUG_EPOCH=20
  22. elif [[ $MODEL == *"yolov7"* ]]; then
  23. # Epoch setting
  24. BATCH_SIZE=128
  25. MAX_EPOCH=300
  26. WP_EPOCH=3
  27. EVAL_EPOCH=10
  28. NO_AUG_EPOCH=20
  29. elif [[ $MODEL == *"yolov5"* ]]; then
  30. # Epoch setting
  31. BATCH_SIZE=128
  32. MAX_EPOCH=300
  33. WP_EPOCH=3
  34. EVAL_EPOCH=10
  35. NO_AUG_EPOCH=20
  36. elif [[ $MODEL == *"yolov4"* ]]; then
  37. # Epoch setting
  38. BATCH_SIZE=128
  39. MAX_EPOCH=300
  40. WP_EPOCH=3
  41. EVAL_EPOCH=10
  42. NO_AUG_EPOCH=20
  43. elif [[ $MODEL == *"yolov3"* ]]; then
  44. # Epoch setting
  45. BATCH_SIZE=128
  46. MAX_EPOCH=300
  47. WP_EPOCH=3
  48. EVAL_EPOCH=10
  49. NO_AUG_EPOCH=20
  50. elif [[ $MODEL == *"rtcdet"* ]]; then
  51. # Epoch setting
  52. BATCH_SIZE=128
  53. MAX_EPOCH=300
  54. WP_EPOCH=3
  55. EVAL_EPOCH=10
  56. NO_AUG_EPOCH=20
  57. elif [[ $MODEL == *"ctrnet"* ]]; then
  58. # Epoch setting
  59. BATCH_SIZE=128
  60. MAX_EPOCH=300
  61. WP_EPOCH=3
  62. EVAL_EPOCH=10
  63. NO_AUG_EPOCH=20
  64. else
  65. # Epoch setting
  66. BATCH_SIZE=128
  67. MAX_EPOCH=150
  68. WP_EPOCH=3
  69. EVAL_EPOCH=10
  70. NO_AUG_EPOCH=0
  71. fi
  72. # -------------------------- Train Pipeline --------------------------
  73. WORLD_SIZE=$1
  74. MASTER_PORT=$2
  75. if [ $WORLD_SIZE == 1 ]; then
  76. python train.py \
  77. --cuda \
  78. --dataset ${DATASET} \
  79. --root ${DATA_ROOT} \
  80. --model ${MODEL} \
  81. --batch_size ${BATCH_SIZE} \
  82. --img_size ${IMAGE_SIZE} \
  83. --wp_epoch ${WP_EPOCH} \
  84. --max_epoch ${MAX_EPOCH} \
  85. --eval_epoch ${EVAL_EPOCH} \
  86. --no_aug_epoch ${NO_AUG_EPOCH} \
  87. --resume ${RESUME} \
  88. --ema \
  89. --fp16 \
  90. --multi_scale
  91. elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
  92. python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
  93. --cuda \
  94. -dist \
  95. --dataset ${DATASET} \
  96. --root ${DATA_ROOT} \
  97. --model ${MODEL} \
  98. --batch_size ${BATCH_SIZE} \
  99. --img_size ${IMAGE_SIZE} \
  100. --wp_epoch ${WP_EPOCH} \
  101. --max_epoch ${MAX_EPOCH} \
  102. --eval_epoch ${EVAL_EPOCH} \
  103. --no_aug_epoch ${NO_AUG_EPOCH} \
  104. --resume ${RESUME} \
  105. --ema \
  106. --fp16 \
  107. --multi_scale \
  108. --sybn
  109. else
  110. echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
  111. multi-card training mode, which is currently unsupported."
  112. exit 1
  113. fi