Bläddra i källkod

add MSDeformableAttn

yjh0410 1 år sedan
förälder
incheckning
5629b894c1

+ 366 - 2
models/detectors/rtdetr/basic_modules/basic.py

@@ -2,6 +2,7 @@ import math
 import copy
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 
 
 def get_clones(module, N):
@@ -243,8 +244,177 @@ class RTCBlock(nn.Module):
 
 
 # ----------------- Transformer modules -----------------
-## Transformer layer
-class TransformerLayer(nn.Module):
+## Basic ops of Deformable Attn
+def deformable_attention_core_func(value, value_spatial_shapes,
+                                   value_level_start_index, sampling_locations,
+                                   attention_weights):
+    """
+    Args:
+        value (Tensor): [bs, value_length, n_head, c]
+        value_spatial_shapes (Tensor|List): [n_levels, 2]
+        value_level_start_index (Tensor|List): [n_levels]
+        sampling_locations (Tensor): [bs, query_length, n_head, n_levels, n_points, 2]
+        attention_weights (Tensor): [bs, query_length, n_head, n_levels, n_points]
+
+    Returns:
+        output (Tensor): [bs, Length_{query}, C]
+    """
+    bs, _, n_head, c = value.shape
+    _, Len_q, _, n_levels, n_points, _ = sampling_locations.shape
+
+    split_shape = [h * w for h, w in value_spatial_shapes]
+    value_list = value.split(split_shape, axis=1)
+    sampling_grids = 2 * sampling_locations - 1
+    sampling_value_list = []
+    for level, (h, w) in enumerate(value_spatial_shapes):
+        # N_, H_*W_, M_, D_ -> N_, H_*W_, M_*D_ -> N_, M_*D_, H_*W_ -> N_*M_, D_, H_, W_
+        value_l_ = value_list[level].flatten(2).transpose(
+            [0, 2, 1]).reshape([bs * n_head, c, h, w])
+        # N_, Lq_, M_, P_, 2 -> N_, M_, Lq_, P_, 2 -> N_*M_, Lq_, P_, 2
+        sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(
+            [0, 2, 1, 3, 4]).flatten(0, 1)
+        # N_*M_, D_, Lq_, P_
+        sampling_value_l_ = F.grid_sample(
+            value_l_,
+            sampling_grid_l_,
+            mode='bilinear',
+            padding_mode='zeros',
+            align_corners=False)
+        sampling_value_list.append(sampling_value_l_)
+    # (N_, Lq_, M_, L_, P_) -> (N_, M_, Lq_, L_, P_) -> (N_*M_, 1, Lq_, L_*P_)
+    attention_weights = attention_weights.transpose([0, 2, 1, 3, 4]).reshape(
+        [bs * n_head, 1, Len_q, n_levels * n_points])
+    output = (torch.stack(
+        sampling_value_list, axis=-2).flatten(-2) *
+              attention_weights).sum(-1).reshape([bs, n_head * c, Len_q])
+
+    return output.transpose([0, 2, 1])
+
+class MSDeformableAttention(nn.Layer):
+    def __init__(self,
+                 embed_dim=256,
+                 num_heads=8,
+                 num_levels=4,
+                 num_points=4,
+                 lr_mult=0.1):
+        """
+        Multi-Scale Deformable Attention Module
+        """
+        super(MSDeformableAttention, self).__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.num_levels = num_levels
+        self.num_points = num_points
+        self.total_points = num_heads * num_levels * num_points
+
+        self.head_dim = embed_dim // num_heads
+        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+
+        self.sampling_offsets = nn.Linear(
+            embed_dim,
+            self.total_points * 2,
+            weight_attr=ParamAttr(learning_rate=lr_mult),
+            bias_attr=ParamAttr(learning_rate=lr_mult))
+
+        self.attention_weights = nn.Linear(embed_dim, self.total_points)
+        self.value_proj = nn.Linear(embed_dim, embed_dim)
+        self.output_proj = nn.Linear(embed_dim, embed_dim)
+        try:
+            # use cuda op
+            from deformable_detr_ops import ms_deformable_attn
+            self.ms_deformable_attn_core = ms_deformable_attn
+        except:
+            # use paddle func
+            self.ms_deformable_attn_core = deformable_attention_core_func
+
+        self._reset_parameters()
+
+    def _reset_parameters(self):
+        # sampling_offsets
+        constant_(self.sampling_offsets.weight)
+        thetas = paddle.arange(
+            self.num_heads,
+            dtype=paddle.float32) * (2.0 * math.pi / self.num_heads)
+        grid_init = paddle.stack([thetas.cos(), thetas.sin()], -1)
+        grid_init = grid_init / grid_init.abs().max(-1, keepdim=True)
+        grid_init = grid_init.reshape([self.num_heads, 1, 1, 2]).tile(
+            [1, self.num_levels, self.num_points, 1])
+        scaling = paddle.arange(
+            1, self.num_points + 1,
+            dtype=paddle.float32).reshape([1, 1, -1, 1])
+        grid_init *= scaling
+        self.sampling_offsets.bias.set_value(grid_init.flatten())
+        # attention_weights
+        constant_(self.attention_weights.weight)
+        constant_(self.attention_weights.bias)
+        # proj
+        xavier_uniform_(self.value_proj.weight)
+        constant_(self.value_proj.bias)
+        xavier_uniform_(self.output_proj.weight)
+        constant_(self.output_proj.bias)
+
+    def forward(self,
+                query,
+                reference_points,
+                value,
+                value_spatial_shapes,
+                value_level_start_index,
+                value_mask=None):
+        """
+        Args:
+            query (Tensor): [bs, query_length, C]
+            reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
+                bottom-right (1, 1), including padding area
+            value (Tensor): [bs, value_length, C]
+            value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+            value_level_start_index (Tensor(int64)): [n_levels], [0, H_0*W_0, H_0*W_0+H_1*W_1, ...]
+            value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
+
+        Returns:
+            output (Tensor): [bs, Length_{query}, C]
+        """
+        bs, Len_q = query.shape[:2]
+        Len_v = value.shape[1]
+        assert int(value_spatial_shapes.prod(1).sum()) == Len_v
+
+        value = self.value_proj(value)
+        if value_mask is not None:
+            value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
+            value *= value_mask
+        value = value.reshape([bs, Len_v, self.num_heads, self.head_dim])
+
+        sampling_offsets = self.sampling_offsets(query).reshape(
+            [bs, Len_q, self.num_heads, self.num_levels, self.num_points, 2])
+        attention_weights = self.attention_weights(query).reshape(
+            [bs, Len_q, self.num_heads, self.num_levels * self.num_points])
+        attention_weights = F.softmax(attention_weights).reshape(
+            [bs, Len_q, self.num_heads, self.num_levels, self.num_points])
+
+        if reference_points.shape[-1] == 2:
+            offset_normalizer = value_spatial_shapes.flip([1]).reshape(
+                [1, 1, 1, self.num_levels, 1, 2])
+            sampling_locations = reference_points.reshape([
+                bs, Len_q, 1, self.num_levels, 1, 2
+            ]) + sampling_offsets / offset_normalizer
+        elif reference_points.shape[-1] == 4:
+            sampling_locations = (
+                reference_points[:, :, None, :, None, :2] + sampling_offsets /
+                self.num_points * reference_points[:, :, None, :, None, 2:] *
+                0.5)
+        else:
+            raise ValueError(
+                "Last dim of reference_points must be 2 or 4, but get {} instead.".
+                format(reference_points.shape[-1]))
+
+        output = self.ms_deformable_attn_core(
+            value, value_spatial_shapes, value_level_start_index,
+            sampling_locations, attention_weights)
+        output = self.output_proj(output)
+
+        return output
+
+## Transformer Encoder layer
+class TransformerEncoderLayer(nn.Module):
     def __init__(self,
                  d_model         :int   = 256,
                  num_heads       :int   = 8,
@@ -291,3 +461,197 @@ class TransformerLayer(nn.Module):
         src = self.ffn(src)
         
         return src
+
+## Transformer Encoder
+class TransformerEncoder(nn.Module):
+    def __init__(self,
+                 d_model        :int   = 256,
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 mlp_ratio      :float = 4.0,
+                 pe_temperature : float = 10000.,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_layers = num_layers
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        self.pe_temperature = pe_temperature
+        self.pos_embed = None
+        # ----------- Basic parameters -----------
+        self.encoder_layers = get_clones(
+            TransformerEncoderLayer(d_model, num_heads, mlp_ratio, dropout, act_type), num_layers)
+
+    def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
+        assert embed_dim % 4 == 0, \
+            'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
+        
+        # ----------- Check cahed pos_embed -----------
+        if self.pos_embed is not None and \
+            self.pos_embed.shape[2:] == [h, w]:
+            return self.pos_embed
+        
+        # ----------- Generate grid coords -----------
+        grid_w = torch.arange(int(w), dtype=torch.float32)
+        grid_h = torch.arange(int(h), dtype=torch.float32)
+        grid_w, grid_h = torch.meshgrid([grid_w, grid_h])  # shape: [H, W]
+
+        pos_dim = embed_dim // 4
+        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
+        omega = 1. / (temperature**omega)
+
+        out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
+        out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
+
+        # shape: [1, N, C]
+        pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
+        self.pos_embed = pos_embed
+
+        return pos_embed
+
+    def forward(self, src):
+        """
+        Input:
+            src:       [torch.Tensor] -> [B, C, H, W]
+        Output:
+            src:       [torch.Tensor] -> [B, N, C]
+        """
+        # -------- Transformer encoder --------
+        for encoder in self.encoder_layers:
+            channels, fmp_h, fmp_w = src.shape[1:]
+            # [B, C, H, W] -> [B, N, C], N=HxW
+            src_flatten = src.flatten(2).permute(0, 2, 1)
+            pos_embed = self.build_2d_sincos_position_embedding(
+                    fmp_w, fmp_h, channels, self.pe_temperature)
+            memory = encoder(src_flatten, pos_embed=pos_embed)
+            # [B, N, C] -> [B, C, N] -> [B, C, H, W]
+            src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
+
+        return src
+
+## Transformer Decoder layer
+class TransformerDecoderLayer(nn.Module):
+    def __init__(self,
+                 d_model         :int   = 256,
+                 num_heads       :int   = 8,
+                 num_levels      :int   = 3,
+                 num_points      :int   = 4,
+                 mlp_ratio       :float = 4.0,
+                 dropout         :float = 0.1,
+                 act_type        :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_levels = num_levels
+        self.num_points = num_points
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        # ---------------- Network parameters ----------------
+        ## Multi-head Self-Attn
+        self.self_attn  = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+        self.dropout1 = nn.Dropout(dropout)
+        self.norm1 = nn.LayerNorm(d_model)
+        ## CrossAttention
+        self.cross_attn = MSDeformableAttention(d_model, num_heads, num_levels, num_points, 1.0)
+        self.dropout2 = nn.Dropout(dropout)
+        self.norm2 = nn.LayerNorm(d_model)
+        ## FFN
+        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
+
+    def with_pos_embed(self, tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+    def forward(self,
+                tgt,
+                reference_points,
+                memory,
+                memory_spatial_shapes,
+                memory_level_start_index,
+                attn_mask=None,
+                memory_mask=None,
+                query_pos_embed=None):
+        # ---------------- MSHA for Object Query -----------------
+        q = k = self.with_pos_embed(tgt, query_pos_embed)
+        if attn_mask is not None:
+            attn_mask = torch.where(
+                attn_mask.astype('bool'),
+                torch.zeros(attn_mask.shape, tgt.dtype),
+                torch.full(attn_mask.shape, float("-inf"), tgt.dtype))
+        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+
+        # ---------------- CMHA for Object Query and Image-feature -----------------
+        tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos_embed),
+                               reference_points,
+                               memory,
+                               memory_spatial_shapes,
+                               memory_level_start_index,
+                               memory_mask)
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+
+        # ---------------- FeedForward Network -----------------
+        tgt = self.ffn(tgt)
+
+        return tgt
+
+## Transformer Decoder
+class TransformerDecoder(nn.Module):
+    def __init__(self,
+                 d_model        :int   = 256,
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 mlp_ratio      :float = 4.0,
+                 pe_temperature :float = 10000.,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_layers = num_layers
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        self.pe_temperature = pe_temperature
+        self.pos_embed = None
+        # ----------- Basic parameters -----------
+        self.decoder_layers = None
+
+    def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
+        assert embed_dim % 4 == 0, \
+            'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
+        
+        # ----------- Check cahed pos_embed -----------
+        if self.pos_embed is not None and \
+            self.pos_embed.shape[2:] == [h, w]:
+            return self.pos_embed
+        
+        # ----------- Generate grid coords -----------
+        grid_w = torch.arange(int(w), dtype=torch.float32)
+        grid_h = torch.arange(int(h), dtype=torch.float32)
+        grid_w, grid_h = torch.meshgrid([grid_w, grid_h])  # shape: [H, W]
+
+        pos_dim = embed_dim // 4
+        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
+        omega = 1. / (temperature**omega)
+
+        out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
+        out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
+
+        # shape: [1, N, C]
+        pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
+        self.pos_embed = pos_embed
+
+        return pos_embed
+

+ 12 - 44
models/detectors/rtdetr/basic_modules/fpn.py

@@ -4,9 +4,9 @@ import torch.nn.functional as F
 from typing import List
 
 try:
-    from .basic import get_clones, BasicConv, RTCBlock, TransformerLayer
+    from .basic import get_clones, BasicConv, RTCBlock, TransformerEncoder
 except:
-    from  basic import get_clones, BasicConv, RTCBlock, TransformerLayer
+    from  basic import get_clones, BasicConv, RTCBlock, TransformerEncoder
 
 
 # Build PaFPN
@@ -31,7 +31,7 @@ def build_fpn(cfg, in_dims, out_dim):
 
 
 # ----------------- Feature Pyramid Network -----------------
-## Real-time Convolutional PaFPN
+## Hybrid Encoder (Transformer encoder + Convolutional PaFPN)
 class HybridEncoder(nn.Module):
     def __init__(self, 
                  in_dims     :List  = [256, 512, 512],
@@ -60,8 +60,6 @@ class HybridEncoder(nn.Module):
         self.num_heads = num_heads
         self.num_layers = num_layers
         self.mlp_ratio = mlp_ratio
-        self.pe_temperature = pe_temperature
-        self.pos_embed = None
         c3, c4, c5 = in_dims
 
         # ---------------- Input projs ----------------
@@ -74,8 +72,14 @@ class HybridEncoder(nn.Module):
         self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
 
         # ---------------- Transformer Encoder ----------------
-        self.transformer_encoder = get_clones(
-            TransformerLayer(self.out_dim, num_heads, mlp_ratio, dropout, en_act_type), num_layers)
+        self.transformer_encoder = TransformerEncoder(d_model        = self.out_dim,
+                                                      num_heads      = num_heads,
+                                                      num_layers     = num_layers,
+                                                      mlp_ratio      = mlp_ratio,
+                                                      pe_temperature = pe_temperature,
+                                                      dropout        = dropout,
+                                                      act_type       = en_act_type
+                                                      )
 
         # ---------------- Top dwon FPN ----------------
         ## P5 -> P4
@@ -127,33 +131,6 @@ class HybridEncoder(nn.Module):
                 # reset the Conv2d initialization parameters
                 m.reset_parameters()
 
-    def build_2d_sincos_position_embedding(self, w, h, embed_dim=256, temperature=10000.):
-        assert embed_dim % 4 == 0, \
-            'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
-        
-        # ----------- Check cahed pos_embed -----------
-        if self.pos_embed is not None and \
-            self.pos_embed.shape[2:] == [h, w]:
-            return self.pos_embed
-        
-        # ----------- Generate grid coords -----------
-        grid_w = torch.arange(int(w), dtype=torch.float32)
-        grid_h = torch.arange(int(h), dtype=torch.float32)
-        grid_w, grid_h = torch.meshgrid([grid_w, grid_h])  # shape: [H, W]
-
-        pos_dim = embed_dim // 4
-        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
-        omega = 1. / (temperature**omega)
-
-        out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
-        out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
-
-        # shape: [1, N, C]
-        pos_embed = torch.concat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], axis=1)[None, :, :]
-        self.pos_embed = pos_embed
-
-        return pos_embed
-
     def forward(self, features):
         c3, c4, c5 = features
 
@@ -163,16 +140,7 @@ class HybridEncoder(nn.Module):
         p3 = self.reduce_layer_3(c3)
 
         # -------- Transformer encoder --------
-        if self.transformer_encoder is not None:
-            for encoder in self.transformer_encoder:
-                channels, fmp_h, fmp_w = p5.shape[1:]
-                # [B, C, H, W] -> [B, N, C], N=HxW
-                src_flatten = p5.flatten(2).permute(0, 2, 1)
-                pos_embed = self.build_2d_sincos_position_embedding(
-                        fmp_w, fmp_h, channels, self.pe_temperature)
-                memory = encoder(src_flatten, pos_embed=pos_embed)
-                # [B, N, C] -> [B, C, N] -> [B, C, H, W]
-                p5 = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
+        p5 = self.transformer_encoder(p5)
 
         # -------- Top down FPN --------
         p5_up = F.interpolate(p5, scale_factor=2.0)

+ 0 - 7
models/detectors/rtdetr/basic_modules/neck.py

@@ -1,7 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-# Build neck
-def build_neck(cfg, in_dim, out_dim):
-    return

+ 1 - 0
models/detectors/rtdetr/rtdetr_decoder.py

@@ -4,6 +4,7 @@ import torch.nn.functional as F
 
 
 # ----------------- Dencoder for Detection task -----------------
+## RTDETR's Transformer
 class DetDecoder(nn.Module):
     def __init__(self, ):
         super().__init__()

+ 3 - 2
utils/solver/optimizer.py

@@ -61,13 +61,14 @@ def build_detr_optimizer(cfg, model, resume=None):
 
     # ------------- Divide model's parameters -------------
     param_dicts = [], [], [], [], [], []
+    norm_names = ["norm"] + ["norm{}".format(i) for i in range(10000)]
     for n, p in model.named_parameters():
         # Non-Backbone's learnable parameters
         if "backbone" not in n and p.requires_grad:
             if "bias" == n.split(".")[-1]:
                 param_dicts[0].append(p)      # no weight decay for all layers' bias
             else:
-                if "norm" == n.split(".")[-2]:
+                if n.split(".")[-2] in norm_names:
                     param_dicts[1].append(p)  # no weight decay for all NormLayers' weight
                 else:
                     param_dicts[2].append(p)  # weight decay for all Non-NormLayers' weight
@@ -76,7 +77,7 @@ def build_detr_optimizer(cfg, model, resume=None):
             if "bias" == n.split(".")[-1]:
                 param_dicts[3].append(p)      # no weight decay for all layers' bias
             else:
-                if "norm" == n.split(".")[-2]:
+                if n.split(".")[-2] in norm_names:
                     param_dicts[4].append(p)  # no weight decay for all NormLayers' weight
                 else:
                     param_dicts[5].append(p)  # weight decay for all Non-NormLayers' weight