train.sh 1.0 KB

1234567891011121314151617181920212223242526272829303132333435
  1. # Args parameters
  2. MODEL=$1
  3. DATASET=$2
  4. DATASET_ROOT=$3
  5. BATCH_SIZE=$4
  6. WORLD_SIZE=$5
  7. MASTER_PORT=$6
  8. RESUME=$7
  9. # -------------------------- Train Pipeline --------------------------
  10. if [[ $WORLD_SIZE == 1 ]]; then
  11. python train.py \
  12. --cuda \
  13. --dataset ${DATASET} \
  14. --root ${DATASET_ROOT} \
  15. --model ${MODEL} \
  16. --batch_size ${BATCH_SIZE} \
  17. --resume ${RESUME}
  18. elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
  19. python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
  20. --cuda \
  21. --distributed \
  22. --dataset ${DATASET} \
  23. --root ${DATASET_ROOT} \
  24. --model ${MODEL} \
  25. --batch_size ${BATCH_SIZE} \
  26. --resume ${RESUME} \
  27. --fp16 \
  28. --sybn
  29. else
  30. echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
  31. multi-card training mode, which is currently unsupported."
  32. exit 1
  33. fi