yjh0410 1 年之前
父节点
当前提交
e90574b663
共有 5 个文件被更改,包括 307 次插入21 次删除
  1. 5 5
      yolo/config/fcos_config.py
  2. 224 1
      yolo/engine.py
  3. 20 10
      yolo/train.py
  4. 6 4
      yolo/utils/solver/lr_scheduler.py
  5. 52 1
      yolo/utils/solver/optimizer.py

+ 5 - 5
yolo/config/fcos_config.py

@@ -33,7 +33,7 @@ class FcosBaseConfig(object):
         # ---------------- Assignment config ----------------
         ## Matcher
         self.center_sampling_radius = 1.5
-        self.object_sizes_of_interest = [[-1, 64], [64, 128], [128, 256], [256, 512], [512, float('inf')]]
+        self.object_sizes_of_interest = [[-1, 64], [64, 128], [128, 256], [256, float('inf')]]
 
         ## Loss weight
         self.focal_loss_alpha = 0.25
@@ -43,12 +43,12 @@ class FcosBaseConfig(object):
         self.loss_ctn = 1.0
 
         # ---------------- ModelEMA config ----------------
-        self.use_ema   = True
+        self.use_ema   = False
         self.ema_decay = 0.9998
         self.ema_tau   = 2000
 
         # ---------------- Optimizer config ----------------
-        self.trainer      = 'yolo'
+        self.trainer      = 'simple'
         self.optimizer    = 'sgd'
         self.base_lr      = 0.01     # base_lr = per_image_lr * batch_size
         self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
@@ -57,10 +57,10 @@ class FcosBaseConfig(object):
         self.weight_decay = 0.0001
         self.clip_max_norm   = 10.0
         self.warmup_bias_lr  = 0.0
-        self.warmup_momentum = 0.8
+        self.warmup_momentum = 0.9
 
         # ---------------- Lr Scheduler config ----------------
-        self.warmup_epoch = 3
+        self.warmup_iters = 500
         self.lr_scheduler = "cosine"
         self.max_epoch    = 150
         self.eval_epoch   = 10

+ 224 - 1
yolo/engine.py

@@ -10,7 +10,7 @@ from utils.misc import MetricLogger, SmoothedValue
 from utils.vis_tools import vis_data
 
 # ----------------- Optimizer & LrScheduler Components -----------------
-from utils.solver.optimizer import build_yolo_optimizer
+from utils.solver.optimizer import build_simple_optimizer, build_yolo_optimizer
 from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler
 
 
@@ -295,3 +295,226 @@ class YoloTrainer(object):
         if self.train_loader.dataset.copy_paste > 0.:
             print(' - Close < Copy-paste Augmentation > ...')
             self.train_loader.dataset.copy_paste = 0.
+
+class SimpleTrainer(object):
+    def __init__(self,
+                 # Basic parameters
+                 args,
+                 cfg,
+                 device,
+                 # Model parameters
+                 model,
+                 criterion,
+                 # Data parameters
+                 train_loader,
+                 evaluator,
+                 ):
+        # ------------------- basic parameters -------------------
+        self.args = args
+        self.cfg  = cfg
+        self.epoch = 0
+        self.best_map = -1.
+        self.device = device
+        self.criterion = criterion
+
+        # path to save model
+        self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
+        os.makedirs(self.path_to_save, exist_ok=True)
+
+        # ---------------------------- Dataset & Dataloader ----------------------------
+        self.train_loader = train_loader
+
+        # ---------------------------- Evaluator ----------------------------
+        self.evaluator = evaluator
+
+        # ---------------------------- Build Optimizer ----------------------------
+        self.grad_accumulate = max(cfg.batch_size_base // args.batch_size, 1)
+        cfg.base_lr = cfg.base_lr / cfg.batch_size_base * args.batch_size * self.grad_accumulate  # Auto scale learning rate
+        cfg.min_lr  = cfg.base_lr * cfg.min_lr_ratio
+        self.optimizer, self.start_epoch = build_simple_optimizer(cfg, model, args.resume)
+
+        # ---------------------------- Build LR Scheduler ----------------------------
+        self.lr_scheduler_warmup = LinearWarmUpLrScheduler(cfg.warmup_iters, cfg.base_lr, cfg.warmup_bias_lr)
+        self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
+
+        self.best_map = cfg.best_map / 100.0
+        print("Best mAP metric: {}".format(self.best_map))
+
+    def train(self, model):
+        for epoch in range(self.start_epoch, self.cfg.max_epoch):
+            if self.args.distributed:
+                self.train_loader.batch_sampler.sampler.set_epoch(epoch)
+
+            # train one epoch
+            self.epoch = epoch
+            self.train_one_epoch(model)
+
+            # LR Schedule
+            self.lr_scheduler.step()
+
+            # eval one epoch
+            model_eval = model.module if self.args.distributed else model
+            if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
+                self.eval(model_eval)
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 epoch")
+                break
+
+    def eval(self, model):
+        # set eval mode
+        model.eval()
+        cur_map = -1.
+        to_save = False
+
+        if distributed_utils.is_main_process():
+            if self.evaluator is None:
+                print('No evaluator ... save model and go on training.')
+                to_save = True
+                weight_name = '{}_no_eval.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+            else:
+                print('Eval ...')
+                # Evaluate
+                with torch.no_grad():
+                    self.evaluator.evaluate(model)
+
+                cur_map = self.evaluator.map
+                if cur_map > self.best_map:
+                    # update best-map
+                    self.best_map = cur_map
+                    to_save = True
+
+            # Save model
+            if to_save:
+                print('Saving state, epoch:', self.epoch)
+                weight_name = '{}_best.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                state_dicts = {
+                    'model': model.state_dict(),
+                    'mAP': round(cur_map*100, 3),
+                    'optimizer':  self.optimizer.state_dict(),
+                    'lr_scheduler': self.lr_scheduler.state_dict(),
+                    'epoch': self.epoch,
+                    'args': self.args,
+                    }
+                torch.save(state_dicts, checkpoint_path)                      
+
+        if self.args.distributed:
+            # wait for all processes to synchronize
+            dist.barrier()
+
+        # set train mode.
+        model.train()
+
+    def train_one_epoch(self, model):
+        metric_logger = MetricLogger(delimiter="  ")
+        metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
+        metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
+        metric_logger.add_meter('gnorm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
+        header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
+        epoch_size = len(self.train_loader)
+        print_freq = 10
+        gnorm = 0.0
+
+        # basic parameters
+        epoch_size = len(self.train_loader)
+        img_size   = self.cfg.train_img_size
+        nw = self.cfg.warmup_iters
+
+        # Train one epoch
+        for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
+            ni = iter_i + self.epoch * epoch_size
+
+            # Warmup
+            if nw > 0 and ni < nw:
+                self.lr_scheduler_warmup(ni, self.optimizer)
+            elif ni == nw:
+                print("Warmup stage is over.")
+                self.lr_scheduler_warmup.set_lr(self.optimizer, self.cfg.base_lr)
+                                
+            # To device
+            images = images.to(self.device, non_blocking=True).float()
+
+            # Multi scale
+            images, targets, img_size = self.rescale_image_targets(
+                images, targets, 32, self.cfg.multi_scale)
+                
+            # Visualize train targets
+            if self.args.vis_tgt:
+                vis_data(images,
+                         targets,
+                         self.cfg.num_classes,
+                         self.cfg.pixel_mean,
+                         self.cfg.pixel_std,
+                         )
+
+            # Inference
+            outputs = model(images)
+
+            # Compute loss
+            loss_dict = self.criterion(outputs=outputs, targets=targets)
+            losses = loss_dict['losses']
+            losses /= self.grad_accumulate
+            loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
+
+            # Backward
+            losses.backward()
+
+            # Optimize
+            if (iter_i + 1) % self.grad_accumulate == 0:
+                gnorm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
+                self.optimizer.step()
+                self.optimizer.zero_grad()
+
+            # Update log
+            metric_logger.update(**loss_dict_reduced)
+            metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])
+            metric_logger.update(size=img_size)
+            metric_logger.update(gnorm=gnorm)
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 iteration")
+                break
+
+        # Gather the stats from all processes
+        metric_logger.synchronize_between_processes()
+        print("Averaged stats:", metric_logger)
+
+    def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
+        """
+            Deployed for Multi scale trick.
+        """
+        # During training phase, the shape of input image is square.
+        old_img_size = images.shape[-1]
+        min_img_size = old_img_size * multi_scale_range[0]
+        max_img_size = old_img_size * multi_scale_range[1]
+
+        # Choose a new image size
+        new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
+        
+        # Resize
+        if new_img_size != old_img_size:
+            # interpolate
+            images = torch.nn.functional.interpolate(
+                                input=images, 
+                                size=new_img_size, 
+                                mode='bilinear', 
+                                align_corners=False)
+        # rescale targets
+        for tgt in targets:
+            boxes = tgt["boxes"].clone()
+            labels = tgt["labels"].clone()
+            boxes = torch.clamp(boxes, 0, old_img_size)
+            # rescale box
+            boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
+            boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
+            # refine tgt
+            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= 8)
+
+            tgt["boxes"] = boxes[keep]
+            tgt["labels"] = labels[keep]
+
+        return images, targets, new_img_size

+ 20 - 10
yolo/train.py

@@ -29,7 +29,7 @@ from evaluator.map_evaluator import MapEvaluator
 from models import build_model
 
 # ----------------- Train Components -----------------
-from engine import YoloTrainer
+from engine import YoloTrainer, SimpleTrainer
 
 
 def parse_args():
@@ -193,15 +193,25 @@ def train():
         dist.barrier()
 
     # ---------------------------- Build Trainer ----------------------------
-    trainer = YoloTrainer(args = args,
-                          cfg = cfg,
-                          device = device,
-                          model = model,
-                          model_ema = model_ema,
-                          criterion = criterion,
-                          train_loader = train_loader,
-                          evaluator = evaluator,
-                          )
+    if cfg.trainer == "simple":
+        trainer = SimpleTrainer(args = args,
+                                cfg = cfg,
+                                device = device,
+                                model = model,
+                                criterion = criterion,
+                                train_loader = train_loader,
+                                evaluator = evaluator,
+                                )
+    else:
+        trainer = YoloTrainer(args = args,
+                              cfg = cfg,
+                              device = device,
+                              model = model,
+                              model_ema = model_ema,
+                              criterion = criterion,
+                              train_loader = train_loader,
+                              evaluator = evaluator,
+                              )
     
     ## Eval before training
     if args.eval_first and distributed_utils.is_main_process():

+ 6 - 4
yolo/utils/solver/lr_scheduler.py

@@ -1,6 +1,6 @@
 import numpy as np
 import torch
-
+from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR
 
 # ------------------------- WarmUp LR Scheduler -------------------------
 ## Warmup LR Scheduler
@@ -29,10 +29,12 @@ def build_lr_scheduler(cfg, optimizer, resume=None):
     print('LR Scheduler: {}'.format(cfg.lr_scheduler))
 
     if cfg.lr_scheduler == "step":
-        lr_step = [cfg.max_epoch // 3, cfg.max_epoch // 3 * 2]
-        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_step, gamma=0.1)
+        lr_step = [cfg.max_epoch // 2, cfg.max_epoch // 3 * 4]
+        lr_scheduler = MultiStepLR(optimizer, milestones=lr_step, gamma=0.1)
+
     elif cfg.lr_scheduler == "cosine":
-        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.max_epoch - cfg.warmup_epoch - 1, eta_min=cfg.min_lr)
+        lr_scheduler = CosineAnnealingLR(optimizer, T_max=cfg.max_epoch - cfg.warmup_epoch - 1, eta_min=cfg.min_lr)
+    
     else:
         raise NotImplementedError("Unknown lr scheduler: {}".format(cfg.lr_scheduler))
         

+ 52 - 1
yolo/utils/solver/optimizer.py

@@ -1,6 +1,58 @@
 import torch
 
 
+def build_simple_optimizer(cfg, model, resume=None):
+    print('==============================')
+    print('Optimizer: {}'.format(cfg.optimizer))
+    print('--base lr: {}'.format(cfg.base_lr))
+    print('--min lr:  {}'.format(cfg.min_lr))
+    print('--momentum: {}'.format(cfg.momentum))
+    print('--weight_decay: {}'.format(cfg.weight_decay))
+
+    # ------------- Divide model's parameters -------------
+    param_dicts = [
+        {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
+        {
+            "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
+            "lr": cfg.base_lr * cfg.bk_lr_ratio,
+        },
+    ]
+
+    if cfg.optimizer == 'sgd':
+        optimizer = torch.optim.SGD(
+            params=param_dicts, 
+            lr=cfg.base_lr,
+            momentum=cfg.momentum,
+            weight_decay=cfg.weight_decay
+            )
+                                
+    elif cfg.optimizer == 'adamw':
+        optimizer = torch.optim.AdamW(
+            params=param_dicts, 
+            lr=cfg.base_lr,
+            weight_decay=cfg.weight_decay
+            )
+
+    start_epoch = 0
+    cfg.best_map = -1.
+    if resume and resume != 'None':
+        checkpoint = torch.load(resume)
+        # checkpoint state dict
+        try:
+            checkpoint_state_dict = checkpoint.pop("optimizer")
+            print('--Load optimizer from the checkpoint: ', resume)
+            optimizer.load_state_dict(checkpoint_state_dict)
+            start_epoch = checkpoint.pop("epoch") + 1
+            if "mAP" in checkpoint:
+                print('--Load best metric from the checkpoint: ', resume)
+                best_map = checkpoint["mAP"]
+                cfg.best_map = best_map
+            del checkpoint, checkpoint_state_dict
+        except:
+            print("No optimzier in the given checkpoint.")
+                                                        
+    return optimizer, start_epoch
+
 def build_yolo_optimizer(cfg, model, resume=None):
     print('==============================')
     print('Optimizer: {}'.format(cfg.optimizer))
@@ -54,7 +106,6 @@ def build_yolo_optimizer(cfg, model, resume=None):
                                                         
     return optimizer, start_epoch
 
-
 def build_rtdetr_optimizer(cfg, model, resume=None):
     print('==============================')
     print('Optimizer: {}'.format(cfg.optimizer))