yjh0410 1 jaar geleden
bovenliggende
commit
e87e86fb0a
2 gewijzigde bestanden met toevoegingen van 15 en 13 verwijderingen
  1. 13 11
      yolo/engine.py
  2. 2 2
      yolo/models/yolov8/yolov8_backbone.py

+ 13 - 11
yolo/engine.py

@@ -63,7 +63,8 @@ class YoloTrainer(object):
         self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
 
         # ---------------------------- Build Optimizer ----------------------------
-        cfg.base_lr = cfg.per_image_lr * args.batch_size
+        self.grad_accumulate = max(64 // args.batch_size, 1)
+        cfg.base_lr = cfg.per_image_lr * args.batch_size * self.grad_accumulate
         cfg.min_lr  = cfg.base_lr * cfg.min_lr_ratio
         self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
 
@@ -216,16 +217,17 @@ class YoloTrainer(object):
             self.scaler.scale(losses).backward()
 
             # Optimize
-            if self.cfg.clip_max_norm > 0:
-                self.scaler.unscale_(self.optimizer)
-                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
-            self.scaler.step(self.optimizer)
-            self.scaler.update()
-            self.optimizer.zero_grad()
-
-            # ModelEMA
-            if self.model_ema is not None:
-                self.model_ema.update(model)
+            if (iter_i + 1) % self.grad_accumulate == 0:
+                if self.cfg.clip_max_norm > 0:
+                    self.scaler.unscale_(self.optimizer)
+                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
+                self.scaler.step(self.optimizer)
+                self.scaler.update()
+                self.optimizer.zero_grad()
+
+                # ModelEMA
+                if self.model_ema is not None:
+                    self.model_ema.update(model)
 
             # Update log
             metric_logger.update(**loss_dict_reduced)

+ 2 - 2
yolo/models/yolov8/yolov8_backbone.py

@@ -8,8 +8,8 @@ except:
 
 # IN1K pretrained weight
 pretrained_urls = {
-    'n': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/rtcnet_n_in1k_62.1.pth",
-    's': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/rtcnet_s_in1k_71.3.pth",
+    'n': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/elandarknet_n_in1k_62.1.pth",
+    's': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/elandarknet_s_in1k_71.3.pth",
     'm': None,
     'l': None,
     'x': None,