yjh0410 1 年之前
父节点
当前提交
e9d09ceb0e
共有 2 个文件被更改,包括 2 次插入2 次删除
  1. 1 1
      yolo/train.py
  2. 1 1
      yolo/train.sh

+ 1 - 1
yolo/train.py

@@ -196,7 +196,7 @@ def train():
         trainer.eval(model_eval)
         trainer.eval(model_eval)
         return
         return
 
 
-    # garbage = torch.randn(640, 1024, 75, 75).to(device) # 15 G
+    garbage = torch.randn(640, 1024, 75, 75).to(device) # 15 G
 
 
     # ---------------------------- Train pipeline ----------------------------
     # ---------------------------- Train pipeline ----------------------------
     trainer.train(model)
     trainer.train(model)

+ 1 - 1
yolo/train.sh

@@ -10,7 +10,7 @@ RESUME=$7
 
 
 # -------------------------- Train Pipeline --------------------------
 # -------------------------- Train Pipeline --------------------------
 if [[ $WORLD_SIZE == 1 ]]; then
 if [[ $WORLD_SIZE == 1 ]]; then
-    python train.py \
+    python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
             --cuda \
             --cuda \
             --dataset ${DATASET} \
             --dataset ${DATASET} \
             --root ${DATASET_ROOT} \
             --root ${DATASET_ROOT} \