|
@@ -1,19 +1,20 @@
|
|
|
# --------------- Torch components ---------------
|
|
# --------------- Torch components ---------------
|
|
|
import torch
|
|
import torch
|
|
|
import torch.nn as nn
|
|
import torch.nn as nn
|
|
|
|
|
+import torch.nn.functional as F
|
|
|
|
|
|
|
|
# --------------- Model components ---------------
|
|
# --------------- Model components ---------------
|
|
|
-from .yolox2_backbone import build_backbone
|
|
|
|
|
-from .yolox2_neck import build_neck
|
|
|
|
|
-from .yolox2_pafpn import build_fpn
|
|
|
|
|
-from .yolox2_head import build_head
|
|
|
|
|
|
|
+from .yolovx_backbone import build_backbone
|
|
|
|
|
+from .yolovx_neck import build_neck
|
|
|
|
|
+from .yolovx_pafpn import build_fpn
|
|
|
|
|
+from .yolovx_head import build_head
|
|
|
|
|
|
|
|
# --------------- External components ---------------
|
|
# --------------- External components ---------------
|
|
|
from utils.misc import multiclass_nms
|
|
from utils.misc import multiclass_nms
|
|
|
|
|
|
|
|
|
|
|
|
|
-# YOLOX-2
|
|
|
|
|
-class YOLOX2(nn.Module):
|
|
|
|
|
|
|
+# YOLOvx
|
|
|
|
|
+class YOLOvx(nn.Module):
|
|
|
def __init__(self,
|
|
def __init__(self,
|
|
|
cfg,
|
|
cfg,
|
|
|
device,
|
|
device,
|
|
@@ -23,11 +24,12 @@ class YOLOX2(nn.Module):
|
|
|
trainable = False,
|
|
trainable = False,
|
|
|
topk = 1000,
|
|
topk = 1000,
|
|
|
deploy = False):
|
|
deploy = False):
|
|
|
- super(YOLOX2, self).__init__()
|
|
|
|
|
|
|
+ super(YOLOvx, self).__init__()
|
|
|
# ---------------------- Basic Parameters ----------------------
|
|
# ---------------------- Basic Parameters ----------------------
|
|
|
self.cfg = cfg
|
|
self.cfg = cfg
|
|
|
self.device = device
|
|
self.device = device
|
|
|
self.stride = cfg['stride']
|
|
self.stride = cfg['stride']
|
|
|
|
|
+ self.reg_max = cfg['reg_max']
|
|
|
self.num_classes = num_classes
|
|
self.num_classes = num_classes
|
|
|
self.trainable = trainable
|
|
self.trainable = trainable
|
|
|
self.conf_thresh = conf_thresh
|
|
self.conf_thresh = conf_thresh
|
|
@@ -37,6 +39,11 @@ class YOLOX2(nn.Module):
|
|
|
self.head_dim = round(256*cfg['width'])
|
|
self.head_dim = round(256*cfg['width'])
|
|
|
|
|
|
|
|
# ---------------------- Network Parameters ----------------------
|
|
# ---------------------- Network Parameters ----------------------
|
|
|
|
|
+ ## ----------- proj_conv ------------
|
|
|
|
|
+ self.proj = nn.Parameter(torch.linspace(0, cfg['reg_max'], cfg['reg_max']), requires_grad=False)
|
|
|
|
|
+ self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
|
|
|
|
|
+ self.proj_conv.weight = nn.Parameter(self.proj.view([1, cfg['reg_max'], 1, 1]).clone().detach(), requires_grad=False)
|
|
|
|
|
+
|
|
|
## ----------- Backbone -----------
|
|
## ----------- Backbone -----------
|
|
|
self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
|
|
self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
|
|
|
|
|
|
|
@@ -57,7 +64,7 @@ class YOLOX2(nn.Module):
|
|
|
for _ in range(len(self.stride))
|
|
for _ in range(len(self.stride))
|
|
|
])
|
|
])
|
|
|
self.reg_preds = nn.ModuleList(
|
|
self.reg_preds = nn.ModuleList(
|
|
|
- [nn.Conv2d(self.head_dim, 4, kernel_size=1)
|
|
|
|
|
|
|
+ [nn.Conv2d(self.head_dim, 4*cfg['reg_max'], kernel_size=1)
|
|
|
for _ in range(len(self.stride))
|
|
for _ in range(len(self.stride))
|
|
|
])
|
|
])
|
|
|
|
|
|
|
@@ -156,36 +163,45 @@ class YOLOX2(nn.Module):
|
|
|
reg_pred = self.reg_preds[level](reg_feat)
|
|
reg_pred = self.reg_preds[level](reg_feat)
|
|
|
|
|
|
|
|
# anchors: [M, 2]
|
|
# anchors: [M, 2]
|
|
|
- fmp_size = cls_feat.shape[-2:]
|
|
|
|
|
- anchors = self.generate_anchors(level, fmp_size)
|
|
|
|
|
|
|
+ B, _, H, W = cls_feat.size()
|
|
|
|
|
+ anchors = self.generate_anchors(level, [H, W])
|
|
|
|
|
|
|
|
- # [1, C, H, W] -> [H, W, C] -> [M, C]
|
|
|
|
|
- cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
|
|
|
|
|
- reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
|
|
|
|
|
-
|
|
|
|
|
- # decode bbox
|
|
|
|
|
- ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
|
|
|
|
|
- wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
|
|
|
|
|
- pred_x1y1 = ctr_pred - wh_pred * 0.5
|
|
|
|
|
- pred_x2y2 = ctr_pred + wh_pred * 0.5
|
|
|
|
|
- box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
|
|
|
|
|
-
|
|
|
|
|
- all_cls_preds.append(cls_pred)
|
|
|
|
|
- all_box_preds.append(box_pred)
|
|
|
|
|
|
|
+ # process preds
|
|
|
|
|
+ cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
|
|
|
|
|
+ reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
|
|
|
|
|
+
|
|
|
|
|
+ # ----------------------- Decode bbox -----------------------
|
|
|
|
|
+ B, M = reg_pred.shape[:2]
|
|
|
|
|
+ # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
|
|
|
|
|
+ reg_pred = reg_pred.reshape([B, M, 4, self.reg_max])
|
|
|
|
|
+ # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
|
|
|
|
|
+ reg_pred = reg_pred.permute(0, 3, 2, 1).contiguous()
|
|
|
|
|
+ # [B, reg_max, 4, M] -> [B, 1, 4, M]
|
|
|
|
|
+ reg_pred = self.proj_conv(F.softmax(reg_pred, dim=1))
|
|
|
|
|
+ # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
|
|
|
|
|
+ reg_pred = reg_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
|
|
|
|
|
+ ## tlbr -> xyxy
|
|
|
|
|
+ x1y1_pred = anchors[None] - reg_pred[..., :2] * self.stride[level]
|
|
|
|
|
+ x2y2_pred = anchors[None] + reg_pred[..., 2:] * self.stride[level]
|
|
|
|
|
+ box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
|
|
|
|
|
+
|
|
|
|
|
+ # collect preds
|
|
|
|
|
+ all_cls_preds.append(cls_pred[0])
|
|
|
|
|
+ all_box_preds.append(box_pred[0])
|
|
|
|
|
|
|
|
if self.deploy:
|
|
if self.deploy:
|
|
|
|
|
+ # no post process
|
|
|
cls_preds = torch.cat(all_cls_preds, dim=0)
|
|
cls_preds = torch.cat(all_cls_preds, dim=0)
|
|
|
- box_preds = torch.cat(all_box_preds, dim=0)
|
|
|
|
|
- scores = cls_preds.sigmoid()
|
|
|
|
|
- bboxes = box_preds
|
|
|
|
|
|
|
+ box_pred = torch.cat(all_box_preds, dim=0)
|
|
|
# [n_anchors_all, 4 + C]
|
|
# [n_anchors_all, 4 + C]
|
|
|
- outputs = torch.cat([bboxes, scores], dim=-1)
|
|
|
|
|
|
|
+ outputs = torch.cat([box_pred, cls_preds.sigmoid()], dim=-1)
|
|
|
|
|
|
|
|
return outputs
|
|
return outputs
|
|
|
|
|
+
|
|
|
else:
|
|
else:
|
|
|
# post process
|
|
# post process
|
|
|
bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
|
|
bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
|
|
|
-
|
|
|
|
|
|
|
+
|
|
|
return bboxes, scores, labels
|
|
return bboxes, scores, labels
|
|
|
|
|
|
|
|
|
|
|
|
@@ -208,7 +224,9 @@ class YOLOX2(nn.Module):
|
|
|
|
|
|
|
|
# ---------------- Preds ----------------
|
|
# ---------------- Preds ----------------
|
|
|
all_anchors = []
|
|
all_anchors = []
|
|
|
|
|
+ all_strides = []
|
|
|
all_cls_preds = []
|
|
all_cls_preds = []
|
|
|
|
|
+ all_reg_preds = []
|
|
|
all_box_preds = []
|
|
all_box_preds = []
|
|
|
for level, (cls_feat, reg_feat) in enumerate(zip(cls_feats, reg_feats)):
|
|
for level, (cls_feat, reg_feat) in enumerate(zip(cls_feats, reg_feats)):
|
|
|
# prediction
|
|
# prediction
|
|
@@ -216,29 +234,44 @@ class YOLOX2(nn.Module):
|
|
|
reg_pred = self.reg_preds[level](reg_feat)
|
|
reg_pred = self.reg_preds[level](reg_feat)
|
|
|
|
|
|
|
|
B, _, H, W = cls_pred.size()
|
|
B, _, H, W = cls_pred.size()
|
|
|
- fmp_size = [H, W]
|
|
|
|
|
# generate anchor boxes: [M, 4]
|
|
# generate anchor boxes: [M, 4]
|
|
|
- anchors = self.generate_anchors(level, fmp_size)
|
|
|
|
|
|
|
+ anchors = self.generate_anchors(level, [H, W])
|
|
|
|
|
+ # stride tensor: [M, 1]
|
|
|
|
|
+ stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
|
|
|
|
|
|
|
|
- # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
|
|
|
|
|
|
|
+ # process preds
|
|
|
cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
|
|
cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
|
|
|
- reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
|
|
|
|
|
-
|
|
|
|
|
- # decode bbox
|
|
|
|
|
- ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
|
|
|
|
|
- wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
|
|
|
|
|
- pred_x1y1 = ctr_pred - wh_pred * 0.5
|
|
|
|
|
- pred_x2y2 = ctr_pred + wh_pred * 0.5
|
|
|
|
|
- box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
|
|
|
|
|
-
|
|
|
|
|
|
|
+ reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
|
|
|
|
|
+
|
|
|
|
|
+ # ----------------------- Decode bbox -----------------------
|
|
|
|
|
+ B, M = reg_pred.shape[:2]
|
|
|
|
|
+ # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
|
|
|
|
|
+ reg_pred_ = reg_pred.reshape([B, M, 4, self.reg_max])
|
|
|
|
|
+ # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
|
|
|
|
|
+ reg_pred_ = reg_pred_.permute(0, 3, 2, 1).contiguous()
|
|
|
|
|
+ # [B, reg_max, 4, M] -> [B, 1, 4, M]
|
|
|
|
|
+ reg_pred_ = self.proj_conv(F.softmax(reg_pred_, dim=1))
|
|
|
|
|
+ # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
|
|
|
|
|
+ reg_pred_ = reg_pred_.view(B, 4, M).permute(0, 2, 1).contiguous()
|
|
|
|
|
+ ## tlbr -> xyxy
|
|
|
|
|
+ x1y1_pred = anchors[None] - reg_pred_[..., :2] * self.stride[level]
|
|
|
|
|
+ x2y2_pred = anchors[None] + reg_pred_[..., 2:] * self.stride[level]
|
|
|
|
|
+ box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
|
|
|
|
|
+
|
|
|
|
|
+ # collect preds
|
|
|
all_cls_preds.append(cls_pred)
|
|
all_cls_preds.append(cls_pred)
|
|
|
|
|
+ all_reg_preds.append(reg_pred)
|
|
|
all_box_preds.append(box_pred)
|
|
all_box_preds.append(box_pred)
|
|
|
all_anchors.append(anchors)
|
|
all_anchors.append(anchors)
|
|
|
|
|
+ all_strides.append(stride_tensor)
|
|
|
|
|
|
|
|
# output dict
|
|
# output dict
|
|
|
outputs = {"pred_cls": all_cls_preds, # List(Tensor) [B, M, C]
|
|
outputs = {"pred_cls": all_cls_preds, # List(Tensor) [B, M, C]
|
|
|
|
|
+ "pred_reg": all_reg_preds, # List(Tensor) [B, M, 4*(reg_max)]
|
|
|
"pred_box": all_box_preds, # List(Tensor) [B, M, 4]
|
|
"pred_box": all_box_preds, # List(Tensor) [B, M, 4]
|
|
|
- "anchors": all_anchors, # List(Tensor) [B, M, 2]
|
|
|
|
|
- 'strides': self.stride} # List(Int) [8, 16, 32]
|
|
|
|
|
-
|
|
|
|
|
|
|
+ "anchors": all_anchors, # List(Tensor) [M, 2]
|
|
|
|
|
+ "strides": self.stride, # List(Int) = [8, 16, 32]
|
|
|
|
|
+ "stride_tensor": all_strides # List(Tensor) [M, 1]
|
|
|
|
|
+ }
|
|
|
|
|
+
|
|
|
return outputs
|
|
return outputs
|