train.sh 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164
  1. # ------------------- Args setting -------------------
  2. MODEL=$1
  3. DATASET=$2
  4. DATASET_ROOT=$3
  5. WORLD_SIZE=$4
  6. MASTER_PORT=$5
  7. RESUME=$6
  8. # ------------------- Training setting -------------------
  9. ## Epoch
  10. BATCH_SIZE=128
  11. GRAD_ACCUMULATE=32
  12. WP_EPOCH=10
  13. MAX_EPOCH=100
  14. EVAL_EPOCH=5
  15. DROP_PATH=0.1
  16. ## Scheduler
  17. OPTIMIZER="adamw"
  18. LRSCHEDULER="cosine"
  19. BASE_LR=1e-3 # 0.1 for SGD; 0.001 for AdamW
  20. MIN_LR=1e-6
  21. BATCH_BASE=1024 # 256 for SGD; 1024 for AdamW
  22. MOMENTUM=0.9
  23. WEIGHT_DECAY=0.05 # 0.0001 for SGD; 0.05 for AdamW
  24. # ------------------- Dataset config -------------------
  25. if [[ $DATASET == "mnist" ]]; then
  26. IMG_SIZE=28
  27. NUM_CLASSES=10
  28. elif [[ $DATASET == "cifar10" ]]; then
  29. IMG_SIZE=32
  30. NUM_CLASSES=10
  31. elif [[ $DATASET == "cifar100" ]]; then
  32. IMG_SIZE=32
  33. NUM_CLASSES=100
  34. elif [[ $DATASET == "imagenet_1k" || $DATASET == "imagenet_22k" ]]; then
  35. IMG_SIZE=224
  36. NUM_CLASSES=1000
  37. elif [[ $DATASET == "custom" ]]; then
  38. IMG_SIZE=224
  39. NUM_CLASSES=2
  40. else
  41. echo "Unknown dataset!!"
  42. exit 1
  43. fi
  44. # ------------------- Training pipeline -------------------
  45. if [ $WORLD_SIZE == 1 ]; then
  46. python train.py \
  47. --cuda \
  48. --root ${DATASET_ROOT} \
  49. --dataset ${DATASET} \
  50. --model ${MODEL} \
  51. --resume ${RESUME} \
  52. --batch_size ${BATCH_SIZE} \
  53. --batch_base ${BATCH_BASE} \
  54. --grad_accumulate ${GRAD_ACCUMULATE} \
  55. --img_size ${IMG_SIZE} \
  56. --drop_path ${DROP_PATH} \
  57. --max_epoch ${MAX_EPOCH} \
  58. --wp_epoch ${WP_EPOCH} \
  59. --eval_epoch ${EVAL_EPOCH} \
  60. --optimizer ${OPTIMIZER} \
  61. --lr_scheduler ${LRSCHEDULER} \
  62. --base_lr ${BASE_LR} \
  63. --min_lr ${MIN_LR} \
  64. --momentum ${MOMENTUM} \
  65. --weight_decay ${WEIGHT_DECAY} \
  66. --color_jitter 0.0 \
  67. --reprob 0.0 \
  68. --mixup 0.0 \
  69. --cutmix 0.0
  70. elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
  71. python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
  72. --cuda \
  73. --distributed \
  74. --root ${DATASET_ROOT} \
  75. --dataset ${DATASET} \
  76. --model ${MODEL} \
  77. --resume ${RESUME} \
  78. --batch_size ${BATCH_SIZE} \
  79. --batch_base ${BATCH_BASE} \
  80. --grad_accumulate ${GRAD_ACCUMULATE} \
  81. --img_size ${IMG_SIZE} \
  82. --drop_path ${DROP_PATH} \
  83. --max_epoch ${MAX_EPOCH} \
  84. --wp_epoch ${WP_EPOCH} \
  85. --eval_epoch ${EVAL_EPOCH} \
  86. --optimizer ${OPTIMIZER} \
  87. --lr_scheduler ${LRSCHEDULER} \
  88. --base_lr ${BASE_LR} \
  89. --min_lr ${MIN_LR} \
  90. --momentum ${MOMENTUM} \
  91. --weight_decay ${WEIGHT_DECAY} \
  92. --sybn \
  93. --color_jitter 0.0 \
  94. --reprob 0.0 \
  95. --mixup 0.0 \
  96. --cutmix 0.0
  97. else
  98. echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
  99. multi-card training mode, which is currently unsupported."
  100. exit 1
  101. fi
  102. # # ------------------- Training pipeline with strong augmentations -------------------
  103. # if [ $WORLD_SIZE == 1 ]; then
  104. # python train.py \
  105. # --cuda \
  106. # --root ${DATASET_ROOT} \
  107. # --dataset ${DATASET} \
  108. # --model ${MODEL} \
  109. # --resume ${RESUME} \
  110. # --batch_size ${BATCH_SIZE} \
  111. # --batch_base ${BATCH_BASE} \
  112. # --grad_accumulate ${GRAD_ACCUMULATE} \
  113. # --img_size ${IMG_SIZE} \
  114. # --drop_path ${DROP_PATH} \
  115. # --max_epoch ${MAX_EPOCH} \
  116. # --wp_epoch ${WP_EPOCH} \
  117. # --eval_epoch ${EVAL_EPOCH} \
  118. # --optimizer ${OPTIMIZER} \
  119. # --lr_scheduler ${LRSCHEDULER} \
  120. # --base_lr ${BASE_LR} \
  121. # --min_lr ${MIN_LR} \
  122. # --weight_decay ${WEIGHT_DECAY} \
  123. # --aa "rand-m9-mstd0.5-inc1" \
  124. # --reprob 0.25 \
  125. # --mixup 0.8 \
  126. # --cutmix 1.0
  127. # elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
  128. # python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
  129. # --cuda \
  130. # --distributed \
  131. # --root ${DATASET_ROOT} \
  132. # --dataset ${DATASET} \
  133. # --model ${MODEL} \
  134. # --resume ${RESUME} \
  135. # --batch_size ${BATCH_SIZE} \
  136. # --batch_base ${BATCH_BASE} \
  137. # --grad_accumulate ${GRAD_ACCUMULATE} \
  138. # --img_size ${IMG_SIZE} \
  139. # --drop_path ${DROP_PATH} \
  140. # --max_epoch ${MAX_EPOCH} \
  141. # --wp_epoch ${WP_EPOCH} \
  142. # --eval_epoch ${EVAL_EPOCH} \
  143. # --optimizer ${OPTIMIZER} \
  144. # --lr_scheduler ${LRSCHEDULER} \
  145. # --base_lr ${BASE_LR} \
  146. # --min_lr ${MIN_LR} \
  147. # --weight_decay ${WEIGHT_DECAY} \
  148. # --sybn \
  149. # --aa "rand-m9-mstd0.5-inc1" \
  150. # --reprob 0.25 \
  151. # --mixup 0.8 \
  152. # --cutmix 1.0
  153. # else
  154. # echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
  155. # multi-card training mode, which is currently unsupported."
  156. # exit 1
  157. # fi