|
@@ -2,17 +2,9 @@
|
|
|
MODEL=$1
|
|
MODEL=$1
|
|
|
DATASET=$2
|
|
DATASET=$2
|
|
|
DATA_ROOT=$3
|
|
DATA_ROOT=$3
|
|
|
-WORLD_SIZE=$4
|
|
|
|
|
-MASTER_PORT=$5
|
|
|
|
|
-if [[ $MODEL == *"yolof"* ]]; then
|
|
|
|
|
- # Epoch setting
|
|
|
|
|
- BATCH_SIZE=64
|
|
|
|
|
- EVAL_EPOCH=2
|
|
|
|
|
-elif [[ $MODEL == *"fcos"* ]]; then
|
|
|
|
|
- # Epoch setting
|
|
|
|
|
- BATCH_SIZE=16
|
|
|
|
|
- EVAL_EPOCH=2
|
|
|
|
|
-fi
|
|
|
|
|
|
|
+BATCH_SIZE=$4
|
|
|
|
|
+WORLD_SIZE=$5
|
|
|
|
|
+MASTER_PORT=$6
|
|
|
|
|
|
|
|
# -------------------------- Train Pipeline --------------------------
|
|
# -------------------------- Train Pipeline --------------------------
|
|
|
if [ $WORLD_SIZE == 1 ]; then
|
|
if [ $WORLD_SIZE == 1 ]; then
|
|
@@ -21,8 +13,7 @@ if [ $WORLD_SIZE == 1 ]; then
|
|
|
--dataset ${DATASET} \
|
|
--dataset ${DATASET} \
|
|
|
--root ${DATA_ROOT} \
|
|
--root ${DATA_ROOT} \
|
|
|
--model ${MODEL} \
|
|
--model ${MODEL} \
|
|
|
- --batch_size ${BATCH_SIZE} \
|
|
|
|
|
- --eval_epoch ${EVAL_EPOCH}
|
|
|
|
|
|
|
+ --batch_size ${BATCH_SIZE}
|
|
|
elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
|
|
elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
|
|
|
python -m torch.distributed.run --nproc_per_node=$WORLD_SIZE --master_port ${MASTER_PORT} \
|
|
python -m torch.distributed.run --nproc_per_node=$WORLD_SIZE --master_port ${MASTER_PORT} \
|
|
|
train.py \
|
|
train.py \
|
|
@@ -31,8 +22,7 @@ elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
|
|
|
--dataset ${DATASET} \
|
|
--dataset ${DATASET} \
|
|
|
--root ${DATA_ROOT} \
|
|
--root ${DATA_ROOT} \
|
|
|
--model ${MODEL} \
|
|
--model ${MODEL} \
|
|
|
- --batch_size ${BATCH_SIZE} \
|
|
|
|
|
- --eval_epoch ${EVAL_EPOCH}
|
|
|
|
|
|
|
+ --batch_size ${BATCH_SIZE}
|
|
|
else
|
|
else
|
|
|
echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
|
|
echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
|
|
|
multi-card training mode, which is currently unsupported."
|
|
multi-card training mode, which is currently unsupported."
|