|
@@ -37,38 +37,45 @@ class DetectHead(nn.Module):
|
|
|
nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
|
|
nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
+ def inverse_sigmoid(self, x):
|
|
|
|
|
+ x = x.clamp(min=0, max=1)
|
|
|
|
|
+ return torch.log(x.clamp(min=1e-5)/(1 - x).clamp(min=1e-5))
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+ def decode_bbox(self, outputs_coords):
|
|
|
|
|
+ ## cxcywh -> xyxy
|
|
|
|
|
+ x1y1_pred = outputs_coords[..., :2] - outputs_coords[..., 2:] * 0.5
|
|
|
|
|
+ x2y2_pred = outputs_coords[..., :2] + outputs_coords[..., 2:] * 0.5
|
|
|
|
|
+ box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
|
|
|
|
|
+
|
|
|
|
|
+ return box_pred
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
def forward(self, hs, reference, multi_layer=False):
|
|
def forward(self, hs, reference, multi_layer=False):
|
|
|
if multi_layer:
|
|
if multi_layer:
|
|
|
- ## class embed
|
|
|
|
|
- outputs_class = torch.stack([layer_cls_embed(layer_hs) for
|
|
|
|
|
- layer_cls_embed, layer_hs in zip(self.class_embed, hs)])
|
|
|
|
|
- ## Bbox embed
|
|
|
|
|
|
|
+ # class embed
|
|
|
|
|
+ outputs_class = torch.stack([
|
|
|
|
|
+ layer_cls_embed(layer_hs) for layer_cls_embed, layer_hs in zip(self.class_embed, hs)])
|
|
|
|
|
+ # Bbox embed
|
|
|
outputs_coords = []
|
|
outputs_coords = []
|
|
|
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):
|
|
for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):
|
|
|
layer_delta_unsig = layer_bbox_embed(layer_hs)
|
|
layer_delta_unsig = layer_bbox_embed(layer_hs)
|
|
|
- # ---------- start <inverse sigmoid> ----------
|
|
|
|
|
- layer_ref_sig = layer_ref_sig.clamp(min=0, max=1)
|
|
|
|
|
- layer_ref_sig_1 = layer_ref_sig.clamp(min=1e-5)
|
|
|
|
|
- layer_ref_sig_2 = (1 - layer_ref_sig).clamp(min=1e-5)
|
|
|
|
|
- layer_ref_sig = torch.log(layer_ref_sig_1/layer_ref_sig_2)
|
|
|
|
|
- # ---------- end <inverse sigmoid> ----------
|
|
|
|
|
|
|
+ layer_ref_sig = self.inverse_sigmoid(layer_ref_sig)
|
|
|
layer_outputs_unsig = layer_delta_unsig + layer_ref_sig
|
|
layer_outputs_unsig = layer_delta_unsig + layer_ref_sig
|
|
|
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
|
|
layer_outputs_unsig = layer_outputs_unsig.sigmoid()
|
|
|
outputs_coords.append(layer_outputs_unsig)
|
|
outputs_coords.append(layer_outputs_unsig)
|
|
|
else:
|
|
else:
|
|
|
- ## class embed
|
|
|
|
|
|
|
+ # class embed
|
|
|
outputs_class = self.class_embed[-1](hs[-1])
|
|
outputs_class = self.class_embed[-1](hs[-1])
|
|
|
- ## bbox embed
|
|
|
|
|
|
|
+ # bbox embed
|
|
|
delta_unsig = self.bbox_embed[-1](hs[-1])
|
|
delta_unsig = self.bbox_embed[-1](hs[-1])
|
|
|
ref_sig = reference[-2]
|
|
ref_sig = reference[-2]
|
|
|
- ## ---------- start <inverse sigmoid> ----------
|
|
|
|
|
- ref_sig = ref_sig.clamp(min=0, max=1)
|
|
|
|
|
- ref_sig_1 = ref_sig.clamp(min=1e-5)
|
|
|
|
|
- ref_sig_2 = (1 - ref_sig).clamp(min=1e-5)
|
|
|
|
|
- ref_sig = torch.log(ref_sig_1/ref_sig_2)
|
|
|
|
|
- ## ---------- end <inverse sigmoid> ----------
|
|
|
|
|
|
|
+ ref_sig = self.inverse_sigmoid(ref_sig)
|
|
|
outputs_unsig = delta_unsig + ref_sig
|
|
outputs_unsig = delta_unsig + ref_sig
|
|
|
outputs_coords = outputs_unsig.sigmoid()
|
|
outputs_coords = outputs_unsig.sigmoid()
|
|
|
|
|
+ # decode bbox
|
|
|
|
|
+ outputs_coords = self.decode_bbox(outputs_coords)
|
|
|
|
|
+
|
|
|
|
|
|
|
|
return outputs_class, outputs_coords
|
|
return outputs_class, outputs_coords
|
|
|
|
|
|