yjh0410 1 年間 前
コミット
9228f75dce

+ 4 - 0
config/__init__.py

@@ -87,6 +87,7 @@ from .model_config.yolov8_config import yolov8_cfg
 from .model_config.yolox_config import yolox_cfg
 ## My RTCDet series
 from .model_config.rtcdet_config import rtcdet_cfg, rtcdet_seg_cfg, rtcdet_pos_cfg, rtcdet_seg_pos_cfg
+from .model_config.ctrnet_config import ctrnet_cfg
 
 def build_model_config(args):
     print('==============================')
@@ -118,6 +119,9 @@ def build_model_config(args):
     # RTCDet
     elif args.model in ['rtcdet_n', 'rtcdet_t', 'rtcdet_s', 'rtcdet_m', 'rtcdet_l', 'rtcdet_x']:
         cfg = rtcdet_cfg[args.model]
+    # CenterNet
+    elif args.model in ['ctrnet_n', 'ctrnet_t', 'ctrnet_s', 'ctrnet_m', 'ctrnet_l', 'ctrnet_x']:
+        cfg = ctrnet_cfg[args.model]
 
     return cfg
 

+ 2 - 3
config/model_config/ctrnet_config.py

@@ -11,7 +11,6 @@ ctrnet_cfg = {
         'bk_depthwise': False,
         'width': 0.25,
         'depth': 0.34,
-        'ratio': 2.0,
         'max_stride': 32,
         'out_stride': 4,
         ## Neck
@@ -27,8 +26,8 @@ ctrnet_cfg = {
         'dec_depthwise': False,
         ## Head
         'head': 'decoupled_head',
-        'num_cls_head': 2,
-        'num_reg_head': 2,
+        'num_cls_head': 4,
+        'num_reg_head': 4,
         'head_act': 'silu',
         'head_norm': 'BN',
         'head_depthwise': False,  

+ 5 - 0
models/detectors/__init__.py

@@ -13,6 +13,7 @@ from .yolov8.build import build_yolov8
 from .yolox.build import build_yolox
 # My RTCDet series
 from .rtcdet.build import build_rtcdet
+from .ctrnet.build import build_ctrnet
 
 
 # build object detector
@@ -58,6 +59,10 @@ def build_model(args,
     elif args.model in ['rtcdet_n', 'rtcdet_t', 'rtcdet_s', 'rtcdet_m', 'rtcdet_l', 'rtcdet_x']:
         model, criterion = build_rtcdet(
             args, model_cfg, device, num_classes, trainable, deploy)
+    # CenterNet
+    elif args.model in ['ctrnet_n', 'ctrnet_t', 'ctrnet_s', 'ctrnet_m', 'ctrnet_l', 'ctrnet_x']:
+        model, criterion = build_ctrnet(
+            args, model_cfg, device, num_classes, trainable, deploy)
 
     if trainable:
         # Load pretrained weight

+ 43 - 0
models/detectors/ctrnet/build.py

@@ -0,0 +1,43 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import torch
+import torch.nn as nn
+
+from .loss import build_criterion
+from .ctrnet import CenterNet
+
+
+# build object detector
+def build_ctrnet(args, cfg, device, num_classes=80, trainable=False, deploy=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+    
+    print('==============================')
+    print('Model Configuration: \n', cfg)
+    
+    # -------------- Build CenterNet --------------
+    model = CenterNet(cfg                = cfg,
+                      device             = device, 
+                      num_classes        = num_classes,
+                      trainable          = trainable,
+                      conf_thresh        = args.conf_thresh,
+                      topk               = args.topk,
+                      deploy             = deploy,
+                      no_multi_labels    = args.no_multi_labels,
+                      nms_class_agnostic = args.nms_class_agnostic
+                      )
+
+    # -------------- Initialize CenterNet --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+            
+    # -------------- Build criterion --------------
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(args, cfg, device, num_classes)
+        
+    return model, criterion

+ 4 - 4
models/detectors/ctrnet/ctrnet.py

@@ -13,7 +13,7 @@ from .ctrnet_pred    import build_det_pred
 
 
 # CenterNet
-class CenterNet():
+class CenterNet(nn.Module):
     def __init__(self,
                  cfg,
                  device,
@@ -42,7 +42,7 @@ class CenterNet():
         
         # ---------------- Network Parameters ----------------
         ## Encoder
-        self.encoder, feat_dims = build_encoder(cfg, pretrained=cfg['bk_pretrained']&trainable)
+        self.encoder, feat_dims = build_encoder(cfg)
 
         ## Neck
         self.neck = build_neck(cfg, feat_dims[-1], feat_dims[-1])
@@ -135,9 +135,9 @@ class CenterNet():
         feat = self.decoder(feat)
 
         # ---------------- Head ----------------
-        outputs = self.det_head(x)
+        outputs = self.det_head(feat)
         if self.trainable:
-            outputs['aux_outputs'] = self.aux_det_head(x)
+            outputs['aux_outputs'] = self.aux_det_head(feat)
 
         # ---------------- Post-process ----------------
         if not self.trainable:

+ 1 - 1
models/detectors/ctrnet/ctrnet_decoder.py

@@ -36,7 +36,7 @@ class CTRDecoder(nn.Module):
         layers = []
         for _ in range(self.num_layers):
             layer = nn.Sequential(
-                RTCBlock(in_dim, out_dim, 1, False, act_type, norm_type, depthwise),
+                RTCBlock(in_dim, out_dim, 3, False, act_type, norm_type, depthwise),
                 DeConv(out_dim, out_dim, kernel_size=4, stride=2, act_type=act_type, norm_type=norm_type)
             )
             layers.append(layer)

+ 3 - 64
models/detectors/ctrnet/ctrnet_encoder.py

@@ -7,27 +7,15 @@ except:
     from ctrnet_basic import Conv, RTCBlock
 
 
-# MIM-pretrained weights
-model_urls = {
-    "rtcnet_n": None,
-    "rtcnet_t": None,
-    "rtcnet_s": None,
-    "rtcnet_m": None,
-    "rtcnet_l": None,
-    "rtcnet_x": None,
-}
-
-
 # ---------------------------- Basic functions ----------------------------
 ## Real-time Convolutional Backbone
 class CTREncoder(nn.Module):
-    def __init__(self, width=1.0, depth=1.0, ratio=1.0, act_type='silu', norm_type='BN', depthwise=False):
+    def __init__(self, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False):
         super(CTREncoder, self).__init__()
         # ---------------- Basic parameters ----------------
         self.width_factor = width
         self.depth_factor = depth
-        self.last_stage_factor = ratio
-        self.feat_dims = [round(64 * width), round(128 * width), round(256 * width), round(512 * width), round(512 * width * ratio)]
+        self.feat_dims = [round(64 * width), round(128 * width), round(256 * width), round(512 * width), round(1024 * width)]
         # ---------------- Network parameters ----------------
         ## P1/2
         self.layer_1 = Conv(3, self.feat_dims[0], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type)
@@ -90,67 +78,18 @@ class CTREncoder(nn.Module):
 
 # ---------------------------- Functions ----------------------------
 ## build Backbone
-def build_encoder(cfg, pretrained=False): 
+def build_encoder(cfg): 
     # build backbone model
     backbone = CTREncoder(width=cfg['width'],
                           depth=cfg['depth'],
-                          ratio=cfg['ratio'],
                           act_type=cfg['bk_act'],
                           norm_type=cfg['bk_norm'],
                           depthwise=cfg['bk_depthwise']
                           )
     feat_dims = backbone.feat_dims[-3:]
-
-    # load pretrained weight
-    if pretrained:
-        backbone = load_pretrained_weight(backbone)
         
     return backbone, feat_dims
 
-## load pretrained weight
-def load_pretrained_weight(model):
-    # Model name
-    width, depth, ratio = model.width_factor, model.depth_factor, model.last_stage_factor
-    if width == 0.25 and depth == 0.34 and ratio == 2.0:
-        model_name = "rtcnet_n"
-    elif width == 0.375 and depth == 0.34 and ratio == 2.0:
-        model_name = "rtcnet_t"
-    elif width == 0.50 and depth == 0.34 and ratio == 2.0:
-        model_name = "rtcnet_s"
-    elif width == 0.75 and depth == 0.67 and ratio == 1.5:
-        model_name = "rtcnet_m"
-    elif width == 1.0 and depth == 1.0 and ratio == 1.0:
-        model_name = "rtcnet_l"
-    elif width == 1.25 and depth == 1.34 and ratio == 1.0:
-        model_name = "rtcnet_x"
-    
-    # Load pretrained weight
-    url = model_urls[model_name]
-    if url is not None:
-        print('Loading pretrained weight ...')
-        checkpoint = torch.hub.load_state_dict_from_url(
-            url=url, map_location="cpu", check_hash=True)
-        # checkpoint state dict
-        checkpoint_state_dict = checkpoint.pop("model")
-        # model state dict
-        model_state_dict = model.state_dict()
-        # check
-        for k in list(checkpoint_state_dict.keys()):
-            if k in model_state_dict:
-                shape_model = tuple(model_state_dict[k].shape)
-                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
-                if shape_model != shape_checkpoint:
-                    checkpoint_state_dict.pop(k)
-            else:
-                checkpoint_state_dict.pop(k)
-                print(k)
-        # load the weight
-        model.load_state_dict(checkpoint_state_dict)
-    else:
-        print('No backbone pretrained for {}.'.format(model_name))
-
-    return model
-
 
 if __name__ == '__main__':
     import time

+ 2 - 1
models/detectors/ctrnet/ctrnet_pred.py

@@ -60,7 +60,8 @@ class SDetPDLayer(nn.Module):
 
         return anchors
         
-    def forward(self, cls_feat, reg_feat):
+    def forward(self, inputs):
+        cls_feat, reg_feat = inputs['cls_feat'], inputs['reg_feat']
         # pred
         cls_pred = self.cls_pred(cls_feat)
         reg_pred = self.reg_pred(reg_feat)

+ 17 - 17
models/detectors/ctrnet/loss.py

@@ -102,13 +102,13 @@ class Criterion(object):
                     'losses':  (torch.Tensor) It is a scalar.),
                 }
         """
-        bs = outputs['pred_cls'][0].shape[0]
-        device = outputs['pred_cls'][0].device
-        fpn_strides = outputs['strides']
+        bs = outputs['pred_cls'].shape[0]
+        device = outputs['pred_cls'].device
+        stride = outputs['stride']
         anchors = outputs['anchors']
         # preds: [B, M, C]
-        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
-        box_preds = torch.cat(outputs['pred_box'], dim=1)
+        cls_preds = outputs['pred_cls']
+        box_preds = outputs['pred_box']
         
         # --------------- label assignment ---------------
         cls_targets = []
@@ -118,15 +118,15 @@ class Criterion(object):
             tgt_labels = targets[batch_idx]["labels"].to(device)  # [N,]
             tgt_bboxes = targets[batch_idx]["boxes"].to(device)   # [N, 4]
             if not aux_loss:
-                assigned_result = self.matcher(fpn_strides=fpn_strides,
-                                            anchors=anchors,
-                                            pred_cls=cls_preds[batch_idx].detach(),
-                                            pred_box=box_preds[batch_idx].detach(),
-                                            gt_labels=tgt_labels,
-                                            gt_bboxes=tgt_bboxes
-                                            )
+                assigned_result = self.matcher(stride=stride,
+                                               anchors=anchors,
+                                               pred_cls=cls_preds[batch_idx].detach(),
+                                               pred_box=box_preds[batch_idx].detach(),
+                                               gt_labels=tgt_labels,
+                                               gt_bboxes=tgt_bboxes
+                                               )
             else:
-                assigned_result = self.aux_matcher(fpn_strides=fpn_strides,
+                assigned_result = self.aux_matcher(stride=stride,
                                                    anchors=anchors,
                                                    pred_cls=cls_preds[batch_idx].detach(),
                                                    pred_box=box_preds[batch_idx].detach(),
@@ -170,13 +170,13 @@ class Criterion(object):
         loss_box_aux = None
         if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
             ## reg_preds
-            reg_preds = torch.cat(outputs['pred_reg'], dim=1)
+            reg_preds = outputs['pred_reg']
             reg_preds_pos = reg_preds.view(-1, 4)[pos_inds]
             ## anchor tensors
-            anchors_tensors = torch.cat(outputs['anchors'], dim=0)[None].repeat(bs, 1, 1)
+            anchors_tensors = outputs['anchors'][None].repeat(bs, 1, 1)
             anchors_tensors_pos = anchors_tensors.view(-1, 2)[pos_inds]
             ## stride tensors
-            stride_tensors = torch.cat(outputs['stride_tensors'], dim=0)[None].repeat(bs, 1, 1)
+            stride_tensors = outputs['stride_tensors'][None].repeat(bs, 1, 1)
             stride_tensors_pos = stride_tensors.view(-1, 1)[pos_inds]
             ## aux loss
             loss_box_aux = self.loss_bboxes_aux(reg_preds_pos, box_targets_pos, anchors_tensors_pos, stride_tensors_pos)
@@ -216,7 +216,7 @@ class Criterion(object):
                 loss_dict[k] = main_loss_dict[k]
         for k in aux_loss_dict:
             if k != 'losses':
-                loss_dict[k] = main_loss_dict[k]
+                loss_dict[k+'_aux'] = aux_loss_dict[k]
         
         return loss_dict
 

+ 3 - 6
models/detectors/ctrnet/matcher.py

@@ -16,18 +16,15 @@ class AlignedSimOTA(object):
 
     @torch.no_grad()
     def __call__(self, 
-                 fpn_strides, 
+                 stride, 
                  anchors, 
                  pred_cls, 
                  pred_box, 
                  gt_labels,
                  gt_bboxes):
         # [M,]
-        strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
-                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
-        # List[F, M, 2] -> [M, 2]
+        stride_tensor = torch.ones_like(anchors[:, 0]) * stride
         num_gt = len(gt_labels)
-        anchors = torch.cat(anchors, dim=0)
 
         # check gt
         if num_gt == 0 or gt_bboxes.max().item() == 0.:
@@ -46,7 +43,7 @@ class AlignedSimOTA(object):
         # ----------------------------------- soft center prior -----------------------------------
         gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
         distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1)
-                    ).pow(2).sum(-1).sqrt() / strides.unsqueeze(0)  # [N, M]
+                    ).pow(2).sum(-1).sqrt() / stride_tensor.unsqueeze(0)  # [N, M]
         distance = distance * valid_mask.unsqueeze(0)
         soft_center_prior = torch.pow(10, distance - self.soft_center_radius)