|
|
@@ -2,295 +2,318 @@ import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
-from torch.nn.init import constant_, xavier_uniform_, uniform_, normal_
|
|
|
-from typing import List
|
|
|
|
|
|
try:
|
|
|
- from .basic_modules.basic import BasicConv, MLP
|
|
|
- from .basic_modules.transformer import PlainTransformerDecoder
|
|
|
+ from .basic_modules.basic import LayerNorm2D
|
|
|
+ from .basic_modules.transformer import GlobalDecoder
|
|
|
except:
|
|
|
- from basic_modules.basic import BasicConv, MLP
|
|
|
- from basic_modules.transformer import PlainTransformerDecoder
|
|
|
+ from basic_modules.basic import LayerNorm2D
|
|
|
+ from basic_modules.transformer import GlobalDecoder
|
|
|
|
|
|
-
|
|
|
-def build_transformer(cfg, in_dims, num_classes, return_intermediate=False):
|
|
|
+def build_transformer(cfg, return_intermediate=False):
|
|
|
if cfg['transformer'] == 'plain_detr_transformer':
|
|
|
- return PlainDETRTransformer(in_dims = in_dims,
|
|
|
- hidden_dim = cfg['hidden_dim'],
|
|
|
- strides = cfg['out_stride'],
|
|
|
- num_classes = num_classes,
|
|
|
- num_queries = cfg['num_queries'],
|
|
|
- pos_embed_type = 'sine',
|
|
|
- num_heads = cfg['de_num_heads'],
|
|
|
- num_layers = cfg['de_num_layers'],
|
|
|
- num_levels = len(cfg['out_stride']),
|
|
|
- num_points = cfg['de_num_points'],
|
|
|
- mlp_ratio = cfg['de_mlp_ratio'],
|
|
|
- dropout = cfg['de_dropout'],
|
|
|
- act_type = cfg['de_act'],
|
|
|
- return_intermediate = return_intermediate,
|
|
|
- num_denoising = cfg['dn_num_denoising'],
|
|
|
- label_noise_ratio = cfg['dn_label_noise_ratio'],
|
|
|
- box_noise_scale = cfg['dn_box_noise_scale'],
|
|
|
- learnt_init_query = cfg['learnt_init_query'],
|
|
|
- )
|
|
|
+ return PlainDETRTransformer(d_model = cfg['hidden_dim'],
|
|
|
+ num_heads = cfg['de_num_heads'],
|
|
|
+ mlp_ratio = cfg['de_mlp_ratio'],
|
|
|
+ dropout = cfg['de_dropout'],
|
|
|
+ act_type = cfg['de_act'],
|
|
|
+ pre_norm = cfg['de_pre_norm'],
|
|
|
+ rpe_hidden_dim = cfg['rpe_hidden_dim'],
|
|
|
+ feature_stride = cfg['out_stride'],
|
|
|
+ num_layers = cfg['de_num_layers'],
|
|
|
+ return_intermediate = return_intermediate,
|
|
|
+ use_checkpoint = cfg['use_checkpoint'],
|
|
|
+ num_queries_one2one = cfg['num_queries_one2one'],
|
|
|
+ num_queries_one2many = cfg['num_queries_one2many'],
|
|
|
+ proposal_feature_levels = cfg['proposal_feature_levels'],
|
|
|
+ proposal_in_stride = cfg['out_stride'],
|
|
|
+ proposal_tgt_strides = cfg['proposal_tgt_strides'],
|
|
|
+ )
|
|
|
|
|
|
|
|
|
# ----------------- Dencoder for Detection task -----------------
|
|
|
-## RTDETR's Transformer for Detection task
|
|
|
+## PlainDETR's Transformer for Detection task
|
|
|
class PlainDETRTransformer(nn.Module):
|
|
|
def __init__(self,
|
|
|
- # basic parameters
|
|
|
- in_dims :List = [256, 512, 1024],
|
|
|
- hidden_dim :int = 256,
|
|
|
- strides :List = [8, 16, 32],
|
|
|
- num_classes :int = 80,
|
|
|
- num_queries :int = 300,
|
|
|
- pos_embed_type :str = 'sine',
|
|
|
- # transformer parameters
|
|
|
+ # Decoder layer params
|
|
|
+ d_model :int = 256,
|
|
|
num_heads :int = 8,
|
|
|
- num_layers :int = 1,
|
|
|
- num_levels :int = 3,
|
|
|
- num_points :int = 4,
|
|
|
mlp_ratio :float = 4.0,
|
|
|
dropout :float = 0.1,
|
|
|
act_type :str = "relu",
|
|
|
+ pre_norm :bool = False,
|
|
|
+ rpe_hidden_dim :int = 512,
|
|
|
+ feature_stride :int = 16,
|
|
|
+ num_layers :int = 6,
|
|
|
+ # Decoder params
|
|
|
return_intermediate :bool = False,
|
|
|
- # Denoising parameters
|
|
|
- num_denoising :int = 100,
|
|
|
- label_noise_ratio :float = 0.5,
|
|
|
- box_noise_scale :float = 1.0,
|
|
|
- learnt_init_query :bool = True,
|
|
|
+ use_checkpoint :bool = False,
|
|
|
+ num_queries_one2one :int = 300,
|
|
|
+ num_queries_one2many :int = 1500,
|
|
|
+ proposal_feature_levels :int = 3,
|
|
|
+ proposal_in_stride :int = 16,
|
|
|
+ proposal_tgt_strides :int = [8, 16, 32],
|
|
|
):
|
|
|
super().__init__()
|
|
|
- # --------------- Basic setting ---------------
|
|
|
- ## Basic parameters
|
|
|
- self.in_dims = in_dims
|
|
|
- self.strides = strides
|
|
|
- self.num_queries = num_queries
|
|
|
- self.pos_embed_type = pos_embed_type
|
|
|
- self.num_classes = num_classes
|
|
|
- self.eps = 1e-2
|
|
|
- ## Transformer parameters
|
|
|
- self.num_heads = num_heads
|
|
|
+ # ------------ Basic setting ------------
|
|
|
+ ## Model
|
|
|
+ self.d_model = d_model
|
|
|
+ self.num_heads = num_heads
|
|
|
+ self.rpe_hidden_dim = rpe_hidden_dim
|
|
|
+ self.mlp_ratio = mlp_ratio
|
|
|
+ self.act_type = act_type
|
|
|
self.num_layers = num_layers
|
|
|
- self.num_levels = num_levels
|
|
|
- self.num_points = num_points
|
|
|
- self.mlp_ratio = mlp_ratio
|
|
|
- self.dropout = dropout
|
|
|
- self.act_type = act_type
|
|
|
self.return_intermediate = return_intermediate
|
|
|
- ## Denoising parameters
|
|
|
- self.num_denoising = num_denoising
|
|
|
- self.label_noise_ratio = label_noise_ratio
|
|
|
- self.box_noise_scale = box_noise_scale
|
|
|
- self.learnt_init_query = learnt_init_query
|
|
|
+ ## Trick
|
|
|
+ self.use_checkpoint = use_checkpoint
|
|
|
+ self.num_queries_one2one = num_queries_one2one
|
|
|
+ self.num_queries_one2many = num_queries_one2many
|
|
|
+ self.proposal_feature_levels = proposal_feature_levels
|
|
|
+ self.proposal_tgt_strides = proposal_tgt_strides
|
|
|
+ self.proposal_in_stride = proposal_in_stride
|
|
|
+ self.proposal_min_size = 50
|
|
|
|
|
|
# --------------- Network setting ---------------
|
|
|
- ## Input proj layers
|
|
|
- self.input_proj_layers = nn.ModuleList(
|
|
|
- BasicConv(in_dims[i], hidden_dim, kernel_size=1, act_type=None, norm_type="BN")
|
|
|
- for i in range(num_levels)
|
|
|
- )
|
|
|
-
|
|
|
- ## Deformable transformer decoder
|
|
|
- self.decoder = PlainTransformerDecoder(
|
|
|
- d_model = hidden_dim,
|
|
|
- num_heads = num_heads,
|
|
|
- num_layers = num_layers,
|
|
|
- num_levels = num_levels,
|
|
|
- num_points = num_points,
|
|
|
- mlp_ratio = mlp_ratio,
|
|
|
- dropout = dropout,
|
|
|
- act_type = act_type,
|
|
|
- return_intermediate = return_intermediate
|
|
|
- )
|
|
|
+ ## Global Decoder
|
|
|
+ self.decoder = GlobalDecoder(d_model, num_heads, mlp_ratio, dropout, act_type, pre_norm,
|
|
|
+ rpe_hidden_dim, feature_stride, num_layers, return_intermediate,
|
|
|
+ use_checkpoint,)
|
|
|
|
|
|
- ## Detection head for Encoder
|
|
|
- self.enc_output = nn.Sequential(
|
|
|
- nn.Linear(hidden_dim, hidden_dim),
|
|
|
- nn.LayerNorm(hidden_dim)
|
|
|
- )
|
|
|
- self.enc_class_head = nn.Linear(hidden_dim, num_classes)
|
|
|
- self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
|
|
|
-
|
|
|
- ## Detection head for Decoder
|
|
|
- self.dec_class_head = nn.ModuleList([
|
|
|
- nn.Linear(hidden_dim, num_classes)
|
|
|
- for _ in range(num_layers)
|
|
|
- ])
|
|
|
- self.dec_bbox_head = nn.ModuleList([
|
|
|
- MLP(hidden_dim, hidden_dim, 4, num_layers=3)
|
|
|
- for _ in range(num_layers)
|
|
|
- ])
|
|
|
-
|
|
|
- ## Object query
|
|
|
- if learnt_init_query:
|
|
|
- self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
|
|
|
- self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
|
|
|
-
|
|
|
- ## Denoising part
|
|
|
- self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim)
|
|
|
+ ## Two stage
|
|
|
+ self.enc_output = nn.Linear(d_model, d_model)
|
|
|
+ self.enc_output_norm = nn.LayerNorm(d_model)
|
|
|
+ self.pos_trans = nn.Linear(d_model * 2, d_model * 2)
|
|
|
+ self.pos_trans_norm = nn.LayerNorm(d_model * 2)
|
|
|
+
|
|
|
+ ## Expand layers
|
|
|
+ if proposal_feature_levels > 1:
|
|
|
+ assert len(proposal_tgt_strides) == proposal_feature_levels
|
|
|
+
|
|
|
+ self.enc_output_proj = nn.ModuleList([])
|
|
|
+ for stride in proposal_tgt_strides:
|
|
|
+ if stride == proposal_in_stride:
|
|
|
+ self.enc_output_proj.append(nn.Identity())
|
|
|
+ elif stride > proposal_in_stride:
|
|
|
+ scale = int(math.log2(stride / proposal_in_stride))
|
|
|
+ layers = []
|
|
|
+ for _ in range(scale - 1):
|
|
|
+ layers += [
|
|
|
+ nn.Conv2d(d_model, d_model, kernel_size=2, stride=2),
|
|
|
+ LayerNorm2D(d_model),
|
|
|
+ nn.GELU()
|
|
|
+ ]
|
|
|
+ layers.append(nn.Conv2d(d_model, d_model, kernel_size=2, stride=2))
|
|
|
+ self.enc_output_proj.append(nn.Sequential(*layers))
|
|
|
+ else:
|
|
|
+ scale = int(math.log2(proposal_in_stride / stride))
|
|
|
+ layers = []
|
|
|
+ for _ in range(scale - 1):
|
|
|
+ layers += [
|
|
|
+ nn.ConvTranspose2d(d_model, d_model, kernel_size=2, stride=2),
|
|
|
+ LayerNorm2D(d_model),
|
|
|
+ nn.GELU()
|
|
|
+ ]
|
|
|
+ layers.append(nn.ConvTranspose2d(d_model, d_model, kernel_size=2, stride=2))
|
|
|
+ self.enc_output_proj.append(nn.Sequential(*layers))
|
|
|
|
|
|
self._reset_parameters()
|
|
|
|
|
|
def _reset_parameters(self):
|
|
|
- def linear_init_(module):
|
|
|
- bound = 1 / math.sqrt(module.weight.shape[0])
|
|
|
- uniform_(module.weight, -bound, bound)
|
|
|
- if hasattr(module, "bias") and module.bias is not None:
|
|
|
- uniform_(module.bias, -bound, bound)
|
|
|
-
|
|
|
- # class and bbox head init
|
|
|
- prior_prob = 0.01
|
|
|
- cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob))
|
|
|
- linear_init_(self.enc_class_head)
|
|
|
- constant_(self.enc_class_head.bias, cls_bias_init)
|
|
|
- constant_(self.enc_bbox_head.layers[-1].weight, 0.)
|
|
|
- constant_(self.enc_bbox_head.layers[-1].bias, 0.)
|
|
|
- for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head):
|
|
|
- linear_init_(cls_)
|
|
|
- constant_(cls_.bias, cls_bias_init)
|
|
|
- constant_(reg_.layers[-1].weight, 0.)
|
|
|
- constant_(reg_.layers[-1].bias, 0.)
|
|
|
-
|
|
|
- linear_init_(self.enc_output[0])
|
|
|
- xavier_uniform_(self.enc_output[0].weight)
|
|
|
- if self.learnt_init_query:
|
|
|
- xavier_uniform_(self.tgt_embed.weight)
|
|
|
- xavier_uniform_(self.query_pos_head.layers[0].weight)
|
|
|
- xavier_uniform_(self.query_pos_head.layers[1].weight)
|
|
|
- for l in self.input_proj_layers:
|
|
|
- xavier_uniform_(l.conv.weight)
|
|
|
- normal_(self.denoising_class_embed.weight)
|
|
|
-
|
|
|
- def generate_anchors(self, spatial_shapes, grid_size=0.05):
|
|
|
- anchors = []
|
|
|
- for lvl, (h, w) in enumerate(spatial_shapes):
|
|
|
- grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w))
|
|
|
- # [H, W, 2]
|
|
|
- grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
|
|
|
-
|
|
|
- valid_WH = torch.as_tensor([w, h]).float()
|
|
|
- grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
|
|
|
- wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
|
|
|
- # [H, W, 4] -> [1, N, 4], N=HxW
|
|
|
- anchors.append(torch.cat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4))
|
|
|
- # List[L, 1, N_i, 4] -> [1, N, 4], N=N_0 + N_1 + N_2 + ...
|
|
|
- anchors = torch.cat(anchors, dim=1)
|
|
|
- valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
|
|
|
- anchors = torch.log(anchors / (1 - anchors))
|
|
|
- # Equal to operation: anchors = torch.masked_fill(anchors, ~valid_mask, torch.as_tensor(float("inf")))
|
|
|
- anchors = torch.where(valid_mask, anchors, torch.as_tensor(float("inf")))
|
|
|
+ for p in self.parameters():
|
|
|
+ if p.dim() > 1:
|
|
|
+ nn.init.xavier_uniform_(p)
|
|
|
+
|
|
|
+ if hasattr(self.decoder, '_reset_parameters'):
|
|
|
+ print('decoder re-init')
|
|
|
+ self.decoder._reset_parameters()
|
|
|
+
|
|
|
+ def get_proposal_pos_embed(self, proposals):
|
|
|
+ num_pos_feats = self.d_model // 2
|
|
|
+ temperature = 10000
|
|
|
+ scale = 2 * torch.pi
|
|
|
+
|
|
|
+ dim_t = torch.arange(
|
|
|
+ num_pos_feats, dtype=torch.float32, device=proposals.device
|
|
|
+ )
|
|
|
+ dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
|
|
|
+ # N, L, 4
|
|
|
+ proposals = proposals * scale
|
|
|
+ # N, L, 4, 128
|
|
|
+ pos = proposals[:, :, :, None] / dim_t
|
|
|
+ # N, L, 4, 64, 2
|
|
|
+ pos = torch.stack(
|
|
|
+ (pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4
|
|
|
+ ).flatten(2)
|
|
|
+
|
|
|
+ return pos
|
|
|
+
|
|
|
+ def get_valid_ratio(self, mask):
|
|
|
+ _, H, W = mask.shape
|
|
|
+ valid_H = torch.sum(~mask[:, :, 0], 1)
|
|
|
+ valid_W = torch.sum(~mask[:, 0, :], 1)
|
|
|
+ valid_ratio_h = valid_H.float() / H
|
|
|
+ valid_ratio_w = valid_W.float() / W
|
|
|
+ valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1)
|
|
|
+
|
|
|
+ return valid_ratio
|
|
|
+
|
|
|
+ def expand_encoder_output(self, memory, memory_padding_mask, spatial_shapes):
|
|
|
+ assert spatial_shapes.size(0) == 1, f'Get encoder output of shape {spatial_shapes}, not sure how to expand'
|
|
|
+
|
|
|
+ bs, _, c = memory.shape
|
|
|
+ h, w = spatial_shapes[0]
|
|
|
+
|
|
|
+ _out_memory = memory.view(bs, h, w, c).permute(0, 3, 1, 2)
|
|
|
+ _out_memory_padding_mask = memory_padding_mask.view(bs, h, w)
|
|
|
+
|
|
|
+ out_memory, out_memory_padding_mask, out_spatial_shapes = [], [], []
|
|
|
+ for i in range(self.proposal_feature_levels):
|
|
|
+ mem = self.enc_output_proj[i](_out_memory)
|
|
|
+ mask = F.interpolate(
|
|
|
+ _out_memory_padding_mask[None].float(), size=mem.shape[-2:]
|
|
|
+ ).to(torch.bool)
|
|
|
+
|
|
|
+ out_memory.append(mem)
|
|
|
+ out_memory_padding_mask.append(mask.squeeze(0))
|
|
|
+ out_spatial_shapes.append(mem.shape[-2:])
|
|
|
+
|
|
|
+ out_memory = torch.cat([mem.flatten(2).transpose(1, 2) for mem in out_memory], dim=1)
|
|
|
+ out_memory_padding_mask = torch.cat([mask.flatten(1) for mask in out_memory_padding_mask], dim=1)
|
|
|
+ out_spatial_shapes = torch.as_tensor(out_spatial_shapes, dtype=torch.long, device=out_memory.device)
|
|
|
|
|
|
- return anchors, valid_mask
|
|
|
+ return out_memory, out_memory_padding_mask, out_spatial_shapes
|
|
|
+
|
|
|
+ def gen_encoder_output_proposals(self, memory, memory_padding_mask, spatial_shapes):
|
|
|
+ if self.proposal_feature_levels > 1:
|
|
|
+ memory, memory_padding_mask, spatial_shapes = self.expand_encoder_output(
|
|
|
+ memory, memory_padding_mask, spatial_shapes
|
|
|
+ )
|
|
|
+ N_, S_, C_ = memory.shape
|
|
|
+ # base_scale = 4.0
|
|
|
+ proposals = []
|
|
|
+ _cur = 0
|
|
|
+ for lvl, (H_, W_) in enumerate(spatial_shapes):
|
|
|
+ stride = self.proposal_tgt_strides[lvl]
|
|
|
+
|
|
|
+ grid_y, grid_x = torch.meshgrid(
|
|
|
+ torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
|
|
|
+ torch.linspace(0, W_ - 1, W_, dtype=torch.float32, device=memory.device),
|
|
|
+ )
|
|
|
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
|
|
|
+ grid = (grid.unsqueeze(0).expand(N_, -1, -1, -1) + 0.5) * stride
|
|
|
+ wh = torch.ones_like(grid) * self.proposal_min_size * (2.0 ** lvl)
|
|
|
+ proposal = torch.cat((grid, wh), -1).view(N_, -1, 4)
|
|
|
+ proposals.append(proposal)
|
|
|
+ _cur += H_ * W_
|
|
|
+ output_proposals = torch.cat(proposals, 1)
|
|
|
+
|
|
|
+ H_, W_ = spatial_shapes[0]
|
|
|
+ stride = self.proposal_tgt_strides[0]
|
|
|
+ mask_flatten_ = memory_padding_mask[:, :H_*W_].view(N_, H_, W_, 1)
|
|
|
+ valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1, keepdim=True) * stride
|
|
|
+ valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1, keepdim=True) * stride
|
|
|
+ img_size = torch.cat([valid_W, valid_H, valid_W, valid_H], dim=-1)
|
|
|
+ img_size = img_size.unsqueeze(1) # [BS, 1, 4]
|
|
|
+
|
|
|
+ output_proposals_valid = (
|
|
|
+ (output_proposals > 0.01 * img_size) & (output_proposals < 0.99 * img_size)
|
|
|
+ ).all(-1, keepdim=True)
|
|
|
+ output_proposals = output_proposals.masked_fill(
|
|
|
+ memory_padding_mask.unsqueeze(-1).repeat(1, 1, 1),
|
|
|
+ max(H_, W_) * stride,
|
|
|
+ )
|
|
|
+ output_proposals = output_proposals.masked_fill(
|
|
|
+ ~output_proposals_valid,
|
|
|
+ max(H_, W_) * stride,
|
|
|
+ )
|
|
|
+
|
|
|
+ output_memory = memory
|
|
|
+ output_memory = output_memory.masked_fill(memory_padding_mask.unsqueeze(-1), float(0))
|
|
|
+ output_memory = output_memory.masked_fill(~output_proposals_valid, float(0))
|
|
|
+ output_memory = self.enc_output_norm(self.enc_output(output_memory))
|
|
|
+
|
|
|
+ max_shape = (valid_H[:, None, :], valid_W[:, None, :])
|
|
|
+ return output_memory, output_proposals, max_shape
|
|
|
|
|
|
- def get_encoder_input(self, feats):
|
|
|
- # get projection features
|
|
|
- proj_feats = [self.input_proj_layers[i](feat) for i, feat in enumerate(feats)]
|
|
|
-
|
|
|
- # get encoder inputs
|
|
|
- feat_flatten = []
|
|
|
- spatial_shapes = []
|
|
|
- level_start_index = [0, ]
|
|
|
- for i, feat in enumerate(proj_feats):
|
|
|
- _, _, h, w = feat.shape
|
|
|
- spatial_shapes.append([h, w])
|
|
|
- # [l], start index of each level
|
|
|
- level_start_index.append(h * w + level_start_index[-1])
|
|
|
- # [B, C, H, W] -> [B, N, C], N=HxW
|
|
|
- feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
|
|
|
-
|
|
|
- # [B, N, C], N = N_0 + N_1 + ...
|
|
|
- feat_flatten = torch.cat(feat_flatten, dim=1)
|
|
|
- level_start_index.pop()
|
|
|
-
|
|
|
- return (feat_flatten, spatial_shapes, level_start_index)
|
|
|
-
|
|
|
- def get_decoder_input(self,
|
|
|
- memory,
|
|
|
- spatial_shapes,
|
|
|
- denoising_class=None,
|
|
|
- denoising_bbox_unact=None):
|
|
|
- bs, _, _ = memory.shape
|
|
|
- # Prepare input for decoder
|
|
|
- anchors, valid_mask = self.generate_anchors(spatial_shapes)
|
|
|
- anchors = anchors.to(memory.device)
|
|
|
- valid_mask = valid_mask.to(memory.device)
|
|
|
+ def get_reference_points(self, memory, mask_flatten, spatial_shapes):
|
|
|
+ output_memory, output_proposals, max_shape = self.gen_encoder_output_proposals(
|
|
|
+ memory, mask_flatten, spatial_shapes
|
|
|
+ )
|
|
|
+
|
|
|
+ # hack implementation for two-stage Deformable DETR
|
|
|
+ enc_outputs_class = self.decoder.class_embed[self.decoder.num_layers](output_memory)
|
|
|
+ enc_outputs_delta = self.decoder.bbox_embed[self.decoder.num_layers](output_memory)
|
|
|
+ enc_outputs_coord_unact = self.decoder.box_xyxy_to_cxcywh(self.decoder.delta2bbox(
|
|
|
+ output_proposals,
|
|
|
+ enc_outputs_delta,
|
|
|
+ max_shape
|
|
|
+ ))
|
|
|
+
|
|
|
+ topk = self.two_stage_num_proposals
|
|
|
+ topk_proposals = torch.topk(enc_outputs_class[..., 0], topk, dim=1)[1]
|
|
|
+ topk_coords_unact = torch.gather(
|
|
|
+ enc_outputs_coord_unact, 1, topk_proposals.unsqueeze(-1).repeat(1, 1, 4)
|
|
|
+ )
|
|
|
+ topk_coords_unact = topk_coords_unact.detach()
|
|
|
+ reference_points = topk_coords_unact
|
|
|
|
|
|
- # Process encoder's output
|
|
|
- memory = torch.where(valid_mask, memory, torch.as_tensor(0., device=memory.device))
|
|
|
- output_memory = self.enc_output(memory)
|
|
|
-
|
|
|
- # Head for encoder's output : [bs, num_quries, c]
|
|
|
- enc_outputs_class = self.enc_class_head(output_memory)
|
|
|
- enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
|
|
|
-
|
|
|
- # Topk proposals from encoder's output
|
|
|
- topk = self.num_queries
|
|
|
- topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # [bs, num_queries]
|
|
|
- enc_topk_logits = torch.gather(
|
|
|
- enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes)) # [bs, num_queries, nc]
|
|
|
- reference_points_unact = torch.gather(
|
|
|
- enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4)) # [bs, num_queries, 4]
|
|
|
- enc_topk_bboxes = F.sigmoid(reference_points_unact)
|
|
|
-
|
|
|
- if denoising_bbox_unact is not None:
|
|
|
- reference_points_unact = torch.cat(
|
|
|
- [denoising_bbox_unact, reference_points_unact], 1)
|
|
|
- if self.training:
|
|
|
- reference_points_unact = reference_points_unact.detach()
|
|
|
+ return (reference_points, max_shape, enc_outputs_class,
|
|
|
+ enc_outputs_coord_unact, enc_outputs_delta, output_proposals)
|
|
|
|
|
|
- # Extract region features
|
|
|
- if self.learnt_init_query:
|
|
|
- # [num_queries, c] -> [b, num_queries, c]
|
|
|
- target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
|
|
|
- else:
|
|
|
- # [num_queries, c] -> [b, num_queries, c]
|
|
|
- target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
|
|
|
- if self.training:
|
|
|
- target = target.detach()
|
|
|
- if denoising_class is not None:
|
|
|
- target = torch.cat([denoising_class, target], dim=1)
|
|
|
-
|
|
|
- return target, reference_points_unact, enc_topk_bboxes, enc_topk_logits
|
|
|
-
|
|
|
- def forward(self, feats, targets=None):
|
|
|
- # input projection and embedding
|
|
|
- memory, spatial_shapes, _ = self.get_encoder_input(feats)
|
|
|
+ def forward(self, src, mask, pos_embed, query_embed=None, self_attn_mask=None):
|
|
|
+ # Prepare input for encoder
|
|
|
+ bs, c, h, w = src.shape
|
|
|
+ src_flatten = src.flatten(2).transpose(1, 2)
|
|
|
+ mask_flatten = mask.flatten(1)
|
|
|
+ pos_embed_flatten = pos_embed.flatten(2).transpose(1, 2)
|
|
|
+ spatial_shapes = torch.as_tensor([(h, w)], dtype=torch.long, device=src_flatten.device)
|
|
|
|
|
|
- # prepare denoising training
|
|
|
+ # Prepare input for decoder
|
|
|
+ memory = src_flatten
|
|
|
+ bs, _, c = memory.shape
|
|
|
+
|
|
|
+ # Two stage trick
|
|
|
if self.training:
|
|
|
- denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
|
|
|
- get_contrastive_denoising_training_group(targets,
|
|
|
- self.num_classes,
|
|
|
- self.num_queries,
|
|
|
- self.denoising_class_embed.weight,
|
|
|
- self.num_denoising,
|
|
|
- self.label_noise_ratio,
|
|
|
- self.box_noise_scale)
|
|
|
+ self.two_stage_num_proposals = self.num_queries_one2one + self.num_queries_one2many
|
|
|
else:
|
|
|
- denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
|
|
|
-
|
|
|
- target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
|
|
|
- self.get_decoder_input(
|
|
|
- memory, spatial_shapes, denoising_class, denoising_bbox_unact)
|
|
|
-
|
|
|
- # decoder
|
|
|
- out_bboxes, out_logits = self.decoder(target,
|
|
|
- init_ref_points_unact,
|
|
|
- memory,
|
|
|
- spatial_shapes,
|
|
|
- self.dec_bbox_head,
|
|
|
- self.dec_class_head,
|
|
|
- self.query_pos_head,
|
|
|
- attn_mask)
|
|
|
-
|
|
|
- return out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta
|
|
|
+ self.two_stage_num_proposals = self.num_queries_one2one
|
|
|
+ (reference_points, max_shape, enc_outputs_class,
|
|
|
+ enc_outputs_coord_unact, enc_outputs_delta, output_proposals) \
|
|
|
+ = self.get_reference_points(memory, mask_flatten, spatial_shapes)
|
|
|
+ init_reference_out = reference_points
|
|
|
+ pos_trans_out = torch.zeros((bs, self.two_stage_num_proposals, 2*c), device=init_reference_out.device)
|
|
|
+ pos_trans_out = self.pos_trans_norm(self.pos_trans(self.get_proposal_pos_embed(reference_points)))
|
|
|
+
|
|
|
+ # Mixed selection trick
|
|
|
+ tgt = query_embed.unsqueeze(0).expand(bs, -1, -1)
|
|
|
+ query_embed, _ = torch.split(pos_trans_out, c, dim=2)
|
|
|
+
|
|
|
+ # Decoder
|
|
|
+ hs, inter_references = self.decoder(tgt,
|
|
|
+ reference_points,
|
|
|
+ memory,
|
|
|
+ pos_embed_flatten,
|
|
|
+ spatial_shapes,
|
|
|
+ query_embed,
|
|
|
+ mask_flatten,
|
|
|
+ self_attn_mask,
|
|
|
+ max_shape
|
|
|
+ )
|
|
|
+ inter_references_out = inter_references
|
|
|
+
|
|
|
+ return (hs,
|
|
|
+ init_reference_out,
|
|
|
+ inter_references_out,
|
|
|
+ enc_outputs_class,
|
|
|
+ enc_outputs_coord_unact,
|
|
|
+ enc_outputs_delta,
|
|
|
+ output_proposals,
|
|
|
+ max_shape
|
|
|
+ )
|
|
|
|
|
|
|
|
|
# ----------------- Dencoder for Segmentation task -----------------
|
|
|
-## RTDETR's Transformer for Segmentation task
|
|
|
+## PlainDETR's Transformer for Segmentation task
|
|
|
class SegTransformerDecoder(nn.Module):
|
|
|
def __init__(self, ):
|
|
|
super().__init__()
|
|
|
@@ -301,7 +324,7 @@ class SegTransformerDecoder(nn.Module):
|
|
|
|
|
|
|
|
|
# ----------------- Dencoder for Pose estimation task -----------------
|
|
|
-## RTDETR's Transformer for Pose estimation task
|
|
|
+## PlainDETR's Transformer for Pose estimation task
|
|
|
class PosTransformerDecoder(nn.Module):
|
|
|
def __init__(self, ):
|
|
|
super().__init__()
|
|
|
@@ -314,50 +337,66 @@ class PosTransformerDecoder(nn.Module):
|
|
|
if __name__ == '__main__':
|
|
|
import time
|
|
|
from thop import profile
|
|
|
+ from basic_modules.basic import MLP
|
|
|
+ from basic_modules.transformer import get_clones
|
|
|
+
|
|
|
cfg = {
|
|
|
- 'out_stride': [8, 16, 32],
|
|
|
+ 'out_stride': 16,
|
|
|
# Transformer Decoder
|
|
|
- 'transformer': 'rtdetr_transformer',
|
|
|
+ 'transformer': 'plain_detr_transformer',
|
|
|
'hidden_dim': 256,
|
|
|
+ 'num_queries': 300,
|
|
|
'de_num_heads': 8,
|
|
|
'de_num_layers': 6,
|
|
|
'de_mlp_ratio': 4.0,
|
|
|
'de_dropout': 0.1,
|
|
|
'de_act': 'gelu',
|
|
|
- 'de_num_points': 4,
|
|
|
- 'num_queries': 300,
|
|
|
- 'learnt_init_query': False,
|
|
|
- 'pe_temperature': 10000.,
|
|
|
- 'dn_num_denoising': 100,
|
|
|
- 'dn_label_noise_ratio': 0.5,
|
|
|
- 'dn_box_noise_scale': 1,
|
|
|
+ 'de_pre_norm': True,
|
|
|
+ 'rpe_hidden_dim': 512,
|
|
|
+ 'use_checkpoint': False,
|
|
|
+ 'proposal_feature_levels': 3,
|
|
|
+ 'proposal_tgt_strides': [8, 16, 32],
|
|
|
}
|
|
|
- bs = 1
|
|
|
- hidden_dim = cfg['hidden_dim']
|
|
|
- in_dims = [hidden_dim] * 3
|
|
|
- targets = [{
|
|
|
- 'labels': torch.tensor([2, 4, 5, 8]).long(),
|
|
|
- 'boxes': torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float()
|
|
|
- }] * bs
|
|
|
- pyramid_feats = [torch.randn(bs, hidden_dim, 80, 80),
|
|
|
- torch.randn(bs, hidden_dim, 40, 40),
|
|
|
- torch.randn(bs, hidden_dim, 20, 20)]
|
|
|
- model = build_transformer(cfg, in_dims, 80, True)
|
|
|
- model.train()
|
|
|
+ feat = torch.randn(1, cfg['hidden_dim'], 40, 40)
|
|
|
+ mask = torch.zeros(1, 40, 40)
|
|
|
+ pos_embed = torch.randn(1, cfg['hidden_dim'], 40, 40)
|
|
|
+ query_embed = torch.randn(cfg['num_queries'], cfg['hidden_dim'])
|
|
|
|
|
|
+ model = build_transformer(cfg, True)
|
|
|
+
|
|
|
+ class_embed = nn.Linear(cfg['hidden_dim'], 80)
|
|
|
+ bbox_embed = MLP(cfg['hidden_dim'], cfg['hidden_dim'], 4, 3)
|
|
|
+ class_embed = get_clones(class_embed, cfg['de_num_layers'] + 1)
|
|
|
+ bbox_embed = get_clones(bbox_embed, cfg['de_num_layers'] + 1)
|
|
|
+
|
|
|
+ model.decoder.bbox_embed = bbox_embed
|
|
|
+ model.decoder.class_embed = class_embed
|
|
|
+
|
|
|
+ model.train()
|
|
|
t0 = time.time()
|
|
|
- outputs = model(pyramid_feats, targets)
|
|
|
- out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta = outputs
|
|
|
+ outputs = model(feat, mask, pos_embed, query_embed)
|
|
|
+ (hs,
|
|
|
+ init_reference_out,
|
|
|
+ inter_references_out,
|
|
|
+ enc_outputs_class,
|
|
|
+ enc_outputs_coord_unact,
|
|
|
+ enc_outputs_delta,
|
|
|
+ output_proposals,
|
|
|
+ max_shape
|
|
|
+ ) = outputs
|
|
|
t1 = time.time()
|
|
|
print('Time: ', t1 - t0)
|
|
|
- print(out_bboxes.shape)
|
|
|
- print(out_logits.shape)
|
|
|
- print(enc_topk_bboxes.shape)
|
|
|
- print(enc_topk_logits.shape)
|
|
|
+ print(hs.shape)
|
|
|
+ print(init_reference_out.shape)
|
|
|
+ print(inter_references_out.shape)
|
|
|
+ print(enc_outputs_class.shape)
|
|
|
+ print(enc_outputs_coord_unact.shape)
|
|
|
+ print(enc_outputs_delta.shape)
|
|
|
+ print(output_proposals.shape)
|
|
|
|
|
|
print('==============================')
|
|
|
model.eval()
|
|
|
- flops, params = profile(model, inputs=(pyramid_feats, ), verbose=False)
|
|
|
+ flops, params = profile(model, inputs=(feat, mask, pos_embed, query_embed, ), verbose=False)
|
|
|
print('==============================')
|
|
|
print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
|
|
|
print('Params : {:.2f} M'.format(params / 1e6))
|