train.sh 877 B

123456789101112131415161718192021222324252627282930
  1. # Args setting
  2. MODEL=$1
  3. DATASET=$2
  4. DATA_ROOT=$3
  5. BATCH_SIZE=$4
  6. WORLD_SIZE=$5
  7. MASTER_PORT=$6
  8. # -------------------------- Train Pipeline --------------------------
  9. if [ $WORLD_SIZE == 1 ]; then
  10. python train.py \
  11. --cuda \
  12. --dataset ${DATASET} \
  13. --root ${DATA_ROOT} \
  14. --model ${MODEL} \
  15. --batch_size ${BATCH_SIZE}
  16. elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
  17. python -m torch.distributed.run --nproc_per_node=$WORLD_SIZE --master_port ${MASTER_PORT} \
  18. train.py \
  19. --cuda \
  20. --distributed \
  21. --dataset ${DATASET} \
  22. --root ${DATA_ROOT} \
  23. --model ${MODEL} \
  24. --batch_size ${BATCH_SIZE}
  25. else
  26. echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
  27. multi-card training mode, which is currently unsupported."
  28. exit 1
  29. fi