yjh0410 1 year ago
parent
commit
b1ed050e0e

+ 1 - 1
config/__init__.py

@@ -86,7 +86,7 @@ from .model_config.yolov7_config import yolov7_cfg
 from .model_config.yolov8_config import yolov8_cfg
 from .model_config.yolox_config import yolox_cfg
 ## My RTCDet series
-from .model_config.rtcdet_config import rtcdet_cfg
+from .model_config.rtcdet_config import rtcdet_cfg, rtcdet_seg_cfg, rtcdet_pos_cfg, rtcdet_seg_pos_cfg
 
 def build_model_config(args):
     print('==============================')

+ 243 - 32
config/model_config/rtcdet_config.py

@@ -1,6 +1,7 @@
-# rtcdet Config
+# Real-time Convolution Object Detector
 
 
+# ------------------- Det task --------------------
 rtcdet_cfg = {
     'rtcdet_n':{
         # ---------------- Model config ----------------
@@ -27,11 +28,17 @@ rtcdet_cfg = {
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
         ## Head
-        'num_cls_head': 2,
-        'num_reg_head': 2,
-        'head_act': 'silu',
-        'head_norm': 'BN',
-        'head_depthwise': False,
+        'det_head': {'name': 'decoupled_head',
+                     'num_cls_head': 2,
+                     'num_reg_head': 2,
+                     'head_act': 'silu',
+                     'head_norm': 'BN',
+                     'head_depthwise': False,  
+                     },
+        'seg_head': {'name': None,
+                     },
+        'pos_head': {'name': None,
+                     },
         # ---------------- Train config ----------------
         ## input
         'multi_scale': [0.5, 1.25],   # 320 -> 800
@@ -74,11 +81,17 @@ rtcdet_cfg = {
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
         ## Head
-        'head_act': 'silu',
-        'head_norm': 'BN',
-        'num_cls_head': 2,
-        'num_reg_head': 2,
-        'head_depthwise': False,
+        'det_head': {'name': 'decoupled_head',
+                     'num_cls_head': 2,
+                     'num_reg_head': 2,
+                     'head_act': 'silu',
+                     'head_norm': 'BN',
+                     'head_depthwise': False,  
+                     },
+        'seg_head': {'name': None,
+                     },
+        'pos_head': {'name': None,
+                     },
         # ---------------- Train config ----------------
         ## input
         'multi_scale': [0.5, 1.25],   # 320 -> 800
@@ -121,11 +134,17 @@ rtcdet_cfg = {
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
         ## Head
-        'head_act': 'silu',
-        'head_norm': 'BN',
-        'num_cls_head': 2,
-        'num_reg_head': 2,
-        'head_depthwise': False,
+        'det_head': {'name': 'decoupled_head',
+                     'num_cls_head': 2,
+                     'num_reg_head': 2,
+                     'head_act': 'silu',
+                     'head_norm': 'BN',
+                     'head_depthwise': False,  
+                     },
+        'seg_head': {'name': None,
+                     },
+        'pos_head': {'name': None,
+                     },
         # ---------------- Train config ----------------
         ## input
         'multi_scale': [0.5, 1.25],   # 320 -> 800
@@ -168,11 +187,17 @@ rtcdet_cfg = {
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
         ## Head
-        'head_act': 'silu',
-        'head_norm': 'BN',
-        'num_cls_head': 2,
-        'num_reg_head': 2,
-        'head_depthwise': False,
+        'det_head': {'name': 'decoupled_head',
+                     'num_cls_head': 2,
+                     'num_reg_head': 2,
+                     'head_act': 'silu',
+                     'head_norm': 'BN',
+                     'head_depthwise': False,  
+                     },
+        'seg_head': {'name': None,
+                     },
+        'pos_head': {'name': None,
+                     },
         # ---------------- Train config ----------------
         ## input
         'multi_scale': [0.5, 1.25],   # 320 -> 800
@@ -215,11 +240,17 @@ rtcdet_cfg = {
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
         ## Head
-        'head_act': 'silu',
-        'head_norm': 'BN',
-        'num_cls_head': 2,
-        'num_reg_head': 2,
-        'head_depthwise': False,
+        'det_head': {'name': 'decoupled_head',
+                     'num_cls_head': 2,
+                     'num_reg_head': 2,
+                     'head_act': 'silu',
+                     'head_norm': 'BN',
+                     'head_depthwise': False,  
+                     },
+        'seg_head': {'name': None,
+                     },
+        'pos_head': {'name': None,
+                     },
         # ---------------- Train config ----------------
         ## input
         'multi_scale': [0.5, 1.25],   # 320 -> 800
@@ -262,11 +293,17 @@ rtcdet_cfg = {
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
         ## Head
-        'head_act': 'silu',
-        'head_norm': 'BN',
-        'num_cls_head': 2,
-        'num_reg_head': 2,
-        'head_depthwise': False,
+        'det_head': {'name': 'decoupled_head',
+                     'num_cls_head': 2,
+                     'num_reg_head': 2,
+                     'head_act': 'silu',
+                     'head_norm': 'BN',
+                     'head_depthwise': False,  
+                     },
+        'seg_head': {'name': None,
+                     },
+        'pos_head': {'name': None,
+                     },
         # ---------------- Train config ----------------
         ## input
         'multi_scale': [0.5, 1.25],   # 320 -> 800
@@ -284,4 +321,178 @@ rtcdet_cfg = {
         'trainer_type': 'rtcdet',
     },
 
-}
+}
+
+
+# ------------------- Det + Seg task -------------------
+rtcdet_seg_cfg = {
+    'rtcdet_seg_n':{
+        # ---------------- Model config ----------------
+        ## Backbone
+        'bk_pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_depthwise': False,
+        'width': 0.25,
+        'depth': 0.34,
+        'ratio': 2.0,
+        'stride': [8, 16, 32],  # P3, P4, P5
+        'max_stride': 32,
+        ## Neck: SPP
+        'neck': 'sppf',
+        'neck_expand_ratio': 0.5,
+        'pooling_size': 5,
+        'neck_act': 'silu',
+        'neck_norm': 'BN',
+        'neck_depthwise': False,
+        ## Neck: PaFPN
+        'fpn': 'rtcdet_pafpn',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        ## Head
+        'det_head': {'name': 'decoupled_head',
+                     'num_cls_head': 2,
+                     'num_reg_head': 2,
+                     'head_act': 'silu',
+                     'head_norm': 'BN',
+                     'head_depthwise': False,  
+                     },
+        'seg_head': {'name': None,
+                     },
+        'pos_head': {'name': None,
+                     },
+        # ---------------- Train config ----------------
+        ## input
+        'multi_scale': [0.5, 1.25],   # 320 -> 800
+        'trans_type': 'yolox_n',
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        'matcher': "aligned_simota",
+        'matcher_hpy': {'soft_center_radius': 3.0,
+                        'topk_candidates': 13},
+        # ---------------- Loss config ----------------
+        ## loss weight
+        'loss_cls_weight': 1.0,
+        'loss_box_weight': 2.0,
+        # ---------------- Train config ----------------
+        'trainer_type': 'rtcdet',
+    },
+
+}
+
+
+# ------------------- Det + Pos task -------------------
+rtcdet_pos_cfg = {
+    'rtcdet_pos_n':{
+        # ---------------- Model config ----------------
+        ## Backbone
+        'bk_pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_depthwise': False,
+        'width': 0.25,
+        'depth': 0.34,
+        'ratio': 2.0,
+        'stride': [8, 16, 32],  # P3, P4, P5
+        'max_stride': 32,
+        ## Neck: SPP
+        'neck': 'sppf',
+        'neck_expand_ratio': 0.5,
+        'pooling_size': 5,
+        'neck_act': 'silu',
+        'neck_norm': 'BN',
+        'neck_depthwise': False,
+        ## Neck: PaFPN
+        'fpn': 'rtcdet_pafpn',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        ## Head
+        'det_head': {'name': 'decoupled_head',
+                     'num_cls_head': 2,
+                     'num_reg_head': 2,
+                     'head_act': 'silu',
+                     'head_norm': 'BN',
+                     'head_depthwise': False,  
+                     },
+        'seg_head': {'name': None,
+                     },
+        'pos_head': {'name': None,
+                     },
+        # ---------------- Train config ----------------
+        ## input
+        'multi_scale': [0.5, 1.25],   # 320 -> 800
+        'trans_type': 'yolox_n',
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        'matcher': "aligned_simota",
+        'matcher_hpy': {'soft_center_radius': 3.0,
+                        'topk_candidates': 13},
+        # ---------------- Loss config ----------------
+        ## loss weight
+        'loss_cls_weight': 1.0,
+        'loss_box_weight': 2.0,
+        # ---------------- Train config ----------------
+        'trainer_type': 'rtcdet',
+    },
+
+}
+
+
+# ------------------- Det + Seg + Pos task -------------------
+rtcdet_seg_pos_cfg = {
+    'rtcdet_seg_pos_n':{
+        # ---------------- Model config ----------------
+        ## Backbone
+        'bk_pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_depthwise': False,
+        'width': 0.25,
+        'depth': 0.34,
+        'ratio': 2.0,
+        'stride': [8, 16, 32],  # P3, P4, P5
+        'max_stride': 32,
+        ## Neck: SPP
+        'neck': 'sppf',
+        'neck_expand_ratio': 0.5,
+        'pooling_size': 5,
+        'neck_act': 'silu',
+        'neck_norm': 'BN',
+        'neck_depthwise': False,
+        ## Neck: PaFPN
+        'fpn': 'rtcdet_pafpn',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        ## Head
+        'det_head': {'name': 'decoupled_head',
+                     'num_cls_head': 2,
+                     'num_reg_head': 2,
+                     'head_act': 'silu',
+                     'head_norm': 'BN',
+                     'head_depthwise': False,  
+                     },
+        'seg_head': {'name': None,
+                     },
+        'pos_head': {'name': None,
+                     },
+        # ---------------- Train config ----------------
+        ## input
+        'multi_scale': [0.5, 1.25],   # 320 -> 800
+        'trans_type': 'yolox_n',
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        'matcher': "aligned_simota",
+        'matcher_hpy': {'soft_center_radius': 3.0,
+                        'topk_candidates': 13},
+        # ---------------- Loss config ----------------
+        ## loss weight
+        'loss_cls_weight': 1.0,
+        'loss_box_weight': 2.0,
+        # ---------------- Train config ----------------
+        'trainer_type': 'rtcdet',
+    },
+
+}

+ 12 - 3
demo.py

@@ -114,7 +114,10 @@ def detect(args,
                 
                 # inference
                 t0 = time.time()
-                bboxes, scores, labels = model(x)
+                outputs = model(x)
+                scores = outputs['scores']
+                labels = outputs['labels']
+                bboxes = outputs['bboxes']
                 t1 = time.time()
                 print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
 
@@ -180,7 +183,10 @@ def detect(args,
 
                 # inference
                 t0 = time.time()
-                bboxes, scores, labels = model(x)
+                outputs = model(x)
+                scores = outputs['scores']
+                labels = outputs['labels']
+                bboxes = outputs['bboxes']
                 t1 = time.time()
                 print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
 
@@ -234,7 +240,10 @@ def detect(args,
 
             # inference
             t0 = time.time()
-            bboxes, scores, labels = model(x)
+            outputs = model(x)
+            scores = outputs['scores']
+            labels = outputs['labels']
+            bboxes = outputs['bboxes']
             t1 = time.time()
             print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
 

+ 681 - 18
engine.py

@@ -22,7 +22,8 @@ from utils.solver.lr_scheduler import build_lr_scheduler
 from dataset.build import build_dataset, build_transform
 
 
-# YOLOv8 Trainer
+# ----------------------- Det trainers -----------------------
+## YOLOv8 Trainer
 class Yolov8Trainer(object):
     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
         # ------------------- basic parameters -------------------
@@ -393,8 +394,7 @@ class Yolov8Trainer(object):
 
         return images, targets, new_img_size
 
-
-# YOLOX Trainer
+## YOLOX Trainer
 class YoloxTrainer(object):
     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
         # ------------------- basic parameters -------------------
@@ -758,8 +758,7 @@ class YoloxTrainer(object):
 
         return images, targets, new_img_size
 
-
-# RTCDet Trainer
+## RTCDet Trainer
 class RTCTrainer(object):
     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
         # ------------------- basic parameters -------------------
@@ -1129,8 +1128,7 @@ class RTCTrainer(object):
             args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
         self.train_loader.dataset.transform = self.train_transform
    
-
-# RTRDet Trainer
+## RTRDet Trainer
 class RTRTrainer(object):
     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
         # ------------------- Basic parameters -------------------
@@ -1519,16 +1517,681 @@ class RTRTrainer(object):
         self.train_loader.dataset.transform = self.train_transform
         
 
-# Build Trainer
-def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
-    if model_cfg['trainer_type'] == 'yolov8':
-        return Yolov8Trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
-    elif model_cfg['trainer_type'] == 'yolox':
-        return YoloxTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
-    elif model_cfg['trainer_type'] == 'rtcdet':
-        return RTCTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
-    elif model_cfg['trainer_type'] == 'rtrdet':
-        return RTRTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
+# ----------------------- Det + Seg trainers -----------------------
+## RTCDet Trainer for Det + Seg
+class RTCTrainerDS(object):
+    def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
+        # ------------------- basic parameters -------------------
+        self.args = args
+        self.epoch = 0
+        self.best_map = -1.
+        self.device = device
+        self.criterion = criterion
+        self.world_size = world_size
+        self.grad_accumulate = args.grad_accumulate
+        self.clip_grad = 35
+        self.heavy_eval = False
+        # weak augmentatino stage
+        self.second_stage = False
+        self.third_stage = False
+        self.second_stage_epoch = args.no_aug_epoch
+        self.third_stage_epoch = args.no_aug_epoch // 2
+        # path to save model
+        self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
+        os.makedirs(self.path_to_save, exist_ok=True)
+
+        # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
+        self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 5e-2, 'lr0': 0.001}
+        self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
+        self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
+        self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}        
+
+        # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
+        self.data_cfg = data_cfg
+        self.model_cfg = model_cfg
+        self.trans_cfg = trans_cfg
+
+        # ---------------------------- Build Transform ----------------------------
+        self.train_transform, self.trans_cfg = build_transform(
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.val_transform, _ = build_transform(
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
+
+        # ---------------------------- Build Dataset & Dataloader ----------------------------
+        self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
+        self.train_loader = build_dataloader(args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
+
+        # ---------------------------- Build Evaluator ----------------------------
+        self.evaluator = build_evluator(args, self.data_cfg, self.val_transform, self.device)
+
+        # ---------------------------- Build Grad. Scaler ----------------------------
+        self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
+
+        # ---------------------------- Build Optimizer ----------------------------
+        self.optimizer_dict['lr0'] *= args.batch_size * self.grad_accumulate / 64
+        self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, args.resume)
+
+        # ---------------------------- Build LR Scheduler ----------------------------
+        self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch - args.no_aug_epoch)
+        self.lr_scheduler.last_epoch = self.start_epoch - 1  # do not move
+        if self.args.resume and self.args.resume != 'None':
+            self.lr_scheduler.step()
+
+        # ---------------------------- Build Model-EMA ----------------------------
+        if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
+            print('Build ModelEMA ...')
+            self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
+        else:
+            self.model_ema = None
+
+    def train(self, model):
+        for epoch in range(self.start_epoch, self.args.max_epoch):
+            if self.args.distributed:
+                self.train_loader.batch_sampler.sampler.set_epoch(epoch)
+
+            # check second stage
+            if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
+                self.check_second_stage()
+                # save model of the last mosaic epoch
+                weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
+                torch.save({'model': model.state_dict(),
+                            'mAP': round(self.evaluator.map*100, 1),
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)
+
+            # check third stage
+            if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
+                self.check_third_stage()
+                # save model of the last mosaic epoch
+                weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                print('Saving state of the last weak augment epoch-{}.'.format(self.epoch))
+                torch.save({'model': model.state_dict(),
+                            'mAP': round(self.evaluator.map*100, 1),
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)
+
+            # train one epoch
+            self.epoch = epoch
+            self.train_one_epoch(model)
+
+            # eval one epoch
+            if self.heavy_eval:
+                model_eval = model.module if self.args.distributed else model
+                self.eval(model_eval)
+            else:
+                model_eval = model.module if self.args.distributed else model
+                if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
+                    self.eval(model_eval)
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 epoch")
+                break
+
+    def eval(self, model):
+        # chech model
+        model_eval = model if self.model_ema is None else self.model_ema.ema
+
+        if distributed_utils.is_main_process():
+            # check evaluator
+            if self.evaluator is None:
+                print('No evaluator ... save model and go on training.')
+                print('Saving state, epoch: {}'.format(self.epoch))
+                weight_name = '{}_no_eval.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                torch.save({'model': model_eval.state_dict(),
+                            'mAP': -1.,
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)               
+            else:
+                print('eval ...')
+                # set eval mode
+                model_eval.trainable = False
+                model_eval.eval()
+
+                # evaluate
+                with torch.no_grad():
+                    self.evaluator.evaluate(model_eval)
+
+                # save model
+                cur_map = self.evaluator.map
+                if cur_map > self.best_map:
+                    # update best-map
+                    self.best_map = cur_map
+                    # save model
+                    print('Saving state, epoch:', self.epoch)
+                    weight_name = '{}_best.pth'.format(self.args.model)
+                    checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                    torch.save({'model': model_eval.state_dict(),
+                                'mAP': round(self.best_map*100, 1),
+                                'optimizer': self.optimizer.state_dict(),
+                                'epoch': self.epoch,
+                                'args': self.args}, 
+                                checkpoint_path)                      
+
+                # set train mode.
+                model_eval.trainable = True
+                model_eval.train()
+
+        if self.args.distributed:
+            # wait for all processes to synchronize
+            dist.barrier()
+
+    def train_one_epoch(self, model):
+        # basic parameters
+        epoch_size = len(self.train_loader)
+        img_size = self.args.img_size
+        t0 = time.time()
+        nw = epoch_size * self.args.wp_epoch
+
+        # Train one epoch
+        for iter_i, (images, targets) in enumerate(self.train_loader):
+            ni = iter_i + self.epoch * epoch_size
+            # Warmup
+            if ni <= nw:
+                xi = [0, nw]  # x interp
+                for j, x in enumerate(self.optimizer.param_groups):
+                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
+                    x['lr'] = np.interp(
+                        ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
+                    if 'momentum' in x:
+                        x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
+                                
+            # To device
+            images = images.to(self.device, non_blocking=True).float() / 255.
+
+            # Multi scale
+            if self.args.multi_scale:
+                images, targets, img_size = self.rescale_image_targets(
+                    images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
+            else:
+                targets = self.refine_targets(targets, self.args.min_box_size)
+                
+            # Visualize train targets
+            if self.args.vis_tgt:
+                vis_data(images*255, targets, self.data_cfg['num_classes'])
+
+            # Inference
+            with torch.cuda.amp.autocast(enabled=self.args.fp16):
+                outputs = model(images)
+                # Compute loss
+                loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch, task='det_seg')
+                det_loss_dict = loss_dict['det_loss_dict']
+                seg_loss_dict = loss_dict['seg_loss_dict']
+
+                # TODO: finish the backward + optimize
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 iteration")
+                break
+
+        # LR Schedule
+        if not self.second_stage:
+            self.lr_scheduler.step()
+
+    def refine_targets(self, targets, min_box_size):
+        # rescale targets
+        for tgt in targets:
+            boxes = tgt["boxes"].clone()
+            labels = tgt["labels"].clone()
+            # refine tgt
+            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= min_box_size)
+
+            tgt["boxes"] = boxes[keep]
+            tgt["labels"] = labels[keep]
+        
+        return targets
+
+    def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
+        """
+            Deployed for Multi scale trick.
+        """
+        if isinstance(stride, int):
+            max_stride = stride
+        elif isinstance(stride, list):
+            max_stride = max(stride)
+
+        # During training phase, the shape of input image is square.
+        old_img_size = images.shape[-1]
+        new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
+        new_img_size = new_img_size // max_stride * max_stride  # size
+        if new_img_size / old_img_size != 1:
+            # interpolate
+            images = torch.nn.functional.interpolate(
+                                input=images, 
+                                size=new_img_size, 
+                                mode='bilinear', 
+                                align_corners=False)
+        # rescale targets
+        for tgt in targets:
+            boxes = tgt["boxes"].clone()
+            labels = tgt["labels"].clone()
+            boxes = torch.clamp(boxes, 0, old_img_size)
+            # rescale box
+            boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
+            boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
+            # refine tgt
+            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= min_box_size)
+
+            tgt["boxes"] = boxes[keep]
+            tgt["labels"] = labels[keep]
+
+        return images, targets, new_img_size
+
+    def check_second_stage(self):
+        # set second stage
+        print('============== Second stage of Training ==============')
+        self.second_stage = True
+
+        # close mosaic augmentation
+        if self.train_loader.dataset.mosaic_prob > 0.:
+            print(' - Close < Mosaic Augmentation > ...')
+            self.train_loader.dataset.mosaic_prob = 0.
+            self.heavy_eval = True
+
+        # close mixup augmentation
+        if self.train_loader.dataset.mixup_prob > 0.:
+            print(' - Close < Mixup Augmentation > ...')
+            self.train_loader.dataset.mixup_prob = 0.
+            self.heavy_eval = True
+
+        # close rotation augmentation
+        if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
+            print(' - Close < degress of rotation > ...')
+            self.trans_cfg['degrees'] = 0.0
+        if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
+            print(' - Close < shear of rotation >...')
+            self.trans_cfg['shear'] = 0.0
+        if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
+            print(' - Close < perspective of rotation > ...')
+            self.trans_cfg['perspective'] = 0.0
+
+        # build a new transform for second stage
+        print(' - Rebuild transforms ...')
+        self.train_transform, self.trans_cfg = build_transform(
+            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.train_loader.dataset.transform = self.train_transform
+        
+    def check_third_stage(self):
+        # set third stage
+        print('============== Third stage of Training ==============')
+        self.third_stage = True
+
+        # close random affine
+        if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
+            print(' - Close < translate of affine > ...')
+            self.trans_cfg['translate'] = 0.0
+        if 'scale' in self.trans_cfg.keys():
+            print(' - Close < scale of affine >...')
+            self.trans_cfg['scale'] = [1.0, 1.0]
+
+        # build a new transform for second stage
+        print(' - Rebuild transforms ...')
+        self.train_transform, self.trans_cfg = build_transform(
+            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.train_loader.dataset.transform = self.train_transform
+   
+
+# ----------------------- Det + Seg + Pos trainers -----------------------
+## RTCDet Trainer for Det + Seg + HumanPose
+class RTCTrainerDSP(object):
+    def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
+        # ------------------- basic parameters -------------------
+        self.args = args
+        self.epoch = 0
+        self.best_map = -1.
+        self.device = device
+        self.criterion = criterion
+        self.world_size = world_size
+        self.grad_accumulate = args.grad_accumulate
+        self.clip_grad = 35
+        self.heavy_eval = False
+        # weak augmentatino stage
+        self.second_stage = False
+        self.third_stage = False
+        self.second_stage_epoch = args.no_aug_epoch
+        self.third_stage_epoch = args.no_aug_epoch // 2
+        # path to save model
+        self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
+        os.makedirs(self.path_to_save, exist_ok=True)
+
+        # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
+        self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 5e-2, 'lr0': 0.001}
+        self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
+        self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
+        self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}        
+
+        # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
+        self.data_cfg = data_cfg
+        self.model_cfg = model_cfg
+        self.trans_cfg = trans_cfg
+
+        # ---------------------------- Build Transform ----------------------------
+        self.train_transform, self.trans_cfg = build_transform(
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.val_transform, _ = build_transform(
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
+
+        # ---------------------------- Build Dataset & Dataloader ----------------------------
+        self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
+        self.train_loader = build_dataloader(args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
+
+        # ---------------------------- Build Evaluator ----------------------------
+        self.evaluator = build_evluator(args, self.data_cfg, self.val_transform, self.device)
+
+        # ---------------------------- Build Grad. Scaler ----------------------------
+        self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
+
+        # ---------------------------- Build Optimizer ----------------------------
+        self.optimizer_dict['lr0'] *= args.batch_size * self.grad_accumulate / 64
+        self.optimizer, self.start_epoch = build_yolo_optimizer(self.optimizer_dict, model, args.resume)
+
+        # ---------------------------- Build LR Scheduler ----------------------------
+        self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch - args.no_aug_epoch)
+        self.lr_scheduler.last_epoch = self.start_epoch - 1  # do not move
+        if self.args.resume and self.args.resume != 'None':
+            self.lr_scheduler.step()
+
+        # ---------------------------- Build Model-EMA ----------------------------
+        if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
+            print('Build ModelEMA ...')
+            self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
+        else:
+            self.model_ema = None
+
+    def train(self, model):
+        for epoch in range(self.start_epoch, self.args.max_epoch):
+            if self.args.distributed:
+                self.train_loader.batch_sampler.sampler.set_epoch(epoch)
+
+            # check second stage
+            if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
+                self.check_second_stage()
+                # save model of the last mosaic epoch
+                weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
+                torch.save({'model': model.state_dict(),
+                            'mAP': round(self.evaluator.map*100, 1),
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)
+
+            # check third stage
+            if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
+                self.check_third_stage()
+                # save model of the last mosaic epoch
+                weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                print('Saving state of the last weak augment epoch-{}.'.format(self.epoch))
+                torch.save({'model': model.state_dict(),
+                            'mAP': round(self.evaluator.map*100, 1),
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)
+
+            # train one epoch
+            self.epoch = epoch
+            self.train_one_epoch(model)
+
+            # eval one epoch
+            if self.heavy_eval:
+                model_eval = model.module if self.args.distributed else model
+                self.eval(model_eval)
+            else:
+                model_eval = model.module if self.args.distributed else model
+                if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
+                    self.eval(model_eval)
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 epoch")
+                break
+
+    def eval(self, model):
+        # chech model
+        model_eval = model if self.model_ema is None else self.model_ema.ema
+
+        if distributed_utils.is_main_process():
+            # check evaluator
+            if self.evaluator is None:
+                print('No evaluator ... save model and go on training.')
+                print('Saving state, epoch: {}'.format(self.epoch))
+                weight_name = '{}_no_eval.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                torch.save({'model': model_eval.state_dict(),
+                            'mAP': -1.,
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)               
+            else:
+                print('eval ...')
+                # set eval mode
+                model_eval.trainable = False
+                model_eval.eval()
+
+                # evaluate
+                with torch.no_grad():
+                    self.evaluator.evaluate(model_eval)
+
+                # save model
+                cur_map = self.evaluator.map
+                if cur_map > self.best_map:
+                    # update best-map
+                    self.best_map = cur_map
+                    # save model
+                    print('Saving state, epoch:', self.epoch)
+                    weight_name = '{}_best.pth'.format(self.args.model)
+                    checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                    torch.save({'model': model_eval.state_dict(),
+                                'mAP': round(self.best_map*100, 1),
+                                'optimizer': self.optimizer.state_dict(),
+                                'epoch': self.epoch,
+                                'args': self.args}, 
+                                checkpoint_path)                      
+
+                # set train mode.
+                model_eval.trainable = True
+                model_eval.train()
+
+        if self.args.distributed:
+            # wait for all processes to synchronize
+            dist.barrier()
+
+    def train_one_epoch(self, model):
+        # basic parameters
+        epoch_size = len(self.train_loader)
+        img_size = self.args.img_size
+        t0 = time.time()
+        nw = epoch_size * self.args.wp_epoch
+
+        # Train one epoch
+        for iter_i, (images, targets) in enumerate(self.train_loader):
+            ni = iter_i + self.epoch * epoch_size
+            # Warmup
+            if ni <= nw:
+                xi = [0, nw]  # x interp
+                for j, x in enumerate(self.optimizer.param_groups):
+                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
+                    x['lr'] = np.interp(
+                        ni, xi, [self.warmup_dict['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
+                    if 'momentum' in x:
+                        x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
+                                
+            # To device
+            images = images.to(self.device, non_blocking=True).float() / 255.
+
+            # Multi scale
+            if self.args.multi_scale:
+                images, targets, img_size = self.rescale_image_targets(
+                    images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
+            else:
+                targets = self.refine_targets(targets, self.args.min_box_size)
+                
+            # Visualize train targets
+            if self.args.vis_tgt:
+                vis_data(images*255, targets, self.data_cfg['num_classes'])
+
+            # Inference
+            with torch.cuda.amp.autocast(enabled=self.args.fp16):
+                outputs = model(images)
+                # Compute loss
+                loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch, task='det_seg_pos')
+                det_loss_dict = loss_dict['det_loss_dict']
+                seg_loss_dict = loss_dict['seg_loss_dict']
+                pos_loss_dict = loss_dict['pos_loss_dict']
+                
+                # TODO: finish the backward + optimize
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 iteration")
+                break
+
+        # LR Schedule
+        if not self.second_stage:
+            self.lr_scheduler.step()
+
+    def refine_targets(self, targets, min_box_size):
+        # rescale targets
+        for tgt in targets:
+            boxes = tgt["boxes"].clone()
+            labels = tgt["labels"].clone()
+            # refine tgt
+            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= min_box_size)
+
+            tgt["boxes"] = boxes[keep]
+            tgt["labels"] = labels[keep]
+        
+        return targets
+
+    def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
+        """
+            Deployed for Multi scale trick.
+        """
+        if isinstance(stride, int):
+            max_stride = stride
+        elif isinstance(stride, list):
+            max_stride = max(stride)
+
+        # During training phase, the shape of input image is square.
+        old_img_size = images.shape[-1]
+        new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
+        new_img_size = new_img_size // max_stride * max_stride  # size
+        if new_img_size / old_img_size != 1:
+            # interpolate
+            images = torch.nn.functional.interpolate(
+                                input=images, 
+                                size=new_img_size, 
+                                mode='bilinear', 
+                                align_corners=False)
+        # rescale targets
+        for tgt in targets:
+            boxes = tgt["boxes"].clone()
+            labels = tgt["labels"].clone()
+            boxes = torch.clamp(boxes, 0, old_img_size)
+            # rescale box
+            boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
+            boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
+            # refine tgt
+            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= min_box_size)
+
+            tgt["boxes"] = boxes[keep]
+            tgt["labels"] = labels[keep]
+
+        return images, targets, new_img_size
+
+    def check_second_stage(self):
+        # set second stage
+        print('============== Second stage of Training ==============')
+        self.second_stage = True
+
+        # close mosaic augmentation
+        if self.train_loader.dataset.mosaic_prob > 0.:
+            print(' - Close < Mosaic Augmentation > ...')
+            self.train_loader.dataset.mosaic_prob = 0.
+            self.heavy_eval = True
+
+        # close mixup augmentation
+        if self.train_loader.dataset.mixup_prob > 0.:
+            print(' - Close < Mixup Augmentation > ...')
+            self.train_loader.dataset.mixup_prob = 0.
+            self.heavy_eval = True
+
+        # close rotation augmentation
+        if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
+            print(' - Close < degress of rotation > ...')
+            self.trans_cfg['degrees'] = 0.0
+        if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
+            print(' - Close < shear of rotation >...')
+            self.trans_cfg['shear'] = 0.0
+        if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
+            print(' - Close < perspective of rotation > ...')
+            self.trans_cfg['perspective'] = 0.0
+
+        # build a new transform for second stage
+        print(' - Rebuild transforms ...')
+        self.train_transform, self.trans_cfg = build_transform(
+            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.train_loader.dataset.transform = self.train_transform
+        
+    def check_third_stage(self):
+        # set third stage
+        print('============== Third stage of Training ==============')
+        self.third_stage = True
+
+        # close random affine
+        if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
+            print(' - Close < translate of affine > ...')
+            self.trans_cfg['translate'] = 0.0
+        if 'scale' in self.trans_cfg.keys():
+            print(' - Close < scale of affine >...')
+            self.trans_cfg['scale'] = [1.0, 1.0]
+
+        # build a new transform for second stage
+        print(' - Rebuild transforms ...')
+        self.train_transform, self.trans_cfg = build_transform(
+            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.train_loader.dataset.transform = self.train_transform
+   
+
+# Build Trainer
+def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
+    # ----------------------- Det trainers -----------------------
+    if   model_cfg['trainer_type'] == 'yolov8':
+        return Yolov8Trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
+    elif model_cfg['trainer_type'] == 'yolox':
+        return YoloxTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
+    elif model_cfg['trainer_type'] == 'rtcdet':
+        return RTCTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
+    elif model_cfg['trainer_type'] == 'rtrdet':
+        return RTRTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
+    
+    # ----------------------- Det + Seg trainers -----------------------
+    elif model_cfg['trainer_type'] == 'rtcdet_ds':
+        return RTCTrainerDS(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
+
+    # ----------------------- Det + Seg + Pos trainers -----------------------
+    elif model_cfg['trainer_type'] == 'rtcdet_dsp':
+        return RTCTrainerDSP(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
+
     else:
-        raise NotImplementedError
+        raise NotImplementedError(model_cfg['trainer_type'])
     

+ 4 - 4
eval.py

@@ -23,13 +23,13 @@ from models.detectors import build_model
 
 def parse_args():
     parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
-    # basic
+    # Basic setting
     parser.add_argument('-size', '--img_size', default=640, type=int,
                         help='the max size of input image')
     parser.add_argument('--cuda', action='store_true', default=False,
                         help='Use cuda')
 
-    # model
+    # Model setting
     parser.add_argument('-m', '--model', default='yolov1', type=str,
                         help='build yolo')
     parser.add_argument('--weight', default=None,
@@ -49,8 +49,8 @@ def parse_args():
     parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
                         help='Perform NMS operations regardless of category.')
 
-    # dataset
-    parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
+    # Data setting
+    parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
                         help='data root')
     parser.add_argument('-d', '--dataset', default='coco',
                         help='coco, voc.')

+ 3 - 1
evaluator/coco_evaluator.py

@@ -76,7 +76,9 @@ class COCOAPIEvaluator():
 
             # inference
             outputs = model(x)
-            bboxes, scores, labels = outputs
+            scores = outputs['scores']
+            labels = outputs['labels']
+            bboxes = outputs['bboxes']
 
             # rescale bboxes
             bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)

+ 3 - 1
evaluator/crowdhuman_evaluator.py

@@ -80,7 +80,9 @@ class CrowdHumanEvaluator():
             
             # inference
             outputs = model(x)
-            bboxes, scores, labels = outputs
+            scores = outputs['scores']
+            labels = outputs['labels']
+            bboxes = outputs['bboxes']
             
             # rescale bboxes
             bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)

+ 3 - 1
evaluator/customed_evaluator.py

@@ -59,7 +59,9 @@ class CustomedEvaluator():
 
             # inference
             outputs = model(x)
-            bboxes, scores, labels = outputs
+            scores = outputs['scores']
+            labels = outputs['labels']
+            bboxes = outputs['bboxes']
 
             # rescale bboxes
             bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)

+ 4 - 1
evaluator/voc_evaluator.py

@@ -71,7 +71,10 @@ class VOCAPIEvaluator():
 
             # forward
             t0 = time.time()
-            bboxes, scores, labels = net(x)
+            outputs = net(x)
+            scores = outputs['scores']
+            labels = outputs['labels']
+            bboxes = outputs['bboxes']
             detect_time = time.time() - t0
 
             # rescale bboxes

+ 3 - 1
evaluator/widerface_evaluator.py

@@ -71,7 +71,9 @@ class WiderFaceEvaluator():
 
             # inference
             outputs = model(x)
-            bboxes, scores, labels = outputs
+            scores = outputs['scores']
+            labels = outputs['labels']
+            bboxes = outputs['bboxes']
 
             # rescale bboxes
             bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)

+ 5 - 0
models/detectors/centernet/centernet.py

@@ -0,0 +1,5 @@
+# Objects as Points
+
+
+class CenterNet():
+    pass

+ 79 - 2
models/detectors/rtcdet/loss.py

@@ -76,7 +76,7 @@ class Criterion(object):
 
 
     # -------------------- Task loss functions --------------------
-    def __call__(self, outputs, targets, epoch=0):        
+    def compute_det_loss(self, outputs, targets, epoch=0):
         """
             Input:
                 outputs: (Dict) -> {
@@ -187,7 +187,84 @@ class Criterion(object):
                     )
 
         return loss_dict
-    
+
+    def compute_seg_loss(self, outputs, targets, epoch=0):
+        """
+            Input:
+                outputs: (Dict) -> {
+                    'pred_cls': (List[torch.Tensor] -> [B, M, Nc]),
+                    'pred_reg': (List[torch.Tensor] -> [B, M, 4]),
+                    'pred_box': (List[torch.Tensor] -> [B, M, 4]),
+                    'strides':  (List[Int])
+                }
+                target: (List[Dict]) [
+                    {'boxes':  (torch.Tensor) -> [N, 4], 
+                     'labels': (torch.Tensor) -> [N,],
+                     ...}, ...
+                     ]
+            Output:
+                loss_dict: (Dict) -> {
+                    'loss_cls': (torch.Tensor) It is a scalar.),
+                    'loss_box': (torch.Tensor) It is a scalar.),
+                    'loss_box_aux': (torch.Tensor) It is a scalar.),
+                    'losses':  (torch.Tensor) It is a scalar.),
+                }
+        """
+
+    def compute_pos_loss(self, outputs, targets, epoch=0):
+        """
+            Input:
+                outputs: (Dict) -> {
+                    'pred_cls': (List[torch.Tensor] -> [B, M, Nc]),
+                    'pred_reg': (List[torch.Tensor] -> [B, M, 4]),
+                    'pred_box': (List[torch.Tensor] -> [B, M, 4]),
+                    'strides':  (List[Int])
+                }
+                target: (List[Dict]) [
+                    {'boxes':  (torch.Tensor) -> [N, 4], 
+                     'labels': (torch.Tensor) -> [N,],
+                     ...}, ...
+                     ]
+            Output:
+                loss_dict: (Dict) -> {
+                    'loss_cls': (torch.Tensor) It is a scalar.),
+                    'loss_box': (torch.Tensor) It is a scalar.),
+                    'loss_box_aux': (torch.Tensor) It is a scalar.),
+                    'losses':  (torch.Tensor) It is a scalar.),
+                }
+        """
+
+    def __call__(self, outputs, targets, epoch=0, task='det'):
+        # -------------- Detection loss --------------
+        det_loss_dict = None
+        if outputs['det_outputs'] is not None:
+            det_loss_dict = self.compute_det_loss(outputs['det_outputs'], targets, epoch)
+        # -------------- Segmentation loss --------------
+        seg_loss_dict = None
+        if outputs['seg_outputs'] is not None:
+            seg_loss_dict = self.compute_seg_loss(outputs['seg_outputs'], targets, epoch)
+        # -------------- Human pose loss --------------
+        pos_loss_dict = None
+        if outputs['pos_outputs'] is not None:
+            pos_loss_dict = self.compute_seg_loss(outputs['pos_outputs'], targets, epoch)
+
+        # Loss dict
+        if task == 'det':
+            return det_loss_dict
+        
+        if task == 'det_seg':
+            return {'det_loss_dict': det_loss_dict,
+                    'seg_loss_dict': seg_loss_dict}
+        
+        if task == 'det_pos':
+            return {'det_loss_dict': det_loss_dict,
+                    'pos_loss_dict': pos_loss_dict}
+        
+        if task == 'det_seg_pos':
+            return {'det_loss_dict': det_loss_dict,
+                    'seg_loss_dict': seg_loss_dict,
+                    'pos_loss_dict': pos_loss_dict}
+
 
 def build_criterion(args, cfg, device, num_classes):
     criterion = Criterion(args, cfg, device, num_classes)

+ 58 - 41
models/detectors/rtcdet/rtcdet.py

@@ -1,3 +1,5 @@
+# Real-time Convolutional Object Detector
+
 # --------------- Torch components ---------------
 import torch
 import torch.nn as nn
@@ -6,8 +8,8 @@ import torch.nn as nn
 from .rtcdet_backbone import build_backbone
 from .rtcdet_neck import build_neck
 from .rtcdet_pafpn import build_fpn
-from .rtcdet_head import build_head
-from .rtcdet_pred import build_pred
+from .rtcdet_head import build_det_head, build_seg_head, build_pose_head
+from .rtcdet_pred import build_det_pred, build_seg_pred, build_pose_pred
 
 # --------------- External components ---------------
 from utils.misc import multiclass_nms
@@ -57,10 +59,18 @@ class RTCDet(nn.Module):
         self.fpn_dims = self.fpn.out_dim
 
         ## ----------- Head -----------
-        self.head = build_head(cfg, self.fpn_dims, self.head_dim, self.num_levels)
-
-        ## ----------- Pred -----------
-        self.pred = build_pred(self.head_dim, self.head_dim, self.strides, num_classes, 4, self.num_levels)
+        self.det_head = nn.Sequential(
+            build_det_head(cfg['det_head'], self.fpn_dims, self.head_dim, self.num_levels),
+            build_det_pred(self.head_dim, self.head_dim, self.strides, num_classes, 4, self.num_levels)
+        )
+        self.seg_head = nn.Sequential(
+            build_seg_head(cfg['seg_head']),
+            build_seg_pred()
+        ) if cfg['seg_head']['name'] is not None else None
+        self.pos_head = nn.Sequential(
+            build_pose_head(cfg['pos_head']),
+            build_pose_pred()
+        ) if cfg['pos_head']['name'] is not None else None
 
     # Post process
     def post_process(self, cls_preds, box_preds):
@@ -141,32 +151,6 @@ class RTCDet(nn.Module):
 
         return bboxes, scores, labels
     
-    def forward_det_task(self, x):
-        # ---------------- Heads ----------------
-        outputs = self.head['det'](x)
-
-        # ---------------- Post-process ----------------
-        if self.trainable:
-            return outputs
-        else:
-            all_cls_preds = outputs['pred_cls']
-            all_box_preds = outputs['pred_box']
-
-            if self.deploy:
-                cls_preds = torch.cat(all_cls_preds, dim=1)[0]
-                box_preds = torch.cat(all_box_preds, dim=1)[0]
-                scores = cls_preds.sigmoid()
-                bboxes = box_preds
-                # [n_anchors_all, 4 + C]
-                outputs = torch.cat([bboxes, scores], dim=-1)
-
-                return outputs
-            else:
-                # post process
-                bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
-            
-                return bboxes, scores, labels
-
     # Main process
     def forward(self, x):
         # ---------------- Backbone ----------------
@@ -179,15 +163,37 @@ class RTCDet(nn.Module):
         pyramid_feats = self.fpn(pyramid_feats)
 
         # ---------------- Head ----------------
-        pyramid_feats = self.head(pyramid_feats)
+        det_outpus = self.forward_det_head(pyramid_feats)
+        seg_outpus = self.forward_seg_head(pyramid_feats)
+        pos_outpus = self.forward_pos_head(pyramid_feats)
+        outputs = {
+            'det_outputs': det_outpus,
+            'seg_outputs': seg_outpus,
+            'pos_outputs': pos_outpus
+        }
+
+        if not self.trainable:
+            if seg_outpus is not None:
+                det_outpus.update(seg_outpus)
+            if pos_outpus is not None:
+                det_outpus.update(pos_outpus)
+            outputs = det_outpus
+        
+        else:
+            outputs = {
+                'det_outputs': det_outpus,
+                'seg_outputs': seg_outpus,
+                'pos_outputs': pos_outpus
+            }
 
-        # ---------------- Pred ----------------
-        outputs = self.pred(pyramid_feats)
+        return outputs
+
+    def forward_det_head(self, x):
+        # ---------------- Heads ----------------
+        outputs = self.det_head(x)
 
         # ---------------- Post-process ----------------
-        if self.trainable:
-            return outputs
-        else:
+        if not self.trainable:
             all_cls_preds = outputs['pred_cls']
             all_box_preds = outputs['pred_box']
 
@@ -199,11 +205,22 @@ class RTCDet(nn.Module):
                 # [n_anchors_all, 4 + C]
                 outputs = torch.cat([bboxes, scores], dim=-1)
 
-                return outputs
             else:
                 # post process
                 bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+
+                outputs = {
+                    "scores": scores,
+                    "labels": labels,
+                    "bboxes": bboxes
+                }
             
-                return bboxes, scores, labels
+        return outputs
 
-    
+    def forward_seg_head(self, x):
+        if self.seg_head is None:
+            return None
+    
+    def forward_pos_head(self, x):
+        if self.pos_head is None:
+            return None

+ 26 - 2
models/detectors/rtcdet/rtcdet_head.py

@@ -7,11 +7,17 @@ except:
     from rtcdet_basic import Conv
 
 
-def build_head(cfg, in_dims, out_dim, num_levels=3):
+def build_det_head(cfg, in_dims, out_dim, num_levels=3):
     head = MDetHead(cfg, in_dims, out_dim, num_levels)
 
     return head
 
+def build_seg_head(cfg, in_dims, out_dim):
+    return MaskHead()
+
+def build_pose_head(cfg, in_dims, out_dim):
+    return PoseHead()
+
 
 # ---------------------------- Detection Head ----------------------------
 ## Single-level Detection Head
@@ -135,6 +141,24 @@ class MDetHead(nn.Module):
         return outputs
 
 
+# ---------------------------- Segmentation Head ----------------------------
+class MaskHead(nn.Module):
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+
+    def forward(self, x):
+        return
+
+
+# ---------------------------- Human-Pose Head ----------------------------
+class PoseHead(nn.Module):
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+
+    def forward(self, x):
+        return
+
+
 if __name__ == '__main__':
     import time
     from thop import profile
@@ -150,7 +174,7 @@ if __name__ == '__main__':
     fpn_dims = [256, 256, 256]
     out_dim = 256
     # Head-1
-    model = build_head(cfg, fpn_dims, out_dim, num_levels=3)
+    model = build_det_head(cfg, fpn_dims, out_dim, num_levels=3)
     print(model)
     fpn_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
     t0 = time.time()

+ 25 - 1
models/detectors/rtcdet/rtcdet_pred.py

@@ -3,7 +3,7 @@ import torch
 import torch.nn as nn
 
 
-def build_pred(cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3):
+def build_det_pred(cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=3):
     pred_layers = MDetPDLayer(cls_dim     = cls_dim,
                               reg_dim     = reg_dim,
                               strides     = strides,
@@ -13,6 +13,12 @@ def build_pred(cls_dim, reg_dim, strides, num_classes, num_coords=4, num_levels=
 
     return pred_layers
 
+def build_seg_pred():
+    return MaskPDLayer()
+
+def build_pose_pred():
+    return PosePDLayer()
+
 
 # ---------------------------- Detection predictor ----------------------------
 ## Single-level Detection Prediction Layer
@@ -153,3 +159,21 @@ class MDetPDLayer(nn.Module):
                    }
 
         return outputs
+
+
+# -------------------- Segmentation predictor --------------------
+class MaskPDLayer(nn.Module):
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+    
+    def forward(self, x):
+        return
+
+
+# -------------------- Human-Pose predictor --------------------
+class PosePDLayer(nn.Module):
+    def __init__(self, *args, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+    
+    def forward(self, x):
+        return

+ 5 - 0
models/detectors/rtrdet/rtrdet.py

@@ -0,0 +1,5 @@
+# Real-time Transformer Object Detector
+
+
+class RTRDet():
+    pass

+ 7 - 2
models/detectors/yolov1/yolov1.py

@@ -156,7 +156,6 @@ class YOLOv1(nn.Module):
             # [n_anchors_all, 4 + C]
             outputs = torch.cat([bboxes, scores], dim=-1)
 
-            return outputs
         else:
             # 将预测放在cpu处理上,以便进行后处理
             scores = scores.cpu().numpy()
@@ -165,7 +164,13 @@ class YOLOv1(nn.Module):
             # 后处理
             bboxes, scores, labels = self.postprocess(bboxes, scores)
 
-        return bboxes, scores, labels
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
+
+        return outputs
 
 
     def forward(self, x):

+ 7 - 2
models/detectors/yolov2/yolov2.py

@@ -211,13 +211,18 @@ class YOLOv2(nn.Module):
             # [n_anchors_all, 4 + C]
             outputs = torch.cat([bboxes, scores], dim=-1)
 
-            return outputs
         else:
             # post process
             bboxes, scores, labels = self.postprocess(
                 obj_pred, cls_pred, reg_pred, anchors)
 
-            return bboxes, scores, labels
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
+
+        return outputs
 
 
     def forward(self, x):

+ 6 - 3
models/detectors/yolov3/yolov3.py

@@ -235,14 +235,17 @@ class YOLOv3(nn.Module):
             # [n_anchors_all, 4 + C]
             outputs = torch.cat([bboxes, scores], dim=-1)
 
-            return outputs
         else:
             # post process
             bboxes, scores, labels = self.post_process(
                 all_obj_preds, all_cls_preds, all_box_preds)
-        
-            return bboxes, scores, labels
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
 
+        return outputs
 
     # ---------------------- Main Process for Training ----------------------
     def forward(self, x):

+ 7 - 3
models/detectors/yolov4/yolov4.py

@@ -236,13 +236,17 @@ class YOLOv4(nn.Module):
             # [n_anchors_all, 4 + C]
             outputs = torch.cat([bboxes, scores], dim=-1)
 
-            return outputs
         else:
             # post process
             bboxes, scores, labels = self.post_process(
                 all_obj_preds, all_cls_preds, all_box_preds)
-        
-            return bboxes, scores, labels
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
+
+        return outputs
 
 
     # ---------------------- Main Process for Training ----------------------

+ 9 - 5
models/detectors/yolov5/yolov5.py

@@ -225,13 +225,17 @@ class YOLOv5(nn.Module):
             # [n_anchors_all, 4 + C]
             outputs = torch.cat([bboxes, scores], dim=-1)
 
-            return outputs
         else:
             # post process
-            bboxes, scores, labels = self.post_process(all_obj_preds, all_cls_preds, all_box_preds)
-        
-            return bboxes, scores, labels
-
+            bboxes, scores, labels = self.post_process(
+                all_obj_preds, all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
+
+        return outputs
 
     # ---------------------- Main Process for Training ----------------------
     def forward(self, x):

+ 6 - 3
models/detectors/yolov7/yolov7.py

@@ -221,14 +221,17 @@ class YOLOv7(nn.Module):
             # [n_anchors_all, 4 + C]
             outputs = torch.cat([bboxes, scores], dim=-1)
 
-            return outputs
         else:
             # post process
             bboxes, scores, labels = self.post_process(
                 all_obj_preds, all_cls_preds, all_box_preds)
-        
-            return bboxes, scores, labels
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
 
+        return outputs
 
     # ---------------------- Main Process for Training ----------------------
     def forward(self, x):

+ 6 - 2
models/detectors/yolov8/yolov8.py

@@ -175,12 +175,16 @@ class YOLOv8(nn.Module):
             # [n_anchors_all, 4 + C]
             outputs = torch.cat([bboxes, scores], dim=-1)
 
-            return outputs
         else:
             # post process
             bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
         
-            return bboxes, scores, labels
+        return outputs
 
     def forward(self, x):
         if not self.trainable:

+ 6 - 3
models/detectors/yolox/yolox.py

@@ -213,14 +213,17 @@ class YOLOX(nn.Module):
             # [n_anchors_all, 4 + C]
             outputs = torch.cat([bboxes, scores], dim=-1)
 
-            return outputs
         else:
             # post process
             bboxes, scores, labels = self.post_process(
                 all_obj_preds, all_cls_preds, all_box_preds)
-        
-            return bboxes, scores, labels
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
 
+        return outputs
 
     # ---------------------- Main Process for Training ----------------------
     def forward(self, x):

+ 47 - 22
test.py

@@ -20,8 +20,7 @@ from models.detectors import build_model
 
 def parse_args():
     parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
-
-    # basic
+    # Basic setting
     parser.add_argument('-size', '--img_size', default=640, type=int,
                         help='the max size of input image')
     parser.add_argument('--show', action='store_true', default=False,
@@ -37,7 +36,7 @@ def parse_args():
     parser.add_argument('--resave', action='store_true', default=False, 
                         help='resave checkpoints without optimizer state dict.')
 
-    # model
+    # Model setting
     parser.add_argument('-m', '--model', default='yolov1', type=str,
                         help='build yolo')
     parser.add_argument('--weight', default=None,
@@ -57,7 +56,7 @@ def parse_args():
     parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
                         help='Perform NMS operations regardless of category.')
 
-    # dataset
+    # Data setting
     parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
                         help='data root')
     parser.add_argument('-d', '--dataset', default='coco',
@@ -71,18 +70,22 @@ def parse_args():
     parser.add_argument('--load_cache', action='store_true', default=False,
                         help='load data into memory.')
 
+    # Task setting
+    parser.add_argument('-t', '--task', default='det', choices=['det', 'det_seg', 'det_pos', 'det_seg_pos'],
+                        help='task type.')
+
     return parser.parse_args()
 
 
 @torch.no_grad()
-def test(args,
-         model, 
-         device, 
-         dataset,
-         transform=None,
-         class_colors=None, 
-         class_names=None, 
-         class_indexs=None):
+def test_det(args,
+             model, 
+             device, 
+             dataset,
+             transform=None,
+             class_colors=None, 
+             class_names=None, 
+             class_indexs=None):
     num_images = len(dataset)
     save_path = os.path.join('det_results/', args.dataset, args.model)
     os.makedirs(save_path, exist_ok=True)
@@ -99,7 +102,10 @@ def test(args,
 
         t0 = time.time()
         # inference
-        bboxes, scores, labels = model(x)
+        outputs = model(x)
+        scores = outputs['scores']
+        labels = outputs['labels']
+        bboxes = outputs['bboxes']
         print("detection time used ", time.time() - t0, "s")
         
         # rescale bboxes
@@ -125,6 +131,18 @@ def test(args,
             # save result
             cv2.imwrite(os.path.join(save_path, str(index).zfill(6) +'.jpg'), img_processed)
 
+@torch.no_grad()
+def test_det_seg():
+    pass
+
+@torch.no_grad()
+def test_det_pos():
+    pass
+
+@torch.no_grad()
+def test_det_seg_pos():
+    pass
+
 
 if __name__ == '__main__':
     args = parse_args()
@@ -181,12 +199,19 @@ if __name__ == '__main__':
         
     print("================= DETECT =================")
     # run
-    test(args=args,
-         model=model, 
-         device=device, 
-         dataset=dataset,
-         transform=val_transform,
-         class_colors=class_colors,
-         class_names=dataset_info['class_names'],
-         class_indexs=dataset_info['class_indexs'],
-         )
+    if args.task == "det":
+        test_det(args=args,
+                model=model, 
+                device=device, 
+                dataset=dataset,
+                transform=val_transform,
+                class_colors=class_colors,
+                class_names=dataset_info['class_names'],
+                class_indexs=dataset_info['class_indexs'],
+                )
+    elif args.task == "det_seg":
+        test_det_seg()
+    elif args.task == "det_pos":
+        test_det_pos()
+    elif args.task == "det_seg_pos":
+        test_det_seg_pos()