yjh0410 1 жил өмнө
parent
commit
1cdbce1ec9

+ 4 - 3
config/model_config/ctrnet_config.py

@@ -15,7 +15,7 @@ ctrnet_cfg = {
         'ratio': 2.0,
         'ratio': 2.0,
         'stride': 32,
         'stride': 32,
         'max_stride': 32,
         'max_stride': 32,
-        'out_stride': 4,
+        'out_stride': 8,
         ## Neck
         ## Neck
         'neck': 'sppf',
         'neck': 'sppf',
         'neck_expand_ratio': 0.5,
         'neck_expand_ratio': 0.5,
@@ -29,8 +29,8 @@ ctrnet_cfg = {
         'dec_depthwise': False,
         'dec_depthwise': False,
         ## Head
         ## Head
         'head': 'decoupled_head',
         'head': 'decoupled_head',
-        'num_cls_head': 2,
-        'num_reg_head': 2,
+        'num_cls_head': 4,
+        'num_reg_head': 4,
         'head_act': 'silu',
         'head_act': 'silu',
         'head_norm': 'BN',
         'head_norm': 'BN',
         'head_depthwise': False,  
         'head_depthwise': False,  
@@ -50,6 +50,7 @@ ctrnet_cfg = {
         ## loss weight
         ## loss weight
         'loss_cls_weight': 1.0,
         'loss_cls_weight': 1.0,
         'loss_box_weight': 2.0,
         'loss_box_weight': 2.0,
+        'aux_bbox_loss': False,
         # ---------------- Train config ----------------
         # ---------------- Train config ----------------
         'trainer_type': 'rtcdet',
         'trainer_type': 'rtcdet',
     },
     },

+ 6 - 6
models/detectors/ctrnet/ctrnet.py

@@ -38,7 +38,7 @@ class CenterNet(nn.Module):
         self.deploy = deploy
         self.deploy = deploy
         self.no_multi_labels = no_multi_labels
         self.no_multi_labels = no_multi_labels
         self.nms_class_agnostic = nms_class_agnostic
         self.nms_class_agnostic = nms_class_agnostic
-        self.head_dims = [round(512 * cfg['width']), round(256 * cfg['width']), round(128 * cfg['width'])]
+        self.head_dim = round(256 * cfg['width'])
         
         
         # ---------------- Network Parameters ----------------
         # ---------------- Network Parameters ----------------
         ## Encoder
         ## Encoder
@@ -49,17 +49,17 @@ class CenterNet(nn.Module):
         self.feat_dim = self.neck.out_dim
         self.feat_dim = self.neck.out_dim
         
         
         ## Decoder
         ## Decoder
-        self.decoder = build_decoder(cfg, self.feat_dim, self.head_dims)
+        self.decoder = build_decoder(cfg, self.feat_dim, self.head_dim)
 
 
         ## Head
         ## Head
         self.det_head = nn.Sequential(
         self.det_head = nn.Sequential(
-            build_det_head(cfg, self.head_dims[-1], self.head_dims[-1]),
-            build_det_pred(self.head_dims[-1], self.head_dims[-1], self.stride, num_classes, 4)
+            build_det_head(cfg, self.head_dim, self.head_dim),
+            build_det_pred(self.head_dim, self.head_dim, self.stride, num_classes, 4)
         )
         )
         ## Aux Head
         ## Aux Head
         self.aux_det_head = nn.Sequential(
         self.aux_det_head = nn.Sequential(
-            build_det_head(cfg, self.head_dims[-1], self.head_dims[-1]),
-            build_det_pred(self.head_dims[-1], self.head_dims[-1], self.stride, num_classes, 4)
+            build_det_head(cfg, self.head_dim, self.head_dim),
+            build_det_pred(self.head_dim, self.head_dim, self.stride, num_classes, 4)
         )
         )
 
 
     # Post process
     # Post process

+ 4 - 4
models/detectors/ctrnet/ctrnet_decoder.py

@@ -1,7 +1,7 @@
 import math
 import math
 import torch.nn as nn
 import torch.nn as nn
 
 
-from .ctrnet_basic import DeConv, DeformableConv
+from .ctrnet_basic import DeConv, RTCBlock
 
 
 
 
 def build_decoder(cfg, in_dim, out_dim):
 def build_decoder(cfg, in_dim, out_dim):
@@ -36,11 +36,11 @@ class CTRDecoder(nn.Module):
         layers = []
         layers = []
         for i in range(self.num_layers):
         for i in range(self.num_layers):
             layer = nn.Sequential(
             layer = nn.Sequential(
-                DeformableConv(in_dim, out_dim[i], kernel_size=3, padding=1, stride=1),
-                DeConv(out_dim[i], out_dim[i], kernel_size=4, stride=2, act_type=act_type, norm_type=norm_type)
+                RTCBlock(in_dim, out_dim, 3, False, act_type, norm_type, depthwise),
+                DeConv(out_dim, out_dim, kernel_size=4, stride=2, act_type=act_type, norm_type=norm_type)
             )
             )
             layers.append(layer)
             layers.append(layer)
-            in_dim = out_dim[i]
+            in_dim = out_dim
         self.layers = nn.Sequential(*layers)
         self.layers = nn.Sequential(*layers)
 
 
     def forward(self, x):
     def forward(self, x):

+ 4 - 4
models/detectors/ctrnet/loss.py

@@ -15,7 +15,7 @@ class Criterion(object):
         self.num_classes = num_classes
         self.num_classes = num_classes
         self.max_epoch = args.max_epoch
         self.max_epoch = args.max_epoch
         self.no_aug_epoch = args.no_aug_epoch
         self.no_aug_epoch = args.no_aug_epoch
-        self.aux_bbox_loss = False
+        self.aux_bbox_loss = cfg['aux_bbox_loss']
         # --------------- Loss config ---------------
         # --------------- Loss config ---------------
         self.loss_cls_weight = cfg['loss_cls_weight']
         self.loss_cls_weight = cfg['loss_cls_weight']
         self.loss_box_weight = cfg['loss_box_weight']
         self.loss_box_weight = cfg['loss_box_weight']
@@ -168,7 +168,7 @@ class Criterion(object):
 
 
         # ------------------ Aux regression loss ------------------
         # ------------------ Aux regression loss ------------------
         loss_box_aux = None
         loss_box_aux = None
-        if epoch >= (self.max_epoch - self.no_aug_epoch - 1):
+        if epoch >= (self.max_epoch - self.no_aug_epoch - 1) and self.aux_bbox_loss:
             ## reg_preds
             ## reg_preds
             reg_preds = outputs['pred_reg']
             reg_preds = outputs['pred_reg']
             reg_preds_pos = reg_preds.view(-1, 4)[pos_inds]
             reg_preds_pos = reg_preds.view(-1, 4)[pos_inds]
@@ -203,10 +203,10 @@ class Criterion(object):
 
 
     def __call__(self, outputs, targets, epoch=0):
     def __call__(self, outputs, targets, epoch=0):
         # -------------- Main loss --------------
         # -------------- Main loss --------------
-        main_loss_dict = self.compute_loss(outputs, targets, epoch)
+        main_loss_dict = self.compute_loss(outputs, targets, False, epoch)
         
         
         # -------------- Aux loss --------------
         # -------------- Aux loss --------------
-        aux_loss_dict = self.compute_loss(outputs['aux_outputs'], targets, epoch)
+        aux_loss_dict = self.compute_loss(outputs['aux_outputs'], targets, True, epoch)
 
 
         # Reformat loss dict
         # Reformat loss dict
         loss_dict = dict()
         loss_dict = dict()