train.sh 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940
  1. # Args parameters
  2. MODEL=$1
  3. DATASET=$2
  4. DATASET_ROOT=$3
  5. BATCH_SIZE=$4
  6. GRAD_ACCUMULATE=$5
  7. WORLD_SIZE=$6
  8. MASTER_PORT=$7
  9. RESUME=$8
  10. # -------------------------- Train Pipeline --------------------------
  11. if [[ $WORLD_SIZE == 1 ]]; then
  12. python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
  13. --cuda \
  14. --distributed \
  15. --dataset ${DATASET} \
  16. --root ${DATASET_ROOT} \
  17. --model ${MODEL} \
  18. --batch_size ${BATCH_SIZE} \
  19. --grad_accumulate ${GRAD_ACCUMULATE}\
  20. --resume ${RESUME} \
  21. --fp16
  22. elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
  23. python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
  24. --cuda \
  25. --distributed \
  26. --dataset ${DATASET} \
  27. --root ${DATASET_ROOT} \
  28. --model ${MODEL} \
  29. --batch_size ${BATCH_SIZE} \
  30. --grad_accumulate ${GRAD_ACCUMULATE}\
  31. --resume ${RESUME} \
  32. --fp16 \
  33. --sybn
  34. else
  35. echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
  36. multi-card training mode, which is currently unsupported."
  37. exit 1
  38. fi