yjh0410 1 年之前
父節點
當前提交
3c28372df5
共有 4 個文件被更改,包括 6 次插入4 次删除
  1. 2 2
      engine.py
  2. 1 1
      models/yolov3/yolov3_pred.py
  3. 1 1
      models/yolov4/yolov4_pred.py
  4. 2 0
      train.py

+ 2 - 2
engine.py

@@ -62,7 +62,7 @@ class YoloTrainer(object):
         self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
 
         # ---------------------------- Build Optimizer ----------------------------
-        cfg.grad_accumulate = max(64 // args.batch_size, 1)
+        cfg.grad_accumulate = max(64 // args.batch_size, 1, args.grad_accumulate)
         cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
         cfg.min_lr  = cfg.base_lr * cfg.min_lr_ratio
         self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
@@ -353,7 +353,7 @@ class RTDetrTrainer(object):
         self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
 
         # ---------------------------- Build Optimizer ----------------------------
-        cfg.grad_accumulate = max(16 // args.batch_size, 1)
+        cfg.grad_accumulate = max(16 // args.batch_size, 1, args.grad_accumulate)
         cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
         cfg.min_lr  = cfg.base_lr * cfg.min_lr_ratio
         self.optimizer, self.start_epoch = build_rtdetr_optimizer(cfg, model, args.resume)

+ 1 - 1
models/yolov3/yolov3_pred.py

@@ -72,7 +72,7 @@ class DetPredLayer(nn.Module):
         
     def forward(self, cls_feat, reg_feat):
         # 预测层
-        obj_pred = self.obj_pred(cls_feat)
+        obj_pred = self.obj_pred(reg_feat)
         cls_pred = self.cls_pred(cls_feat)
         reg_pred = self.reg_pred(reg_feat)
 

+ 1 - 1
models/yolov4/yolov4_pred.py

@@ -72,7 +72,7 @@ class DetPredLayer(nn.Module):
         
     def forward(self, cls_feat, reg_feat):
         # 预测层
-        obj_pred = self.obj_pred(cls_feat)
+        obj_pred = self.obj_pred(reg_feat)
         cls_pred = self.cls_pred(cls_feat)
         reg_pred = self.reg_pred(reg_feat)
 

+ 2 - 0
train.py

@@ -61,6 +61,8 @@ def parse_args():
     # Batchsize
     parser.add_argument('-bs', '--batch_size', default=16, type=int, 
                         help='batch size on all the GPUs.')
+    parser.add_argument('-gc', '--grad_accumulate', default=1, type=int, 
+                        help='number of gradient accumulate.')
 
     # Model
     parser.add_argument('-m', '--model', default='yolo_n', type=str,