|
|
@@ -44,6 +44,38 @@ class RTDETR(nn.Module):
|
|
|
|
|
|
|
|
|
# ---------------------- Basic Functions ----------------------
|
|
|
+ def position_embedding(self, x, temperature=10000):
|
|
|
+ hs, ws = x.shape[-2:]
|
|
|
+ device = x.device
|
|
|
+ num_pos_feats = x.shape[1] // 2
|
|
|
+ scale = 2 * 3.141592653589793
|
|
|
+
|
|
|
+ # generate xy coord mat
|
|
|
+ y_embed, x_embed = torch.meshgrid(
|
|
|
+ [torch.arange(1, hs+1, dtype=torch.float32),
|
|
|
+ torch.arange(1, ws+1, dtype=torch.float32)])
|
|
|
+ y_embed = y_embed / (hs + 1e-6) * scale
|
|
|
+ x_embed = x_embed / (ws + 1e-6) * scale
|
|
|
+
|
|
|
+ # [H, W] -> [1, H, W]
|
|
|
+ y_embed = y_embed[None, :, :].to(device)
|
|
|
+ x_embed = x_embed[None, :, :].to(device)
|
|
|
+
|
|
|
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=device)
|
|
|
+ dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
|
|
|
+ dim_t = temperature ** (2 * dim_t_)
|
|
|
+
|
|
|
+ pos_x = torch.div(x_embed[:, :, :, None], dim_t)
|
|
|
+ pos_y = torch.div(y_embed[:, :, :, None], dim_t)
|
|
|
+ pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
|
+ pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
|
|
|
+
|
|
|
+ # [B, C, H, W]
|
|
|
+ pos_embed = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
|
|
+
|
|
|
+ return pos_embed
|
|
|
+
|
|
|
+
|
|
|
@torch.jit.unused
|
|
|
def set_aux_loss(self, outputs_class, outputs_coord):
|
|
|
# this is a workaround to make torchscript happy, as torchscript
|
|
|
@@ -57,7 +89,13 @@ class RTDETR(nn.Module):
|
|
|
@torch.no_grad()
|
|
|
def inference_single_image(self, x):
|
|
|
# -------------------- Encoder --------------------
|
|
|
- memory, memory_pos = self.encoder(x)
|
|
|
+ pyramid_feats = self.encoder(x)
|
|
|
+
|
|
|
+ # -------------------- Pos Embed --------------------
|
|
|
+ memory = torch.cat([feat.flatten(2) for feat in pyramid_feats], dim=-1)
|
|
|
+ memory_pos = torch.cat([self.position_embedding(feat).flatten(2) for feat in pyramid_feats], dim=-1)
|
|
|
+ memory = memory.permute(0, 2, 1).contiguous()
|
|
|
+ memory_pos = memory_pos.permute(0, 2, 1).contiguous()
|
|
|
|
|
|
# -------------------- Decoder --------------------
|
|
|
hs, reference = self.decoder(memory, memory_pos)
|
|
|
@@ -93,8 +131,14 @@ class RTDETR(nn.Module):
|
|
|
return self.inference_single_image(x)
|
|
|
else:
|
|
|
# -------------------- Encoder --------------------
|
|
|
- memory, memory_pos = self.encoder(x)
|
|
|
+ pyramid_feats = self.encoder(x)
|
|
|
|
|
|
+ # -------------------- Pos Embed --------------------
|
|
|
+ memory = torch.cat([feat.flatten(2) for feat in pyramid_feats], dim=-1)
|
|
|
+ memory_pos = torch.cat([self.position_embedding(feat).flatten(2) for feat in pyramid_feats], dim=-1)
|
|
|
+ memory = memory.permute(0, 2, 1).contiguous()
|
|
|
+ memory_pos = memory_pos.permute(0, 2, 1).contiguous()
|
|
|
+
|
|
|
# -------------------- Decoder --------------------
|
|
|
hs, reference = self.decoder(memory, memory_pos)
|
|
|
|