|
|
@@ -97,6 +97,7 @@ class MultiLevelPredLayer(nn.Module):
|
|
|
all_cls_preds = []
|
|
|
all_reg_preds = []
|
|
|
all_box_preds = []
|
|
|
+ all_delta_preds = []
|
|
|
for level in range(self.num_levels):
|
|
|
# pred
|
|
|
cls_pred, reg_pred = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
|
|
|
@@ -116,21 +117,22 @@ class MultiLevelPredLayer(nn.Module):
|
|
|
# ----------------------- Decode bbox -----------------------
|
|
|
B, M = reg_pred.shape[:2]
|
|
|
# [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
|
|
|
- reg_pred_ = reg_pred.reshape([B, M, 4, self.reg_max])
|
|
|
+ delta_pred = reg_pred.reshape([B, M, 4, self.reg_max])
|
|
|
# [B, M, 4, reg_max] -> [B, reg_max, 4, M]
|
|
|
- reg_pred_ = reg_pred_.permute(0, 3, 2, 1).contiguous()
|
|
|
+ delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
|
|
|
# [B, reg_max, 4, M] -> [B, 1, 4, M]
|
|
|
- reg_pred_ = self.proj_conv(F.softmax(reg_pred_, dim=1))
|
|
|
+ delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
|
|
|
# [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
|
|
|
- reg_pred_ = reg_pred_.view(B, 4, M).permute(0, 2, 1).contiguous()
|
|
|
+ delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
|
|
|
## tlbr -> xyxy
|
|
|
- x1y1_pred = anchors[None] - reg_pred_[..., :2] * self.strides[level]
|
|
|
- x2y2_pred = anchors[None] + reg_pred_[..., 2:] * self.strides[level]
|
|
|
+ x1y1_pred = anchors[None] - delta_pred[..., :2] * self.strides[level]
|
|
|
+ x2y2_pred = anchors[None] + delta_pred[..., 2:] * self.strides[level]
|
|
|
box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
|
|
|
|
|
|
all_cls_preds.append(cls_pred)
|
|
|
all_reg_preds.append(reg_pred)
|
|
|
all_box_preds.append(box_pred)
|
|
|
+ all_delta_preds.append(delta_pred)
|
|
|
all_anchors.append(anchors)
|
|
|
all_strides.append(stride_tensor)
|
|
|
|
|
|
@@ -138,6 +140,7 @@ class MultiLevelPredLayer(nn.Module):
|
|
|
outputs = {"pred_cls": all_cls_preds, # List(Tensor) [B, M, C]
|
|
|
"pred_reg": all_reg_preds, # List(Tensor) [B, M, 4*(reg_max)]
|
|
|
"pred_box": all_box_preds, # List(Tensor) [B, M, 4]
|
|
|
+ "pred_delta": all_delta_preds, # List(Tensor) [B, M, 4]
|
|
|
"anchors": all_anchors, # List(Tensor) [M, 2]
|
|
|
"strides": self.strides, # List(Int) = [8, 16, 32]
|
|
|
"stride_tensor": all_strides # List(Tensor) [M, 1]
|