yjh0410 1 год назад
Родитель
Сommit
9496a5117f

+ 18 - 8
engine.py

@@ -1133,7 +1133,6 @@ class RTRTrainer(object):
         self.world_size = world_size
         self.grad_accumulate = args.grad_accumulate
         self.clip_grad = 0.1
-        self.args.fp16 = False   # No AMP for DETR
         # path to save model
         self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
         os.makedirs(self.path_to_save, exist_ok=True)
@@ -1141,6 +1140,7 @@ class RTRTrainer(object):
         # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
         self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.0001, 'backbone_lr_ratio': 0.1}
         self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.1}
+        self.ema_dict = {'ema_decay': 0.9999, 'ema_tau': 2000}
 
         # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
         self.data_cfg  = data_cfg
@@ -1173,6 +1173,13 @@ class RTRTrainer(object):
         if self.args.resume and self.args.resume != 'None':
             self.lr_scheduler.step()
 
+        # ---------------------------- Build Model-EMA ----------------------------
+        if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
+            print('Build ModelEMA ...')
+            self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
+        else:
+            self.model_ema = None
+
     def train(self, model):
         for epoch in range(self.start_epoch, self.args.max_epoch):
             if self.args.distributed:
@@ -1188,6 +1195,9 @@ class RTRTrainer(object):
                 self.eval(model_eval)
 
     def eval(self, model):
+        # chech model
+        model_eval = model if self.model_ema is None else self.model_ema.ema
+
         if distributed_utils.is_main_process():
             # check evaluator
             if self.evaluator is None:
@@ -1195,7 +1205,7 @@ class RTRTrainer(object):
                 print('Saving state, epoch: {}'.format(self.epoch))
                 weight_name = '{}_no_eval.pth'.format(self.args.model)
                 checkpoint_path = os.path.join(self.path_to_save, weight_name)
-                torch.save({'model': model.state_dict(),
+                torch.save({'model': model_eval.state_dict(),
                             'mAP': -1.,
                             'optimizer': self.optimizer.state_dict(),
                             'epoch': self.epoch,
@@ -1204,12 +1214,12 @@ class RTRTrainer(object):
             else:
                 print('eval ...')
                 # set eval mode
-                model.trainable = False
-                model.eval()
+                model_eval.trainable = False
+                model_eval.eval()
 
                 # evaluate
                 with torch.no_grad():
-                    self.evaluator.evaluate(model)
+                    self.evaluator.evaluate(model_eval)
 
                 # save model
                 cur_map = self.evaluator.map
@@ -1220,7 +1230,7 @@ class RTRTrainer(object):
                     print('Saving state, epoch:', self.epoch)
                     weight_name = '{}_best.pth'.format(self.args.model)
                     checkpoint_path = os.path.join(self.path_to_save, weight_name)
-                    torch.save({'model': model.state_dict(),
+                    torch.save({'model': model_eval.state_dict(),
                                 'mAP': round(self.best_map*100, 1),
                                 'optimizer': self.optimizer.state_dict(),
                                 'epoch': self.epoch,
@@ -1228,8 +1238,8 @@ class RTRTrainer(object):
                                 checkpoint_path)                      
 
                 # set train mode.
-                model.trainable = True
-                model.train()
+                model_eval.trainable = True
+                model_eval.train()
 
         if self.args.distributed:
             # wait for all processes to synchronize

+ 16 - 6
models/detectors/rtdetr/basic_modules/basic.py

@@ -1,8 +1,16 @@
 import math
+import copy
 import torch
 import torch.nn as nn
 
 
+def get_clones(module, N):
+    if N <= 0:
+        return None
+    else:
+        return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
+
+
 # ----------------- MLP modules -----------------
 class MLP(nn.Module):
     def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
@@ -50,6 +58,8 @@ def get_activation(act_type=None):
         return nn.Mish(inplace=True)
     elif act_type == 'silu':
         return nn.SiLU(inplace=True)
+    elif act_type == 'gelu':
+        return nn.GELU()
     elif act_type is None:
         return nn.Identity()
     else:
@@ -262,18 +272,18 @@ class TransformerLayer(nn.Module):
         return tensor if pos is None else tensor + pos
 
 
-    def forward(self, src, pos):
+    def forward(self, src, pos_embed):
         """
         Input:
-            src: [torch.Tensor] -> [B, N, C]
-            pos: [torch.Tensor] -> [B, N, C]
+            src:       [torch.Tensor] -> [B, N, C]
+            pos_embed: [torch.Tensor] -> [B, N, C]
         Output:
-            src: [torch.Tensor] -> [B, N, C]
+            src:       [torch.Tensor] -> [B, N, C]
         """
-        q = k = self.with_pos_embed(src, pos)
+        q = k = self.with_pos_embed(src, pos_embed)
 
         # -------------- MHSA --------------
-        src2 = self.self_attn(q, k, value=src)
+        src2 = self.self_attn(q, k, value=src)[0]
         src = src + self.dropout(src2)
         src = self.norm(src)
 

+ 229 - 0
models/detectors/rtdetr/basic_modules/fpn.py

@@ -0,0 +1,229 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List
+
+try:
+    from .basic import get_clones, BasicConv, RTCBlock, TransformerLayer
+except:
+    from  basic import get_clones, BasicConv, RTCBlock, TransformerLayer
+
+
+# Build PaFPN
+def build_fpn(cfg, in_dims, out_dim):
+    if cfg['fpn'] == 'hybrid_encoder':
+        return HybridEncoder(in_dims     = in_dims,
+                             out_dim     = out_dim,
+                             width       = cfg['width'],
+                             depth       = cfg['depth'],
+                             act_type    = cfg['fpn_act'],
+                             norm_type   = cfg['fpn_norm'],
+                             depthwise   = cfg['fpn_depthwise'],
+                             num_heads   = cfg['en_num_heads'],
+                             num_layers  = cfg['en_num_layers'],
+                             mlp_ratio   = cfg['en_mlp_ratio'],
+                             dropout     = cfg['en_dropout'],
+                             pe_temperature = cfg['pe_temperature'],
+                             en_act_type    = cfg['en_act'],
+                             )
+    else:
+        raise NotImplementedError("Unknown PaFPN: <{}>".format(cfg['fpn']))
+
+
+# ----------------- Feature Pyramid Network -----------------
+## Real-time Convolutional PaFPN
+class HybridEncoder(nn.Module):
+    def __init__(self, 
+                 in_dims     :List  = [256, 512, 512],
+                 out_dim     :int   = 256,
+                 width       :float = 1.0,
+                 depth       :float = 1.0,
+                 act_type    :str   = 'silu',
+                 norm_type   :str   = 'BN',
+                 depthwise   :bool  = False,
+                 # Transformer's parameters
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 mlp_ratio      :float = 4.0,
+                 dropout        :float = 0.1,
+                 pe_temperature :float = 10000.,
+                 en_act_type    :str   = 'gelu'
+                 ) -> None:
+        super(HybridEncoder, self).__init__()
+        print('==============================')
+        print('FPN: {}'.format("RTC-PaFPN"))
+        # ---------------- Basic parameters ----------------
+        self.in_dims = in_dims
+        self.out_dim = round(out_dim * width)
+        self.width = width
+        self.depth = depth
+        self.num_heads = num_heads
+        self.num_layers = num_layers
+        self.mlp_ratio = mlp_ratio
+        self.pe_temperature = pe_temperature
+        self.pos_embed = None
+        c3, c4, c5 = in_dims
+
+        # ---------------- Input projs ----------------
+        self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+
+        # ---------------- Downsample ----------------
+        self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
+        self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
+
+        # ---------------- Transformer Encoder ----------------
+        self.transformer_encoder = get_clones(
+            TransformerLayer(self.out_dim, num_heads, mlp_ratio, dropout, en_act_type), num_layers)
+
+        # ---------------- Top dwon FPN ----------------
+        ## P5 -> P4
+        self.top_down_layer_1 = RTCBlock(in_dim       = self.out_dim * 2,
+                                         out_dim      = self.out_dim,
+                                         num_blocks   = round(3*depth),
+                                         shortcut     = False,
+                                         act_type     = act_type,
+                                         norm_type    = norm_type,
+                                         depthwise    = depthwise,
+                                         )
+        ## P4 -> P3
+        self.top_down_layer_2 = RTCBlock(in_dim       = self.out_dim * 2,
+                                         out_dim      = self.out_dim,
+                                         num_blocks   = round(3*depth),
+                                         shortcut     = False,
+                                         act_type     = act_type,
+                                         norm_type    = norm_type,
+                                         depthwise    = depthwise,
+                                         )
+        
+        # ---------------- Bottom up PAN----------------
+        ## P3 -> P4
+        self.bottom_up_layer_1 = RTCBlock(in_dim       = self.out_dim * 2,
+                                          out_dim      = self.out_dim,
+                                          num_blocks   = round(3*depth),
+                                          shortcut     = False,
+                                          act_type     = act_type,
+                                          norm_type    = norm_type,
+                                          depthwise    = depthwise,
+                                          )
+        ## P4 -> P5
+        self.bottom_up_layer_2 = RTCBlock(in_dim       = self.out_dim * 2,
+                                          out_dim      = self.out_dim,
+                                          num_blocks   = round(3*depth),
+                                          shortcut     = False,
+                                          act_type     = act_type,
+                                          norm_type    = norm_type,
+                                          depthwise    = depthwise,
+                                          )
+
+        self.init_weights()
+  
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
+        assert embed_dim % 4 == 0, \
+            'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
+        
+        # ----------- Check cahed pos_embed -----------
+        if self.pos_embed is not None and \
+            self.pos_embed.shape[2:] == [h, w]:
+            return self.pos_embed
+        
+        # ----------- Generate grid coords -----------
+        grid_w = torch.arange(int(w), dtype=torch.float32)
+        grid_h = torch.arange(int(h), dtype=torch.float32)
+        grid_w, grid_h = torch.meshgrid([grid_w, grid_h])  # shape: [H, W]
+
+        pos_dim = embed_dim // 4
+        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
+        omega = 1. / (temperature**omega)
+
+        out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
+        out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
+
+        # shape: [1, N, C]
+        pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
+        self.pos_embed = pos_embed
+
+        return pos_embed
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        # -------- Input projs --------
+        p5 = self.reduce_layer_1(c5)
+        p4 = self.reduce_layer_2(c4)
+        p3 = self.reduce_layer_3(c3)
+
+        # -------- Transformer encoder --------
+        if self.transformer_encoder is not None:
+            for encoder in self.transformer_encoder:
+                channels, fmp_h, fmp_w = p5.shape[1:]
+                # [B, C, H, W] -> [B, N, C], N=HxW
+                src_flatten = p5.flatten(2).permute(0, 2, 1)
+                pos_embed = self.build_2d_sincos_position_embedding(
+                        fmp_w, fmp_h, channels, self.pe_temperature)
+                memory = encoder(src_flatten, pos_embed=pos_embed)
+                # [B, N, C] -> [B, C, N] -> [B, C, H, W]
+                p5 = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
+
+        # -------- Top down FPN --------
+        p5_up = F.interpolate(p5, scale_factor=2.0)
+        p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
+
+        p4_up = F.interpolate(p4, scale_factor=2.0)
+        p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
+
+        # -------- Bottom up PAN --------
+        p3_ds = self.dowmsample_layer_1(p3)
+        p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
+
+        p4_ds = self.dowmsample_layer_2(p4)
+        p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
+
+        out_feats = [p3, p4, p5]
+        
+        return out_feats
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'width': 1.0,
+        'depth': 1.0,
+        'fpn': 'hybrid_encoder',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        'en_num_heads': 8,
+        'en_num_layers': 1,
+        'en_mlp_ratio': 4.0,
+        'en_dropout': 0.1,
+        'pe_temperature': 10000.,
+        'en_act': 'gelu',
+    }
+    fpn_dims = [256, 512, 1024]
+    out_dim = 256
+    pyramid_feats = [torch.randn(1, fpn_dims[0], 80, 80), torch.randn(1, fpn_dims[1], 40, 40), torch.randn(1, fpn_dims[2], 20, 20)]
+    model = build_fpn(cfg, fpn_dims, out_dim)
+
+    t0 = time.time()
+    outputs = model(pyramid_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 0 - 117
models/detectors/rtdetr/basic_modules/pafpn.py

@@ -1,117 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from .basic import BasicConv, RTCBlock
-
-
-# Build PaFPN
-def build_pafpn(cfg, in_dims, out_dim):
-    return
-
-
-# ----------------- Feature Pyramid Network -----------------
-## Real-time Convolutional PaFPN
-class HybridEncoder(nn.Module):
-    def __init__(self, 
-                 in_dims   = [256, 512, 512],
-                 out_dim   = 256,
-                 width     = 1.0,
-                 depth     = 1.0,
-                 act_type  = 'silu',
-                 norm_type = 'BN',
-                 depthwise = False):
-        super(HybridEncoder, self).__init__()
-        print('==============================')
-        print('FPN: {}'.format("RTC-PaFPN"))
-        # ---------------- Basic parameters ----------------
-        self.in_dims = in_dims
-        self.out_dim = round(out_dim * width)
-        self.width = width
-        self.depth = depth
-        c3, c4, c5 = in_dims
-
-        # ---------------- Input projs ----------------
-        self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-        self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
-
-        # ---------------- Downsample ----------------
-        self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
-        self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
-
-        # ---------------- Top dwon FPN ----------------
-        ## P5 -> P4
-        self.top_down_layer_1 = RTCBlock(in_dim       = self.out_dim * 2,
-                                         out_dim      = self.out_dim,
-                                         num_blocks   = round(3*depth),
-                                         shortcut     = False,
-                                         act_type     = act_type,
-                                         norm_type    = norm_type,
-                                         depthwise    = depthwise,
-                                         )
-        ## P4 -> P3
-        self.top_down_layer_2 = RTCBlock(in_dim       = self.out_dim * 2,
-                                         out_dim      = self.out_dim,
-                                         num_blocks   = round(3*depth),
-                                         shortcut     = False,
-                                         act_type     = act_type,
-                                         norm_type    = norm_type,
-                                         depthwise    = depthwise,
-                                         )
-        
-        # ---------------- Bottom up PAN----------------
-        ## P3 -> P4
-        self.bottom_up_layer_1 = RTCBlock(in_dim       = self.out_dim * 2,
-                                          out_dim      = self.out_dim,
-                                          num_blocks   = round(3*depth),
-                                          shortcut     = False,
-                                          act_type     = act_type,
-                                          norm_type    = norm_type,
-                                          depthwise    = depthwise,
-                                          )
-        ## P4 -> P5
-        self.bottom_up_layer_2 = RTCBlock(in_dim       = self.out_dim * 2,
-                                          out_dim      = self.out_dim,
-                                          num_blocks   = round(3*depth),
-                                          shortcut     = False,
-                                          act_type     = act_type,
-                                          norm_type    = norm_type,
-                                          depthwise    = depthwise,
-                                          )
-
-        self.init_weights()
-        
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def forward(self, features):
-        c3, c4, c5 = features
-
-        # -------- Input projs --------
-        p5 = self.reduce_layer_1(c5)
-        p4 = self.reduce_layer_2(c4)
-        p3 = self.reduce_layer_3(c3)
-
-        # -------- Top down FPN --------
-        p5_up = F.interpolate(p5, scale_factor=2.0)
-        p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
-
-        p4_up = F.interpolate(p4, scale_factor=2.0)
-        p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
-
-        # -------- Bottom up PAN --------
-        p3_ds = self.dowmsample_layer_1(p3)
-        p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
-
-        p4_ds = self.dowmsample_layer_2(p4)
-        p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
-
-        out_feats = [p3, p4, p5]
-        
-        return out_feats

+ 71 - 8
models/detectors/rtdetr/rtdetr_encoder.py

@@ -2,18 +2,81 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from .basic_modules.backbone import build_backbone
-from .basic_modules.pafpn    import build_pafpn
+try:
+    from .basic_modules.backbone import build_backbone
+    from .basic_modules.fpn      import build_fpn
+except:
+    from  basic_modules.backbone import build_backbone
+    from  basic_modules.fpn      import build_fpn
 
 
 # ----------------- Image Encoder -----------------
+def build_image_encoder(cfg, trainable=False):
+    return ImageEncoder(cfg, trainable)
+
 class ImageEncoder(nn.Module):
-    def __init__(self, ):
+    def __init__(self, cfg, trainable=False):
         super().__init__()
-        self.backbone = None
-        self.neck = None
-        self.fpn = None
+        # ---------------- Basic settings ----------------
+        ## Basic parameters
+        self.cfg = cfg
+        ## Network parameters
+        self.strides = cfg['out_stride']
+        self.hidden_dim = cfg['hidden_dim']
+        self.num_levels = len(self.strides)
+        
+        # ---------------- Network settings ----------------
+        ## Backbone Network
+        self.backbone, fpn_feat_dims = build_backbone(cfg, pretrained=cfg['pretrained']&trainable)
 
+        ## Feature Pyramid Network
+        self.fpn = build_fpn(cfg, fpn_feat_dims, self.hidden_dim)
+        
     def forward(self, x):
-        return
-    
+        pyramid_feats = self.backbone(x)
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        return pyramid_feats
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'width': 1.0,
+        'depth': 1.0,
+        'out_stride': [8, 16, 32],
+        # Image Encoder - Backbone
+        'backbone': 'resnet18',
+        'backbone_norm': 'BN',
+        'res5_dilation': False,
+        'pretrained': True,
+        'pretrained_weight': 'imagenet1k_v1',
+        # Image Encoder - FPN
+        'fpn': 'hybrid_encoder',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        'hidden_dim': 256,
+        'en_num_heads': 8,
+        'en_num_layers': 1,
+        'en_mlp_ratio': 4.0,
+        'en_dropout': 0.1,
+        'pe_temperature': 10000.,
+        'en_act': 'gelu',
+    }
+    x = torch.rand(2, 3, 640, 640)
+    model = build_image_encoder(cfg, True)
+
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))