瀏覽代碼

debug train

yjh0410 2 年之前
父節點
當前提交
ba57cea04f
共有 3 個文件被更改,包括 12 次插入12 次删除
  1. 1 1
      engine.py
  2. 8 8
      models/yolov1/loss.py
  3. 3 3
      train.sh

+ 1 - 1
engine.py

@@ -69,7 +69,7 @@ def train_one_epoch(epoch,
     for iter_i, (images, targets) in enumerate(dataloader):
         ni = iter_i + epoch * epoch_size
         # Warmup
-        if ni <= nw:
+        if ni < nw:
             warmup_scheduler.warmup(ni, optimizer)
                             
         # visualize train targets

+ 8 - 8
models/yolov1/loss.py

@@ -70,15 +70,15 @@ class Criterion(object):
                              targets=targets)
         # List[B, M, C] -> [B, M, C] -> [BM, C]
         batch_size = outputs['pred_obj'].shape[0]
-        pred_obj = outputs['pred_obj'].view(-1)
-        pred_cls = outputs['pred_cls'].view(-1, self.num_classes)
-        pred_txty = outputs['pred_txty'].view(-1, 2)
-        pred_twth = outputs['pred_twth'].view(-1, 2)
+        pred_obj = outputs['pred_obj'].view(-1)                     # [BM,]
+        pred_cls = outputs['pred_cls'].view(-1, self.num_classes)   # [BM, C]
+        pred_txty = outputs['pred_txty'].view(-1, 2)                # [BM, 2]
+        pred_twth = outputs['pred_twth'].view(-1, 2)                # [BM, 2]
        
-        gt_objectness = gt_objectness.view(-1).to(device).float()
-        gt_labels = gt_labels.view(-1).to(device).long()
-        gt_bboxes = gt_bboxes.view(-1, 4).to(device).float()
-        gt_box_weight = gt_box_weight.view(-1).to(device).float()
+        gt_objectness = gt_objectness.view(-1).to(device).float()   # [BM,]
+        gt_labels = gt_labels.view(-1).to(device).long()            # [BM,]
+        gt_bboxes = gt_bboxes.view(-1, 4).to(device).float()        # [BM, 4]
+        gt_box_weight = gt_box_weight.view(-1).to(device).float()   # [BM,]
 
         pos_masks = (gt_objectness > 0)
 

+ 3 - 3
train.sh

@@ -7,9 +7,9 @@ python train.py \
         -bs 16 \
         -size 640 \
         --wp_epoch 1 \
-        --max_epoch 4 \
-        --step_epoch 2 3 \
-        --eval_epoch 1 \
+        --max_epoch 150 \
+        --step_epoch 90 120 \
+        --eval_epoch 10 \
         --ema \
         --fp16 \
         --multi_scale \