|
|
@@ -2,7 +2,7 @@ import math
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
-from torch.nn.init import constant_, xavier_uniform_, uniform_
|
|
|
+from torch.nn.init import constant_, xavier_uniform_, uniform_, normal_
|
|
|
from typing import List
|
|
|
|
|
|
try:
|
|
|
@@ -96,7 +96,7 @@ class RTDETRTransformer(nn.Module):
|
|
|
)
|
|
|
|
|
|
## Deformable transformer decoder
|
|
|
- self.transformer_decoder = DeformableTransformerDecoder(
|
|
|
+ self.decoder = DeformableTransformerDecoder(
|
|
|
d_model = hidden_dim,
|
|
|
num_heads = num_heads,
|
|
|
num_layers = num_layers,
|
|
|
@@ -116,7 +116,7 @@ class RTDETRTransformer(nn.Module):
|
|
|
self.enc_class_head = nn.Linear(hidden_dim, num_classes)
|
|
|
self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
|
|
|
|
|
|
- ## Detection head for Decoder
|
|
|
+ ## Detection head for Decoder
|
|
|
self.dec_class_head = nn.ModuleList([
|
|
|
nn.Linear(hidden_dim, num_classes)
|
|
|
for _ in range(num_layers)
|
|
|
@@ -126,18 +126,18 @@ class RTDETRTransformer(nn.Module):
|
|
|
for _ in range(num_layers)
|
|
|
])
|
|
|
|
|
|
- ## Denoising part
|
|
|
- self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim)
|
|
|
-
|
|
|
## Object query
|
|
|
if learnt_init_query:
|
|
|
self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
|
|
|
self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
|
|
|
|
|
|
+ ## Denoising part
|
|
|
+ self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim)
|
|
|
+
|
|
|
self._reset_parameters()
|
|
|
|
|
|
def _reset_parameters(self):
|
|
|
- def _linear_init(module):
|
|
|
+ def linear_init_(module):
|
|
|
bound = 1 / math.sqrt(module.weight.shape[0])
|
|
|
uniform_(module.weight, -bound, bound)
|
|
|
if hasattr(module, "bias") and module.bias is not None:
|
|
|
@@ -146,17 +146,17 @@ class RTDETRTransformer(nn.Module):
|
|
|
# class and bbox head init
|
|
|
prior_prob = 0.01
|
|
|
cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob))
|
|
|
- _linear_init(self.enc_class_head)
|
|
|
+ linear_init_(self.enc_class_head)
|
|
|
constant_(self.enc_class_head.bias, cls_bias_init)
|
|
|
constant_(self.enc_bbox_head.layers[-1].weight, 0.)
|
|
|
constant_(self.enc_bbox_head.layers[-1].bias, 0.)
|
|
|
for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head):
|
|
|
- _linear_init(cls_)
|
|
|
+ linear_init_(cls_)
|
|
|
constant_(cls_.bias, cls_bias_init)
|
|
|
constant_(reg_.layers[-1].weight, 0.)
|
|
|
constant_(reg_.layers[-1].bias, 0.)
|
|
|
|
|
|
- _linear_init(self.enc_output[0])
|
|
|
+ linear_init_(self.enc_output[0])
|
|
|
xavier_uniform_(self.enc_output[0].weight)
|
|
|
if self.learnt_init_query:
|
|
|
xavier_uniform_(self.tgt_embed.weight)
|
|
|
@@ -164,21 +164,25 @@ class RTDETRTransformer(nn.Module):
|
|
|
xavier_uniform_(self.query_pos_head.layers[1].weight)
|
|
|
for l in self.input_proj_layers:
|
|
|
xavier_uniform_(l.conv.weight)
|
|
|
+ normal_(self.denoising_class_embed.weight)
|
|
|
|
|
|
def generate_anchors(self, spatial_shapes, grid_size=0.05):
|
|
|
anchors = []
|
|
|
for lvl, (h, w) in enumerate(spatial_shapes):
|
|
|
grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w))
|
|
|
+ # [H, W, 2]
|
|
|
grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
|
|
|
|
|
|
valid_WH = torch.as_tensor([w, h]).float()
|
|
|
grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
|
|
|
wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
|
|
|
- anchors.append(torch.cat([grid_xy, wh], -1).reshape([-1, h * w, 4]))
|
|
|
-
|
|
|
- anchors = torch.cat(anchors, 1)
|
|
|
+ # [H, W, 4] -> [1, N, 4], N=HxW
|
|
|
+ anchors.append(torch.cat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4))
|
|
|
+ # List[L, 1, N_i, 4] -> [1, N, 4], N=N_0 + N_1 + N_2 + ...
|
|
|
+ anchors = torch.cat(anchors, dim=1)
|
|
|
valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
|
|
|
anchors = torch.log(anchors / (1 - anchors))
|
|
|
+ # Equal to operation: anchors = torch.masked_fill(anchors, ~valid_mask, torch.as_tensor(float("inf")))
|
|
|
anchors = torch.where(valid_mask, anchors, torch.as_tensor(float("inf")))
|
|
|
|
|
|
return anchors, valid_mask
|
|
|
@@ -193,15 +197,14 @@ class RTDETRTransformer(nn.Module):
|
|
|
level_start_index = [0, ]
|
|
|
for i, feat in enumerate(proj_feats):
|
|
|
_, _, h, w = feat.shape
|
|
|
- # [b, c, h, w] -> [b, h*w, c]
|
|
|
- feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
|
|
|
- # [num_levels, 2]
|
|
|
spatial_shapes.append([h, w])
|
|
|
# [l], start index of each level
|
|
|
level_start_index.append(h * w + level_start_index[-1])
|
|
|
+ # [B, C, H, W] -> [B, N, C], N=HxW
|
|
|
+ feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
|
|
|
|
|
|
- # [b, l, c]
|
|
|
- feat_flatten = torch.cat(feat_flatten, 1)
|
|
|
+ # [B, N, C], N = N_0 + N_1 + ...
|
|
|
+ feat_flatten = torch.cat(feat_flatten, dim=1)
|
|
|
level_start_index.pop()
|
|
|
|
|
|
return (feat_flatten, spatial_shapes, level_start_index)
|
|
|
@@ -212,20 +215,26 @@ class RTDETRTransformer(nn.Module):
|
|
|
denoising_class=None,
|
|
|
denoising_bbox_unact=None):
|
|
|
bs, _, _ = memory.shape
|
|
|
- # prepare input for decoder
|
|
|
+ # Prepare input for decoder
|
|
|
anchors, valid_mask = self.generate_anchors(spatial_shapes)
|
|
|
anchors = anchors.to(memory.device)
|
|
|
valid_mask = valid_mask.to(memory.device)
|
|
|
- memory = torch.where(valid_mask, memory, torch.as_tensor(0.))
|
|
|
+
|
|
|
+ # Process encoder's output
|
|
|
+ memory = torch.where(valid_mask, memory, torch.as_tensor(0., device=memory.device))
|
|
|
output_memory = self.enc_output(memory)
|
|
|
|
|
|
- # [bs, num_quries, c]
|
|
|
+ # Head for encoder's output : [bs, num_quries, c]
|
|
|
enc_outputs_class = self.enc_class_head(output_memory)
|
|
|
enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
|
|
|
|
|
|
+ # Topk proposals from encoder's output
|
|
|
topk = self.num_queries
|
|
|
- topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # [bs, topk]
|
|
|
- reference_points_unact = torch.gather(enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4)) # [bs, topk, 4]
|
|
|
+ topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] # [bs, num_queries]
|
|
|
+ enc_topk_logits = torch.gather(
|
|
|
+ enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes)) # [bs, num_queries, nc]
|
|
|
+ reference_points_unact = torch.gather(
|
|
|
+ enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4)) # [bs, num_queries, 4]
|
|
|
enc_topk_bboxes = F.sigmoid(reference_points_unact)
|
|
|
|
|
|
if denoising_bbox_unact is not None:
|
|
|
@@ -233,12 +242,13 @@ class RTDETRTransformer(nn.Module):
|
|
|
[denoising_bbox_unact, reference_points_unact], 1)
|
|
|
if self.training:
|
|
|
reference_points_unact = reference_points_unact.detach()
|
|
|
- enc_topk_logits = torch.gather(enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes)) # [bs, topk, nc]
|
|
|
|
|
|
- # extract region features
|
|
|
+ # Extract region features
|
|
|
if self.learnt_init_query:
|
|
|
+ # [num_queries, c] -> [b, num_queries, c]
|
|
|
target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
|
|
|
else:
|
|
|
+ # [num_queries, c] -> [b, num_queries, c]
|
|
|
target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
|
|
|
if self.training:
|
|
|
target = target.detach()
|
|
|
@@ -269,14 +279,14 @@ class RTDETRTransformer(nn.Module):
|
|
|
memory, spatial_shapes, denoising_class, denoising_bbox_unact)
|
|
|
|
|
|
# decoder
|
|
|
- out_bboxes, out_logits = self.transformer_decoder(target,
|
|
|
- init_ref_points_unact,
|
|
|
- memory,
|
|
|
- spatial_shapes,
|
|
|
- self.dec_bbox_head,
|
|
|
- self.dec_class_head,
|
|
|
- self.query_pos_head,
|
|
|
- attn_mask)
|
|
|
+ out_bboxes, out_logits = self.decoder(target,
|
|
|
+ init_ref_points_unact,
|
|
|
+ memory,
|
|
|
+ spatial_shapes,
|
|
|
+ self.dec_bbox_head,
|
|
|
+ self.dec_class_head,
|
|
|
+ self.query_pos_head,
|
|
|
+ attn_mask)
|
|
|
|
|
|
return out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta
|
|
|
|