yjh0410 1 жил өмнө
parent
commit
5b2a227e64
1 өөрчлөгдсөн 3 нэмэгдсэн , 11 устгасан
  1. 3 11
      train.sh

+ 3 - 11
train.sh

@@ -5,66 +5,58 @@ DATA_ROOT="/data/datasets/"
 # MODEL setting
 # MODEL setting
 MODEL=$1
 MODEL=$1
 RESUME="None"
 RESUME="None"
+BATCH_SIZE=$2
 IMAGE_SIZE=640
 IMAGE_SIZE=640
 if [[ $MODEL == *"yolov8"* ]]; then
 if [[ $MODEL == *"yolov8"* ]]; then
     # Epoch setting
     # Epoch setting
-    BATCH_SIZE=128
     MAX_EPOCH=500
     MAX_EPOCH=500
     WP_EPOCH=3
     WP_EPOCH=3
     EVAL_EPOCH=10
     EVAL_EPOCH=10
     NO_AUG_EPOCH=20
     NO_AUG_EPOCH=20
 elif [[ $MODEL == *"yolox"* ]]; then
 elif [[ $MODEL == *"yolox"* ]]; then
     # Epoch setting
     # Epoch setting
-    BATCH_SIZE=128
     MAX_EPOCH=300
     MAX_EPOCH=300
     WP_EPOCH=3
     WP_EPOCH=3
     EVAL_EPOCH=10
     EVAL_EPOCH=10
     NO_AUG_EPOCH=20
     NO_AUG_EPOCH=20
 elif [[ $MODEL == *"yolov7"* ]]; then
 elif [[ $MODEL == *"yolov7"* ]]; then
     # Epoch setting
     # Epoch setting
-    BATCH_SIZE=128
     MAX_EPOCH=300
     MAX_EPOCH=300
     WP_EPOCH=3
     WP_EPOCH=3
     EVAL_EPOCH=10
     EVAL_EPOCH=10
     NO_AUG_EPOCH=20
     NO_AUG_EPOCH=20
 elif [[ $MODEL == *"yolov5"* ]]; then
 elif [[ $MODEL == *"yolov5"* ]]; then
     # Epoch setting
     # Epoch setting
-    BATCH_SIZE=128
     MAX_EPOCH=300
     MAX_EPOCH=300
     WP_EPOCH=3
     WP_EPOCH=3
     EVAL_EPOCH=10
     EVAL_EPOCH=10
     NO_AUG_EPOCH=20
     NO_AUG_EPOCH=20
 elif [[ $MODEL == *"yolov4"* ]]; then
 elif [[ $MODEL == *"yolov4"* ]]; then
     # Epoch setting
     # Epoch setting
-    BATCH_SIZE=128
     MAX_EPOCH=300
     MAX_EPOCH=300
     WP_EPOCH=3
     WP_EPOCH=3
     EVAL_EPOCH=10
     EVAL_EPOCH=10
     NO_AUG_EPOCH=20
     NO_AUG_EPOCH=20
 elif [[ $MODEL == *"yolov3"* ]]; then
 elif [[ $MODEL == *"yolov3"* ]]; then
     # Epoch setting
     # Epoch setting
-    BATCH_SIZE=128
     MAX_EPOCH=300
     MAX_EPOCH=300
     WP_EPOCH=3
     WP_EPOCH=3
     EVAL_EPOCH=10
     EVAL_EPOCH=10
     NO_AUG_EPOCH=20
     NO_AUG_EPOCH=20
 elif [[ $MODEL == *"rtcdet"* ]]; then
 elif [[ $MODEL == *"rtcdet"* ]]; then
     # Epoch setting
     # Epoch setting
-    BATCH_SIZE=128
     MAX_EPOCH=300
     MAX_EPOCH=300
     WP_EPOCH=3
     WP_EPOCH=3
     EVAL_EPOCH=10
     EVAL_EPOCH=10
     NO_AUG_EPOCH=20
     NO_AUG_EPOCH=20
 elif [[ $MODEL == *"ctrnet"* ]]; then
 elif [[ $MODEL == *"ctrnet"* ]]; then
     # Epoch setting
     # Epoch setting
-    BATCH_SIZE=128
     MAX_EPOCH=300
     MAX_EPOCH=300
     WP_EPOCH=3
     WP_EPOCH=3
     EVAL_EPOCH=10
     EVAL_EPOCH=10
     NO_AUG_EPOCH=20
     NO_AUG_EPOCH=20
 else
 else
     # Epoch setting
     # Epoch setting
-    BATCH_SIZE=128
     MAX_EPOCH=150
     MAX_EPOCH=150
     WP_EPOCH=3
     WP_EPOCH=3
     EVAL_EPOCH=10
     EVAL_EPOCH=10
@@ -72,8 +64,8 @@ else
 fi
 fi
 
 
 # -------------------------- Train Pipeline --------------------------
 # -------------------------- Train Pipeline --------------------------
-WORLD_SIZE=$2
-MASTER_PORT=$3
+WORLD_SIZE=$3
+MASTER_PORT=$4
 if [ $WORLD_SIZE == 1 ]; then
 if [ $WORLD_SIZE == 1 ]; then
     python train.py \
     python train.py \
             --cuda \
             --cuda \