|
@@ -5,6 +5,7 @@ DATA_ROOT=$3
|
|
|
BATCH_SIZE=$4
|
|
BATCH_SIZE=$4
|
|
|
WORLD_SIZE=$5
|
|
WORLD_SIZE=$5
|
|
|
MASTER_PORT=$6
|
|
MASTER_PORT=$6
|
|
|
|
|
+RESUME=$7
|
|
|
|
|
|
|
|
# -------------------------- Train Pipeline --------------------------
|
|
# -------------------------- Train Pipeline --------------------------
|
|
|
if [ $WORLD_SIZE == 1 ]; then
|
|
if [ $WORLD_SIZE == 1 ]; then
|
|
@@ -13,7 +14,8 @@ if [ $WORLD_SIZE == 1 ]; then
|
|
|
--dataset ${DATASET} \
|
|
--dataset ${DATASET} \
|
|
|
--root ${DATA_ROOT} \
|
|
--root ${DATA_ROOT} \
|
|
|
--model ${MODEL} \
|
|
--model ${MODEL} \
|
|
|
- --batch_size ${BATCH_SIZE}
|
|
|
|
|
|
|
+ --batch_size ${BATCH_SIZE} \
|
|
|
|
|
+ --resume ${RESUME}
|
|
|
elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
|
|
elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
|
|
|
python -m torch.distributed.run --nproc_per_node=$WORLD_SIZE --master_port ${MASTER_PORT} \
|
|
python -m torch.distributed.run --nproc_per_node=$WORLD_SIZE --master_port ${MASTER_PORT} \
|
|
|
train.py \
|
|
train.py \
|
|
@@ -22,7 +24,8 @@ elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
|
|
|
--dataset ${DATASET} \
|
|
--dataset ${DATASET} \
|
|
|
--root ${DATA_ROOT} \
|
|
--root ${DATA_ROOT} \
|
|
|
--model ${MODEL} \
|
|
--model ${MODEL} \
|
|
|
- --batch_size ${BATCH_SIZE}
|
|
|
|
|
|
|
+ --batch_size ${BATCH_SIZE} \
|
|
|
|
|
+ --resume ${RESUME}
|
|
|
else
|
|
else
|
|
|
echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
|
|
echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
|
|
|
multi-card training mode, which is currently unsupported."
|
|
multi-card training mode, which is currently unsupported."
|