ソースを参照

debug RTCDet-v2

yjh0410 2 年 前
コミット
5754e59167

+ 2 - 0
config/model_config/rtcdet_v2_config.py

@@ -50,6 +50,7 @@ rtcdet_v2_cfg = {
                             'beta': 6.0},
                     'ota': {'center_sampling_radius': 2.5,
                              'topk_candidate': 10},
+                    'switch_epoch': 1,
                     },
         # ---------------- Loss config ----------------
         ## Loss weight
@@ -109,6 +110,7 @@ rtcdet_v2_cfg = {
                             'beta': 6.0},
                     'ota': {'center_sampling_radius': 2.5,
                              'topk_candidate': 10},
+                    'switch_epoch': 1,
                     },
         # ---------------- Loss config ----------------
         ## Loss weight

+ 1 - 1
models/detectors/rtcdet_v1/README.md

@@ -2,7 +2,7 @@
 
 |   Model    | Scale | Batch | AP<sup>test<br>0.5:0.95 | AP<sup>test<br>0.5 | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
 |------------|-------|-------|-------------------------|--------------------|------------------------|-------------------|-------------------|--------------------|--------|
-| RTCDetv1-N |  640  | 8xb16 |         35.7            |        53.8        |          35.6          |        53.8       |      9.1          |        2.4         |  |
+| RTCDetv1-N |  640  | 8xb16 |         35.7            |        53.8        |          35.6          |        53.8       |      9.1          |        2.4         | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/rtcdet_v1_n_coco.pth) |
 | RTCDetv1-T |  640  | 8xb16 |         40.5            |        59.1        |          40.3          |        59.1       |      19.0         |        5.1         | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/rtcdet_v1_t_coco.pth) |
 | RTCDetv1-S |  640  | 8xb16 |         43.6            |        62.6        |          43.3          |        62.6       |      33.6         |        9.0         | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/rtcdet_v1_s_coco.pth) |
 | RTCDetv1-M |  640  | 8xb16 |         48.3            |        67.0        |          48.1          |        66.9       |      87.4         |        23.6        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/rtcdet_v1_m_coco.pth) |

+ 6 - 3
models/detectors/rtcdet_v2/loss.py

@@ -14,18 +14,21 @@ class Criterion(object):
         self.device = device
         self.num_classes = num_classes
         self.use_ema_update = cfg['ema_update']
-        # loss weight
+        # ---------------- Loss weight ----------------
         self.loss_cls_weight = cfg['loss_cls_weight']
         self.loss_box_weight = cfg['loss_box_weight']
         self.loss_dfl_weight = cfg['loss_dfl_weight']
-        # matcher
+        # ---------------- Matcher ----------------
         matcher_config = cfg['matcher']
+        self.switch_epoch = matcher_config['switch_epoch']
+        ## TAL assigner
         self.tal_matcher = TaskAlignedAssigner(
             topk=matcher_config['tal']['topk'],
             alpha=matcher_config['tal']['alpha'],
             beta=matcher_config['tal']['beta'],
             num_classes=num_classes
             )
+        ## SimOTA assigner
         self.ota_matcher = AlignedSimOTA(
             center_sampling_radius=matcher_config['ota']['center_sampling_radius'],
             topk_candidate=matcher_config['ota']['topk_candidate'],
@@ -33,7 +36,7 @@ class Criterion(object):
         )
 
     def __call__(self, outputs, targets, epoch=0):
-        if epoch < self.args.wp_epoch:
+        if epoch < self.switch_epoch:
             return self.ota_loss(outputs, targets)
         else:
             return self.tal_loss(outputs, targets)

+ 1 - 1
models/detectors/rtcdet_v2/rtcdet_v2.py

@@ -52,7 +52,7 @@ class RTCDet(nn.Module):
 
         ## ----------- Heads -----------
         self.det_heads = build_det_head(
-            cfg, self.fpn_dims, self.head_dim, num_classes, num_levels=len(self.stride))
+            cfg, self.fpn_dims, self.head_dim, num_classes, self.reg_max, num_levels=len(self.stride))
 
         ## ----------- Preds -----------
         self.pred_layers = build_pred_layer(

+ 1 - 1
models/detectors/rtcdet_v2/rtcdet_v2_backbone.py

@@ -8,7 +8,7 @@ except:
 
 
 model_urls = {
-    'mcnet_p': None,
+    'mcnet_p': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/mcnet_pico.pth",
     'mcnet_n': None,
     'mcnet_t': None,
     'mcnet_s': None,

+ 8 - 9
models/detectors/rtcdet_v2/rtcdet_v2_head.py

@@ -6,11 +6,10 @@ from .rtcdet_v2_basic import Conv
 
 # Single-level Head
 class SingleLevelHead(nn.Module):
-    def __init__(self, in_dim, out_dim, num_classes, num_cls_head, num_reg_head, act_type, norm_type, depthwise):
+    def __init__(self, in_dim, cls_head_dim, reg_head_dim, num_cls_head, num_reg_head, act_type, norm_type, depthwise):
         super().__init__()
         # --------- Basic Parameters ----------
         self.in_dim = in_dim
-        self.num_classes = num_classes
         self.num_cls_head = num_cls_head
         self.num_reg_head = num_reg_head
         self.act_type = act_type
@@ -20,7 +19,7 @@ class SingleLevelHead(nn.Module):
         # --------- Network Parameters ----------
         ## cls head
         cls_feats = []
-        self.cls_head_dim = max(out_dim, num_classes)
+        self.cls_head_dim = cls_head_dim
         for i in range(num_cls_head):
             if i == 0:
                 cls_feats.append(
@@ -38,7 +37,7 @@ class SingleLevelHead(nn.Module):
                         )      
         ## reg head
         reg_feats = []
-        self.reg_head_dim = out_dim
+        self.reg_head_dim = reg_head_dim
         for i in range(num_reg_head):
             if i == 0:
                 reg_feats.append(
@@ -70,14 +69,14 @@ class SingleLevelHead(nn.Module):
 
 # Multi-level Head
 class MultiLevelHead(nn.Module):
-    def __init__(self, cfg, in_dims, out_dim, num_classes=80, num_levels=3):
+    def __init__(self, cfg, in_dims, out_dim, num_classes=80, reg_max=16, num_levels=3):
         super().__init__()
         ## ----------- Network Parameters -----------
         self.multi_level_heads = nn.ModuleList(
             [SingleLevelHead(
                 in_dims[level],
-                out_dim,
-                num_classes,
+                max(out_dim, num_classes),   # cls head dim
+                max(out_dim//4, 4*reg_max),  # reg head dim
                 cfg['num_cls_head'],
                 cfg['num_reg_head'],
                 cfg['head_act'],
@@ -110,8 +109,8 @@ class MultiLevelHead(nn.Module):
     
 
 # build detection head
-def build_det_head(cfg, in_dim, out_dim, num_classes=80, num_levels=3):
+def build_det_head(cfg, in_dim, out_dim, num_classes=80, reg_max=16, num_levels=3):
     if cfg['head'] == 'decoupled_head':
-        head = MultiLevelHead(cfg, in_dim, out_dim, num_classes, num_levels) 
+        head = MultiLevelHead(cfg, in_dim, out_dim, num_classes, reg_max, num_levels) 
 
     return head