Jelajahi Sumber

train RT-DETR-R50

yjh0410 1 tahun lalu
induk
melakukan
4b5176e87e

+ 1 - 3
config/model_config/rtdetr_config.py

@@ -17,7 +17,6 @@ rtdetr_cfg = {
         ## Image Encoder - FPN
         'fpn': 'hybrid_encoder',
         'fpn_num_blocks': 3,
-        'fpn_expansion': 0.5,
         'fpn_act': 'silu',
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
@@ -73,11 +72,10 @@ rtdetr_cfg = {
         ## Image Encoder - FPN
         'fpn': 'hybrid_encoder',
         'fpn_num_blocks': 3,
-        'fpn_expansion': 1.0,
         'fpn_act': 'silu',
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
-        'hidden_dim': 256,
+        'hidden_dim': 320,
         'en_num_heads': 8,
         'en_num_layers': 1,
         'en_ffn_dim': 1024,

+ 6 - 10
engine.py

@@ -1322,18 +1322,14 @@ class RTDetrTrainer(object):
                 targets = self.box_xyxy_to_cxcywh(targets)
 
             # Inference
-            with torch.autocast(device_type=str(self.device), cache_enabled=True):
+            with torch.cuda.amp.autocast(enabled=self.args.fp16):
                 outputs = model(images, targets)    
-
-            # Compute loss
-            with torch.autocast(device_type=str(self.device), enabled=False):
                 loss_dict = self.criterion(outputs, targets)
-            losses = sum(loss_dict.values())
-            # Grad Accumulate
-            if self.grad_accumulate > 1:
-                losses /= self.grad_accumulate
-
-            loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
+                losses = sum(loss_dict.values())
+                # Grad Accumulate
+                if self.grad_accumulate > 1:
+                    losses /= self.grad_accumulate
+                loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
 
             # Backward
             self.scaler.scale(losses).backward()

+ 40 - 70
models/detectors/rtdetr/basic_modules/basic.py

@@ -218,89 +218,59 @@ class BasicConv(nn.Module):
 
 
 # ----------------- CNN Modules -----------------
-class RepVggBlock(nn.Module):
-    def __init__(self, in_dim, out_dim, act_type='relu', norm_type='BN'):
-        super().__init__()
-        self.in_dim = in_dim
-        self.out_dim = out_dim
-        self.conv1 = BasicConv(in_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=norm_type)
-        self.conv2 = BasicConv(in_dim, out_dim, kernel_size=1, padding=0, act_type=None, norm_type=norm_type)
-        self.act = get_activation(act_type) 
+class Bottleneck(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio = 0.5,
+                 kernel_sizes = [3, 3],
+                 shortcut     = True,
+                 act_type     = 'silu',
+                 norm_type    = 'BN',
+                 depthwise    = False,):
+        super(Bottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)
+        paddings = [k // 2 for k in kernel_sizes]
+        self.cv1 = BasicConv(in_dim, inter_dim,
+                             kernel_size=kernel_sizes[0], padding=paddings[0],
+                             act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.cv2 = BasicConv(inter_dim, out_dim,
+                             kernel_size=kernel_sizes[1], padding=paddings[1],
+                             act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
 
     def forward(self, x):
-        if hasattr(self, 'conv'):
-            y = self.conv(x)
-        else:
-            y = self.conv1(x) + self.conv2(x)
-
-        return self.act(y)
+        h = self.cv2(self.cv1(x))
 
-    def convert_to_deploy(self):
-        if not hasattr(self, 'conv'):
-            self.conv = nn.Conv2d(self.in_dim, self.out_dim, 3, 1, padding=1)
+        return x + h if self.shortcut else h
 
-        kernel, bias = self.get_equivalent_kernel_bias()
-        self.conv.weight.data = kernel
-        self.conv.bias.data = bias 
-
-    def get_equivalent_kernel_bias(self):
-        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
-        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
-        
-        return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1
-
-    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
-        if kernel1x1 is None:
-            return 0
-        else:
-            return F.pad(kernel1x1, [1, 1, 1, 1])
-
-    def _fuse_bn_tensor(self, branch: BasicConv):
-        if branch is None:
-            return 0, 0
-        kernel = branch.conv.weight
-        running_mean = branch.norm.running_mean
-        running_var = branch.norm.running_var
-        gamma = branch.norm.weight
-        beta = branch.norm.bias
-        eps = branch.norm.eps
-        std = (running_var + eps).sqrt()
-        t = (gamma / std).reshape(-1, 1, 1, 1)
-
-        return kernel * t, beta - running_mean * gamma / std
-
-class RepRTCBlock(nn.Module):
+class RTCBlock(nn.Module):
     def __init__(self,
                  in_dim,
                  out_dim,
-                 num_blocks = 3,
-                 expansion  = 1.0,
+                 num_blocks = 1,
+                 shortcut   = False,
                  act_type   = 'silu',
                  norm_type  = 'BN',
-                 ) -> None:
-        super(RepRTCBlock, self).__init__()
-        self.inter_dim = round(out_dim * expansion)
-        # ------------ Input & Output projection ------------
-        self.conv1 = BasicConv(in_dim, self.inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.conv2 = BasicConv(in_dim, self.inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.conv3 = BasicConv(self.inter_dim * (2 + num_blocks), out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        # ------------ Core modules ------------
-        module = nn.Sequential(RepVggBlock(self.inter_dim, self.inter_dim, act_type, norm_type),
-                               RepVggBlock(self.inter_dim, self.inter_dim, act_type, norm_type),)
-        self.module = nn.ModuleList([copy.deepcopy(module) for _ in range(num_blocks)])
-        
+                 depthwise  = False,):
+        super(RTCBlock, self).__init__()
+        self.inter_dim = out_dim // 2
+        self.input_proj = BasicConv(in_dim, self.inter_dim * 2, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.m = nn.Sequential(*(
+            Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
+            for _ in range(num_blocks)))
+        self.output_proj = BasicConv((2 + num_blocks) * self.inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+
     def forward(self, x):
         # Input proj
-        x1 = self.conv1(x)
-        x2 = self.conv2(x)
+        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
+        out = list([x1, x2])
 
-        # Core module
-        out = [x1, x2]
-        for m in self.module:
-            x2 = m(x2)
-            out.append(x2)
+        # Bottlenecl
+        out.extend(m(out[-1]) for m in self.m)
 
         # Output proj
-        out = self.conv3(torch.cat(out, dim=1))
+        out = self.output_proj(torch.cat(out, dim=1))
 
         return out
+    

+ 58 - 63
models/detectors/rtdetr/basic_modules/fpn.py

@@ -4,10 +4,10 @@ import torch.nn.functional as F
 from typing import List
 
 try:
-    from .basic import BasicConv, RepRTCBlock
+    from .basic import BasicConv, RTCBlock
     from .transformer import TransformerEncoder
 except:
-    from  basic import BasicConv, RepRTCBlock
+    from  basic import BasicConv, RTCBlock
     from  transformer import TransformerEncoder
 
 
@@ -17,9 +17,9 @@ def build_fpn(cfg, in_dims, out_dim):
         return HybridEncoder(in_dims     = in_dims,
                              out_dim     = out_dim,
                              num_blocks  = cfg['fpn_num_blocks'],
-                             expansion   = cfg['fpn_expansion'],
                              act_type    = cfg['fpn_act'],
                              norm_type   = cfg['fpn_norm'],
+                             depthwise   = cfg['fpn_depthwise'],
                              num_heads   = cfg['en_num_heads'],
                              num_layers  = cfg['en_num_layers'],
                              ffn_dim     = cfg['en_ffn_dim'],
@@ -38,9 +38,9 @@ class HybridEncoder(nn.Module):
                  in_dims        :List  = [256, 512, 1024],
                  out_dim        :int   = 256,
                  num_blocks     :int   = 3,
-                 expansion      :float = 1.0,
                  act_type       :str   = 'silu',
                  norm_type      :str   = 'BN',
+                 depthwise      :bool  = False,
                  # Transformer's parameters
                  num_heads      :int   = 8,
                  num_layers     :int   = 1,
@@ -62,9 +62,17 @@ class HybridEncoder(nn.Module):
         c3, c4, c5 = in_dims
 
         # ---------------- Input projs ----------------
-        self.input_proj_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=None, norm_type=norm_type)
-        self.input_proj_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=None, norm_type=norm_type)
-        self.input_proj_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=None, norm_type=norm_type)
+        self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+
+        # ---------------- Downsample ----------------
+        self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim,
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim,
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=act_type, norm_type=norm_type, depthwise=depthwise)
 
         # ---------------- Transformer Encoder ----------------
         self.transformer_encoder = TransformerEncoder(d_model        = self.out_dim,
@@ -78,51 +86,43 @@ class HybridEncoder(nn.Module):
 
         # ---------------- Top dwon FPN ----------------
         ## P5 -> P4
-        self.reduce_layer_1 = BasicConv(self.out_dim, self.out_dim,
-                                        kernel_size=1, padding=0, stride=1,
-                                        act_type=act_type, norm_type=norm_type)
-        self.top_down_layer_1 = RepRTCBlock(in_dim     = self.out_dim * 2,
-                                            out_dim    = self.out_dim,
-                                            num_blocks = num_blocks,
-                                            expansion  = expansion,
-                                            act_type   = act_type,
-                                            norm_type  = norm_type,
-                                           )
+        self.top_down_layer_1 = RTCBlock(in_dim      = self.out_dim * 2,
+                                         out_dim     = self.out_dim,
+                                         num_blocks  = num_blocks,
+                                         shortcut    = False,
+                                         act_type    = act_type,
+                                         norm_type   = norm_type,
+                                         depthwise   = depthwise,
+                                         )
         ## P4 -> P3
-        self.reduce_layer_2 = BasicConv(self.out_dim, self.out_dim,
-                                        kernel_size=1, padding=0, stride=1,
-                                        act_type=act_type, norm_type=norm_type)
-        self.top_down_layer_2 = RepRTCBlock(in_dim     = self.out_dim * 2,
-                                            out_dim    = self.out_dim,
-                                            num_blocks = num_blocks,
-                                            expansion  = expansion,
-                                            act_type   = act_type,
-                                            norm_type  = norm_type,
-                                            )
+        self.top_down_layer_2 = RTCBlock(in_dim      = self.out_dim * 2,
+                                         out_dim     = self.out_dim,
+                                         num_blocks  = num_blocks,
+                                         shortcut    = False,
+                                         act_type    = act_type,
+                                         norm_type   = norm_type,
+                                         depthwise   = depthwise,
+                                         )
         
         # ---------------- Bottom up PAN----------------
         ## P3 -> P4
-        self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim,
-                                            kernel_size=3, padding=1, stride=2,
-                                            act_type=act_type, norm_type=norm_type)
-        self.bottom_up_layer_1 = RepRTCBlock(in_dim     = self.out_dim * 2,
-                                             out_dim    = self.out_dim,
-                                             num_blocks = num_blocks,
-                                             expansion  = expansion,
-                                             act_type   = act_type,
-                                             norm_type  = norm_type,
-                                             )
+        self.bottom_up_layer_1 = RTCBlock(in_dim      = self.out_dim * 2,
+                                          out_dim     = self.out_dim,
+                                          num_blocks  = num_blocks,
+                                          shortcut    = False,
+                                          act_type    = act_type,
+                                          norm_type   = norm_type,
+                                          depthwise   = depthwise,
+                                          )
         ## P4 -> P5
-        self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim,
-                                            kernel_size=3, padding=1, stride=2,
-                                            act_type=act_type, norm_type=norm_type)
-        self.bottom_up_layer_2 = RepRTCBlock(in_dim     = self.out_dim * 2,
-                                             out_dim    = self.out_dim,
-                                             num_blocks = num_blocks,
-                                             expansion  = expansion,
-                                             act_type   = act_type,
-                                             norm_type  = norm_type,
-                                             )
+        self.bottom_up_layer_2 = RTCBlock(in_dim      = self.out_dim * 2,
+                                          out_dim     = self.out_dim,
+                                          num_blocks  = num_blocks,
+                                          shortcut    = False,
+                                          act_type    = act_type,
+                                          norm_type   = norm_type,
+                                          depthwise   = depthwise,
+                                          )
 
         self.init_weights()
   
@@ -138,31 +138,26 @@ class HybridEncoder(nn.Module):
         c3, c4, c5 = features
 
         # -------- Input projs --------
-        p5 = self.input_proj_1(c5)
-        p4 = self.input_proj_2(c4)
-        p3 = self.input_proj_3(c3)
+        p5 = self.reduce_layer_1(c5)
+        p4 = self.reduce_layer_2(c4)
+        p3 = self.reduce_layer_3(c3)
 
         # -------- Transformer encoder --------
         p5 = self.transformer_encoder(p5)
 
         # -------- Top down FPN --------
-        ## P5 -> P4
-        p5_in = self.reduce_layer_1(p5)
-        p5_up = F.interpolate(p5_in, scale_factor=2.0)
-        p4    = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
+        p5_up = F.interpolate(p5, scale_factor=2.0)
+        p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
 
-        ## P4 -> P3
-        p4_in = self.reduce_layer_2(p4)
-        p4_up = F.interpolate(p4_in, scale_factor=2.0)
-        p3    = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
+        p4_up = F.interpolate(p4, scale_factor=2.0)
+        p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
 
         # -------- Bottom up PAN --------
-        ## P3 -> P4
         p3_ds = self.dowmsample_layer_1(p3)
-        p4    = self.bottom_up_layer_1(torch.cat([p4_in, p3_ds], dim=1))
+        p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
 
         p4_ds = self.dowmsample_layer_2(p4)
-        p5    = self.bottom_up_layer_2(torch.cat([p5_in, p4_ds], dim=1))
+        p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
 
         out_feats = [p3, p4, p5]
         
@@ -178,7 +173,7 @@ if __name__ == '__main__':
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
         'fpn_num_blocks': 3,
-        'fpn_expansion': 1.0,
+        'fpn_expansion': 0.5,
         'en_num_heads': 8,
         'en_num_layers': 1,
         'en_ffn_dim': 1024,
@@ -202,4 +197,4 @@ if __name__ == '__main__':
     flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
     print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
+    print('Params : {:.2f} M'.format(params / 1e6))