浏览代码

add YOLOX-Style Trainer

yjh0410 2 年之前
父节点
当前提交
492cdb5d06

+ 11 - 9
README.md

@@ -141,15 +141,17 @@ python train.py --cuda -d coco --root path/to/COCO -m yolov1 -bs 16 --max_epoch
 
 * My YOLO:
 
-| Model    | Scale | Epoch | AP<sup>test<br>0.5:0.95 | AP<sup>test<br>0.5 | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
-|----------|-------|-------|-------------------------|--------------------|------------------------|-------------------|-------------------|--------------------|--------|
-| YOLOvx-P |  640  |  300  |                         |                    |                        |                   |                   |                    |  |
-| YOLOvx-N |  640  |  300  |                         |                    |                        |                   |                   |                    |  |
-| YOLOvx-T |  640  |  300  |                         |                    |                        |                   |                   |                    |  |
-| YOLOvx-S |  640  |  300  |                         |                    |                        |                   |                   |                    |  |
-| YOLOvx-M |  640  |  300  |                         |                    |                        |                   |                   |                    |  |
-| YOLOvx-L |  640  |  300  |         50.2            |        68.6        |          50.0          |        68.4       |      176.6        |        47.6        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolovx_l_coco.pth) |
-| YOLOvx-X |  640  |  300  |                         |                    |                        |                   |                   |                    |  |
+| Model    | Scale | Batch | Epoch | AP<sup>test<br>0.5:0.95 | AP<sup>test<br>0.5 | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
+|----------|-------|-------|-------|-------------------------|--------------------|------------------------|-------------------|-------------------|--------------------|--------|
+| YOLOvx-P |  640  | 4xb32 |  300  |                         |                    |                        |                   |                   |                    |  |
+| YOLOvx-N |  640  | 4xb32 |  300  |                         |                    |                        |                   |                   |                    |  |
+| YOLOvx-T |  640  | 4xb32 |  300  |                         |                    |                        |                   |                   |                    |  |
+| YOLOvx-S |  640  | 4xb32 |  300  |                         |                    |                        |                   |                   |                    |  |
+| YOLOvx-M |  640  | 8xb16 |  300  |                         |                    |                        |                   |                   |                    |  |
+| YOLOvx-L |  640  | 8xb16 |  300  |         50.2            |        68.6        |          50.0          |        68.4       |      176.6        |        47.6        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolovx_l_coco.pth) |
+| YOLOvx-X |  640  |       |       |                         |                    |                        |                   |                   |                    |  |
+
+Due to my limited computing resources, I can not to train `YOLOvx-X` with the setting of `batch size=128`.
 
 * Redesigned RT-DETR:
 

+ 1 - 1
config/model_config/yolov1_config.py

@@ -28,5 +28,5 @@ yolov1_cfg = {
     'loss_cls_weight': 1.0,
     'loss_box_weight': 5.0,
     # training configuration
-    'trainer_type': 'yolo',
+    'trainer_type': 'yolov8',
 }

+ 1 - 1
config/model_config/yolov2_config.py

@@ -35,5 +35,5 @@ yolov2_cfg = {
     'loss_cls_weight': 1.0,
     'loss_box_weight': 5.0,
     # training configuration
-    'trainer_type': 'yolo',
+    'trainer_type': 'yolov8',
 }

+ 2 - 2
config/model_config/yolov3_config.py

@@ -45,7 +45,7 @@ yolov3_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
     'yolov3_tiny':{
@@ -92,7 +92,7 @@ yolov3_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
 }

+ 2 - 2
config/model_config/yolov4_config.py

@@ -45,7 +45,7 @@ yolov4_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
     'yolov4_tiny':{
@@ -92,7 +92,7 @@ yolov4_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
 }

+ 5 - 5
config/model_config/yolov5_config.py

@@ -44,7 +44,7 @@ yolov5_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
     'yolov5_s':{
@@ -90,7 +90,7 @@ yolov5_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
     'yolov5_m':{
@@ -136,7 +136,7 @@ yolov5_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
     'yolov5_l':{
@@ -182,7 +182,7 @@ yolov5_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
     'yolov5_x':{
@@ -228,7 +228,7 @@ yolov5_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
 }

+ 3 - 3
config/model_config/yolov7_config.py

@@ -47,7 +47,7 @@ yolov7_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
     'yolov7':{
@@ -96,7 +96,7 @@ yolov7_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
     'yolov7_x':{
@@ -145,7 +145,7 @@ yolov7_cfg = {
         'loss_cls_weight': 1.0,
         'loss_box_weight': 5.0,
         # ---------------- Train config ----------------
-        'trainer_type': 'yolo',
+        'trainer_type': 'yolov8',
     },
 
 }

+ 322 - 5
engine.py

@@ -22,8 +22,8 @@ from utils.solver.lr_scheduler import build_lr_scheduler
 from dataset.build import build_dataset, build_transform
 
 
-# Trainer refered to YOLOv8
-class YoloTrainer(object):
+# YOLOv8-style Trainer
+class Yolov8Trainer(object):
     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
         # ------------------- basic parameters -------------------
         self.args = args
@@ -37,6 +37,10 @@ class YoloTrainer(object):
         self.world_size = world_size
         self.heavy_eval = False
         self.second_stage = False
+        """
+            The below hyperparameters refer to YOLOv8.
+        """
+
         self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.937, 'weight_decay': 5e-4, 'lr0': 0.01}
         self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
         self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
@@ -344,7 +348,315 @@ class YoloTrainer(object):
         return images, targets, new_img_size
 
 
-# Trainer refered to RTMDet
+# YOLOX-syle Trainer
+class YoloxTrainer(object):
+    def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
+        # ------------------- basic parameters -------------------
+        self.args = args
+        self.epoch = 0
+        self.best_map = -1.
+        self.device = device
+        self.criterion = criterion
+        self.world_size = world_size
+        self.no_aug_epoch = args.no_aug_epoch
+        self.heavy_eval = False
+        self.second_stage = False
+        """
+            The below hyperparameters refer to YOLOX: https://github.com/open-mmlab/mmyolo/tree/main/configs/rtmdet.
+        """
+        self.optimizer_dict = {'optimizer': 'sgd', 'momentum': 0.9, 'weight_decay': 5e-4, 'lr0': 0.01}
+        self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
+        self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.05}
+        self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}        
+
+        # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
+        self.data_cfg = data_cfg
+        self.model_cfg = model_cfg
+        self.trans_cfg = trans_cfg
+
+        # ---------------------------- Build Transform ----------------------------
+        self.train_transform, self.trans_cfg = build_transform(
+            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.val_transform, _ = build_transform(
+            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
+
+        # ---------------------------- Build Dataset & Dataloader ----------------------------
+        self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
+        self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
+
+        # ---------------------------- Build Evaluator ----------------------------
+        self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
+
+        # ---------------------------- Build Grad. Scaler ----------------------------
+        self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
+
+        # ---------------------------- Build Optimizer ----------------------------
+        self.optimizer_dict['lr0'] *= self.args.batch_size / 64
+        self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, self.args.resume)
+
+        # ---------------------------- Build LR Scheduler ----------------------------
+        self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
+        self.lr_scheduler.last_epoch = self.start_epoch - 1  # do not move
+        if self.args.resume:
+            self.lr_scheduler.step()
+
+        # ---------------------------- Build Model-EMA ----------------------------
+        if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
+            print('Build ModelEMA ...')
+            self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
+        else:
+            self.model_ema = None
+
+
+    def train(self, model):
+        for epoch in range(self.start_epoch, self.args.max_epoch):
+            if self.args.distributed:
+                self.train_loader.batch_sampler.sampler.set_epoch(epoch)
+
+            # check second stage
+            if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
+                self.check_second_stage()
+
+            # train one epoch
+            self.epoch = epoch
+            self.train_one_epoch(model)
+
+            # eval one epoch
+            if self.heavy_eval:
+                model_eval = model.module if self.args.distributed else model
+                self.eval(model_eval)
+            else:
+                model_eval = model.module if self.args.distributed else model
+                if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
+                    self.eval(model_eval)
+
+
+    def eval(self, model):
+        # chech model
+        model_eval = model if self.model_ema is None else self.model_ema.ema
+
+        # path to save model
+        path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
+        os.makedirs(path_to_save, exist_ok=True)
+
+        if distributed_utils.is_main_process():
+            # check evaluator
+            if self.evaluator is None:
+                print('No evaluator ... save model and go on training.')
+                print('Saving state, epoch: {}'.format(self.epoch + 1))
+                weight_name = '{}_no_eval.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(path_to_save, weight_name)
+                torch.save({'model': model_eval.state_dict(),
+                            'mAP': -1.,
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)               
+            else:
+                print('eval ...')
+                # set eval mode
+                model_eval.trainable = False
+                model_eval.eval()
+
+                # evaluate
+                with torch.no_grad():
+                    self.evaluator.evaluate(model_eval)
+
+                # save model
+                cur_map = self.evaluator.map
+                if cur_map > self.best_map:
+                    # update best-map
+                    self.best_map = cur_map
+                    # save model
+                    print('Saving state, epoch:', self.epoch + 1)
+                    weight_name = '{}_best.pth'.format(self.args.model)
+                    checkpoint_path = os.path.join(path_to_save, weight_name)
+                    torch.save({'model': model_eval.state_dict(),
+                                'mAP': round(self.best_map*100, 1),
+                                'optimizer': self.optimizer.state_dict(),
+                                'epoch': self.epoch,
+                                'args': self.args}, 
+                                checkpoint_path)                      
+
+                # set train mode.
+                model_eval.trainable = True
+                model_eval.train()
+
+        if self.args.distributed:
+            # wait for all processes to synchronize
+            dist.barrier()
+
+
+    def train_one_epoch(self, model):
+        # basic parameters
+        epoch_size = len(self.train_loader)
+        img_size = self.args.img_size
+        t0 = time.time()
+        nw = epoch_size * self.args.wp_epoch
+
+        # Train one epoch
+        for iter_i, (images, targets) in enumerate(self.train_loader):
+            ni = iter_i + self.epoch * epoch_size
+            # Warmup
+            if ni <= nw:
+                xi = [0, nw]  # x interp
+                for j, x in enumerate(self.optimizer.param_groups):
+                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
+                    x['lr'] = np.interp(
+                        ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
+                    if 'momentum' in x:
+                        x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
+                                
+            # To device
+            images = images.to(self.device, non_blocking=True).float() / 255.
+
+            # Multi scale
+            if self.args.multi_scale:
+                images, targets, img_size = self.rescale_image_targets(
+                    images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
+            else:
+                targets = self.refine_targets(targets, self.args.min_box_size)
+                
+            # Visualize train targets
+            if self.args.vis_tgt:
+                vis_data(images*255, targets)
+
+            # Inference
+            with torch.cuda.amp.autocast(enabled=self.args.fp16):
+                outputs = model(images)
+                # Compute loss
+                loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
+                losses = loss_dict['losses']
+
+                loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
+
+            # Backward
+            self.scaler.scale(losses).backward()
+
+            # Optimize
+            self.scaler.step(self.optimizer)
+            self.scaler.update()
+            self.optimizer.zero_grad()
+            # ema
+            if self.model_ema is not None:
+                self.model_ema.update(model)
+
+            # Logs
+            if distributed_utils.is_main_process() and iter_i % 10 == 0:
+                t1 = time.time()
+                cur_lr = [param_group['lr']  for param_group in self.optimizer.param_groups]
+                # basic infor
+                log =  '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
+                log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
+                log += '[lr: {:.6f}]'.format(cur_lr[2])
+                # loss infor
+                for k in loss_dict_reduced.keys():
+                    log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
+
+                # other infor
+                log += '[time: {:.2f}]'.format(t1 - t0)
+                log += '[size: {}]'.format(img_size)
+
+                # print log infor
+                print(log, flush=True)
+                
+                t0 = time.time()
+        
+        # LR Schedule
+        self.lr_scheduler.step()
+        
+
+    def check_second_stage(self):
+        # set second stage
+        print('============== Second stage of Training ==============')
+        self.second_stage = True
+
+        # close mosaic augmentation
+        if self.train_loader.dataset.mosaic_prob > 0.:
+            print(' - Close < Mosaic Augmentation > ...')
+            self.train_loader.dataset.mosaic_prob = 0.
+            self.heavy_eval = True
+
+        # close mixup augmentation
+        if self.train_loader.dataset.mixup_prob > 0.:
+            print(' - Close < Mixup Augmentation > ...')
+            self.train_loader.dataset.mixup_prob = 0.
+            self.heavy_eval = True
+
+        # close rotation augmentation
+        if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
+            print(' - Close < degress of rotation > ...')
+            self.trans_cfg['degrees'] = 0.0
+        if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
+            print(' - Close < shear of rotation >...')
+            self.trans_cfg['shear'] = 0.0
+        if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
+            print(' - Close < perspective of rotation > ...')
+            self.trans_cfg['perspective'] = 0.0
+
+        # build a new transform for second stage
+        print(' - Rebuild transforms ...')
+        self.train_transform, self.trans_cfg = build_transform(
+            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.train_loader.dataset.transform = self.train_transform
+        
+
+    def refine_targets(self, targets, min_box_size):
+        # rescale targets
+        for tgt in targets:
+            boxes = tgt["boxes"].clone()
+            labels = tgt["labels"].clone()
+            # refine tgt
+            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= min_box_size)
+
+            tgt["boxes"] = boxes[keep]
+            tgt["labels"] = labels[keep]
+        
+        return targets
+
+
+    def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
+        """
+            Deployed for Multi scale trick.
+        """
+        if isinstance(stride, int):
+            max_stride = stride
+        elif isinstance(stride, list):
+            max_stride = max(stride)
+
+        # During training phase, the shape of input image is square.
+        old_img_size = images.shape[-1]
+        new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
+        new_img_size = new_img_size // max_stride * max_stride  # size
+        if new_img_size / old_img_size != 1:
+            # interpolate
+            images = torch.nn.functional.interpolate(
+                                input=images, 
+                                size=new_img_size, 
+                                mode='bilinear', 
+                                align_corners=False)
+        # rescale targets
+        for tgt in targets:
+            boxes = tgt["boxes"].clone()
+            labels = tgt["labels"].clone()
+            boxes = torch.clamp(boxes, 0, old_img_size)
+            # rescale box
+            boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
+            boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
+            # refine tgt
+            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= min_box_size)
+
+            tgt["boxes"] = boxes[keep]
+            tgt["labels"] = labels[keep]
+
+        return images, targets, new_img_size
+
+
+# RTMDet-syle Trainer
 class RTMTrainer(object):
     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
         # ------------------- basic parameters -------------------
@@ -358,6 +670,9 @@ class RTMTrainer(object):
         self.clip_grad = 35
         self.heavy_eval = False
         self.second_stage = False
+        """
+            The below hyperparameters refer to RTMDet: https://github.com/open-mmlab/mmyolo/tree/main/configs/rtmdet.
+        """
         self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 5e-2, 'lr0': 0.001}
         self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
         self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
@@ -990,8 +1305,10 @@ class DetrTrainer(object):
 
 # Build Trainer
 def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
-    if model_cfg['trainer_type'] == 'yolo':
-        return YoloTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
+    if model_cfg['trainer_type'] == 'yolov8':
+        return Yolov8Trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
+    elif model_cfg['trainer_type'] == 'yolox':
+        return YoloxTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
     elif model_cfg['trainer_type'] == 'rtmdet':
         return RTMTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
     elif model_cfg['trainer_type'] == 'detr':

+ 4 - 4
train_ddp.sh

@@ -1,11 +1,11 @@
-# train YOLO with 8 GPUs
-# 使用8张GPU来训练YOLO
-python -m torch.distributed.run --nproc_per_node=8 train.py \
+# train YOLO with 4 GPUs
+# 使用4张GPU来训练YOLO
+python -m torch.distributed.run --nproc_per_node=4 train.py \
                                                     --cuda \
                                                     -dist \
                                                     -d coco \
                                                     --root /data/datasets/ \
-                                                    -m yolovx_m \
+                                                    -m yolovx_s \
                                                     -bs 128 \
                                                     -size 640 \
                                                     --wp_epoch 3 \