import math import torch import torch.nn as nn import torch.nn.functional as F from typing import List from .basic_modules.conv import BasicConv from .basic_modules.mlp import MLP from .basic_modules.transformer import DeformableTransformerDecoder from .basic_modules.dn_compoments import get_contrastive_denoising_training_group # ----------------- Dencoder for Detection task ----------------- ## RTDETR's Transformer for Detection task class RTDetrTransformer(nn.Module): def __init__(self, # basic parameters in_dims :List = [256, 512, 1024], hidden_dim :int = 256, strides :List = [8, 16, 32], num_classes :int = 80, num_queries :int = 300, # transformer parameters num_heads :int = 8, num_layers :int = 1, num_levels :int = 3, num_points :int = 4, ffn_dim :int = 1024, dropout :float = 0.1, act_type :str = "relu", return_intermediate :bool = False, # Denoising parameters num_denoising :int = 100, label_noise_ratio :float = 0.5, box_noise_scale :float = 1.0, learnt_init_query :bool = False, aux_loss :bool = True ): super().__init__() # --------------- Basic setting --------------- ## Basic parameters self.in_dims = in_dims self.strides = strides self.num_queries = num_queries self.num_classes = num_classes self.eps = 1e-2 self.aux_loss = aux_loss ## Transformer parameters self.num_heads = num_heads self.num_layers = num_layers self.num_levels = num_levels self.num_points = num_points self.ffn_dim = ffn_dim self.dropout = dropout self.act_type = act_type self.return_intermediate = return_intermediate ## Denoising parameters self.num_denoising = num_denoising self.label_noise_ratio = label_noise_ratio self.box_noise_scale = box_noise_scale self.learnt_init_query = learnt_init_query # --------------- Network setting --------------- ## Input proj layers self.input_proj_layers = nn.ModuleList( BasicConv(in_dims[i], hidden_dim, kernel_size=1, act_type=None, norm_type="BN") for i in range(num_levels) ) ## Deformable transformer decoder self.decoder = DeformableTransformerDecoder( d_model = hidden_dim, num_heads = num_heads, num_layers = num_layers, num_levels = num_levels, num_points = num_points, ffn_dim = ffn_dim, dropout = dropout, act_type = act_type, return_intermediate = return_intermediate ) ## Detection head for Encoder self.enc_output = nn.Sequential( nn.Linear(hidden_dim, hidden_dim), nn.LayerNorm(hidden_dim) ) self.enc_class_head = nn.Linear(hidden_dim, num_classes) self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3) ## Detection head for Decoder self.dec_class_head = nn.ModuleList([ nn.Linear(hidden_dim, num_classes) for _ in range(num_layers) ]) self.dec_bbox_head = nn.ModuleList([ MLP(hidden_dim, hidden_dim, 4, num_layers=3) for _ in range(num_layers) ]) ## Object query if learnt_init_query: self.tgt_embed = nn.Embedding(num_queries, hidden_dim) self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2) ## Denoising part if num_denoising > 0: self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes) self._reset_parameters() def _reset_parameters(self): # class and bbox head init prior_prob = 0.01 cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob)) nn.init.constant_(self.enc_class_head.bias, cls_bias_init) nn.init.constant_(self.enc_bbox_head.layers[-1].weight, 0.) nn.init.constant_(self.enc_bbox_head.layers[-1].bias, 0.) for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head): nn.init.constant_(cls_.bias, cls_bias_init) nn.init.constant_(reg_.layers[-1].weight, 0.) nn.init.constant_(reg_.layers[-1].bias, 0.) nn.init.xavier_uniform_(self.enc_output[0].weight) if self.learnt_init_query: nn.init.xavier_uniform_(self.tgt_embed.weight) nn.init.xavier_uniform_(self.query_pos_head.layers[0].weight) nn.init.xavier_uniform_(self.query_pos_head.layers[1].weight) @torch.jit.unused def _set_aux_loss(self, outputs_class, outputs_coord): # this is a workaround to make torchscript happy, as torchscript # doesn't support dictionary with non-homogeneous values, such # as a dict having both a Tensor and a list. return [{'pred_logits': a, 'pred_boxes': b} for a, b in zip(outputs_class, outputs_coord)] def generate_anchors(self, spatial_shapes, grid_size=0.05): anchors = [] for lvl, (h, w) in enumerate(spatial_shapes): grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W, 2] grid_xy = torch.stack([grid_x, grid_y], dim=-1).float() valid_WH = torch.as_tensor([w, h]).float() grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl) # [H, W, 4] -> [1, N, 4], N=HxW anchors.append(torch.cat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4)) # List[L, 1, N_i, 4] -> [1, N, 4], N=N_0 + N_1 + N_2 + ... anchors = torch.cat(anchors, dim=1) valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True) anchors = torch.log(anchors / (1 - anchors)) # Equal to operation: anchors = torch.masked_fill(anchors, ~valid_mask, torch.as_tensor(float("inf"))) anchors = torch.where(valid_mask, anchors, torch.inf) return anchors, valid_mask def get_encoder_input(self, feats): # get projection features proj_feats = [self.input_proj_layers[i](feat) for i, feat in enumerate(feats)] # get encoder inputs feat_flatten = [] spatial_shapes = [] level_start_index = [0, ] for i, feat in enumerate(proj_feats): _, _, h, w = feat.shape spatial_shapes.append([h, w]) # [l], start index of each level level_start_index.append(h * w + level_start_index[-1]) # [B, C, H, W] -> [B, N, C], N=HxW feat_flatten.append(feat.flatten(2).permute(0, 2, 1).contiguous()) # [B, N, C], N = N_0 + N_1 + ... feat_flatten = torch.cat(feat_flatten, dim=1) level_start_index.pop() return (feat_flatten, spatial_shapes, level_start_index) def get_decoder_input(self, memory, spatial_shapes, denoising_class=None, denoising_bbox_unact=None): bs, _, _ = memory.shape # Prepare input for decoder anchors, valid_mask = self.generate_anchors(spatial_shapes) anchors = anchors.to(memory.device) valid_mask = valid_mask.to(memory.device) # Process encoder's output memory = torch.where(valid_mask, memory, torch.as_tensor(0., device=memory.device)) output_memory = self.enc_output(memory) # Head for encoder's output : [bs, num_quries, c] enc_outputs_class = self.enc_class_head(output_memory) enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors # Topk proposals from encoder's output topk = self.num_queries topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # [bs, num_queries] enc_topk_logits = torch.gather( enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes)) # [bs, num_queries, nc] reference_points_unact = torch.gather( enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4)) # [bs, num_queries, 4] enc_topk_bboxes = F.sigmoid(reference_points_unact) if denoising_bbox_unact is not None: reference_points_unact = torch.cat( [denoising_bbox_unact, reference_points_unact], dim=1) # Extract region features if self.learnt_init_query: # [num_queries, c] -> [b, num_queries, c] target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1) else: # [num_queries, c] -> [b, num_queries, c] target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1])) target = target.detach() if denoising_class is not None: target = torch.cat([denoising_class, target], dim=1) return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits def forward(self, feats, targets=None): # input projection and embedding memory, spatial_shapes, _ = self.get_encoder_input(feats) # prepare denoising training if self.training and self.num_denoising > 0: denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \ get_contrastive_denoising_training_group(targets, \ self.num_classes, self.num_queries, self.denoising_class_embed, num_denoising=self.num_denoising, label_noise_ratio=self.label_noise_ratio, box_noise_scale=self.box_noise_scale, ) else: denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \ self.get_decoder_input( memory, spatial_shapes, denoising_class, denoising_bbox_unact) # decoder out_bboxes, out_logits = self.decoder(target, init_ref_points_unact, memory, spatial_shapes, self.dec_bbox_head, self.dec_class_head, self.query_pos_head, attn_mask) if self.training and dn_meta is not None: dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2) dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2) out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]} if self.training and self.aux_loss: out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1]) out['aux_outputs'].extend(self._set_aux_loss([enc_topk_logits], [enc_topk_bboxes])) if self.training and dn_meta is not None: out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes) out['dn_meta'] = dn_meta return out ## RTDETR's Transformer for Instance Segmentation task (not complete yet) class MaskRTDetrTransformer(RTDetrTransformer): def __init__(self, # basic parameters in_dims :List = [256, 512, 1024], hidden_dim :int = 256, strides :List = [8, 16, 32], num_classes :int = 80, num_queries :int = 300, # transformer parameters num_heads :int = 8, num_layers :int = 1, num_levels :int = 3, num_points :int = 4, ffn_dim :int = 1024, dropout :float = 0.1, act_type :str = "relu", return_intermediate :bool = False, # Denoising parameters num_denoising :int = 100, label_noise_ratio :float = 0.5, box_noise_scale :float = 1.0, learnt_init_query :bool = False, aux_loss :bool = True ): super().__init__() def forward(self, feats, targets=None): return