|
@@ -4,17 +4,17 @@ import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
import torch.nn.functional as F
|
|
|
|
|
|
|
|
# --------------- Model components ---------------
|
|
# --------------- Model components ---------------
|
|
|
-from .yolox_plus_backbone import build_backbone
|
|
|
|
|
-from .yolox_plus_neck import build_neck
|
|
|
|
|
-from .yolox_plus_pafpn import build_fpn
|
|
|
|
|
-from .yolox_plus_head import build_head
|
|
|
|
|
|
|
+from .artdet_backbone import build_backbone
|
|
|
|
|
+from .artdet_neck import build_neck
|
|
|
|
|
+from .artdet_pafpn import build_fpn
|
|
|
|
|
+from .artdet_head import build_head
|
|
|
|
|
|
|
|
# --------------- External components ---------------
|
|
# --------------- External components ---------------
|
|
|
from utils.misc import multiclass_nms
|
|
from utils.misc import multiclass_nms
|
|
|
|
|
|
|
|
|
|
|
|
|
-# YOLOX-Plus
|
|
|
|
|
-class YoloxPlus(nn.Module):
|
|
|
|
|
|
|
+# Anchor-free Real-Time Detection
|
|
|
|
|
+class ARTDet(nn.Module):
|
|
|
def __init__(self,
|
|
def __init__(self,
|
|
|
cfg,
|
|
cfg,
|
|
|
device,
|
|
device,
|
|
@@ -24,7 +24,7 @@ class YoloxPlus(nn.Module):
|
|
|
trainable = False,
|
|
trainable = False,
|
|
|
topk = 1000,
|
|
topk = 1000,
|
|
|
deploy = False):
|
|
deploy = False):
|
|
|
- super(YoloxPlus, self).__init__()
|
|
|
|
|
|
|
+ super(ARTDet, self).__init__()
|
|
|
# ---------------------- Basic Parameters ----------------------
|
|
# ---------------------- Basic Parameters ----------------------
|
|
|
self.cfg = cfg
|
|
self.cfg = cfg
|
|
|
self.device = device
|
|
self.device = device
|
|
@@ -38,11 +38,6 @@ class YoloxPlus(nn.Module):
|
|
|
self.deploy = deploy
|
|
self.deploy = deploy
|
|
|
|
|
|
|
|
# ---------------------- Network Parameters ----------------------
|
|
# ---------------------- Network Parameters ----------------------
|
|
|
- ## ----------- proj_conv ------------
|
|
|
|
|
- self.proj = nn.Parameter(torch.linspace(0, cfg['reg_max'], cfg['reg_max']), requires_grad=False)
|
|
|
|
|
- self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
|
|
|
|
|
- self.proj_conv.weight = nn.Parameter(self.proj.view([1, cfg['reg_max'], 1, 1]).clone().detach(), requires_grad=False)
|
|
|
|
|
-
|
|
|
|
|
## ----------- Backbone -----------
|
|
## ----------- Backbone -----------
|
|
|
self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
|
|
self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
|
|
|
|
|
|
|
@@ -135,43 +130,19 @@ class YoloxPlus(nn.Module):
|
|
|
def inference_single_image(self, x):
|
|
def inference_single_image(self, x):
|
|
|
# ---------------- Backbone ----------------
|
|
# ---------------- Backbone ----------------
|
|
|
pyramid_feats = self.backbone(x)
|
|
pyramid_feats = self.backbone(x)
|
|
|
-
|
|
|
|
|
- # ---------------- Neck: SPP ----------------
|
|
|
|
|
pyramid_feats[-1] = self.neck(pyramid_feats[-1])
|
|
pyramid_feats[-1] = self.neck(pyramid_feats[-1])
|
|
|
-
|
|
|
|
|
- # ---------------- Neck: PaFPN ----------------
|
|
|
|
|
pyramid_feats = self.fpn(pyramid_feats)
|
|
pyramid_feats = self.fpn(pyramid_feats)
|
|
|
|
|
|
|
|
# ---------------- Heads ----------------
|
|
# ---------------- Heads ----------------
|
|
|
all_cls_preds = []
|
|
all_cls_preds = []
|
|
|
all_box_preds = []
|
|
all_box_preds = []
|
|
|
for level, (feat, head) in enumerate(zip(pyramid_feats, self.det_heads)):
|
|
for level, (feat, head) in enumerate(zip(pyramid_feats, self.det_heads)):
|
|
|
- # ---------------- Pred ----------------
|
|
|
|
|
- cls_pred, reg_pred = head(feat)
|
|
|
|
|
-
|
|
|
|
|
# anchors: [M, 2]
|
|
# anchors: [M, 2]
|
|
|
- B, _, H, W = reg_pred.size()
|
|
|
|
|
- fmp_size = [H, W]
|
|
|
|
|
|
|
+ fmp_size = feat.shape[-2:]
|
|
|
anchors = self.generate_anchors(level, fmp_size)
|
|
anchors = self.generate_anchors(level, fmp_size)
|
|
|
|
|
|
|
|
- # process preds
|
|
|
|
|
- cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
|
|
|
|
|
- reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
|
|
|
|
|
-
|
|
|
|
|
- # ----------------------- Decode bbox -----------------------
|
|
|
|
|
- B, M = reg_pred.shape[:2]
|
|
|
|
|
- # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
|
|
|
|
|
- reg_pred = reg_pred.reshape([B, M, 4, self.reg_max])
|
|
|
|
|
- # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
|
|
|
|
|
- reg_pred = reg_pred.permute(0, 3, 2, 1).contiguous()
|
|
|
|
|
- # [B, reg_max, 4, M] -> [B, 1, 4, M]
|
|
|
|
|
- reg_pred = self.proj_conv(F.softmax(reg_pred, dim=1))
|
|
|
|
|
- # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
|
|
|
|
|
- reg_pred = reg_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
|
|
|
|
|
- ## tlbr -> xyxy
|
|
|
|
|
- x1y1_pred = anchors[None] - reg_pred[..., :2] * self.stride[level]
|
|
|
|
|
- x2y2_pred = anchors[None] + reg_pred[..., 2:] * self.stride[level]
|
|
|
|
|
- box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
|
|
|
|
|
|
|
+ # pred
|
|
|
|
|
+ cls_pred, reg_pred, box_pred = head(feat, anchors, self.stride[level])
|
|
|
|
|
|
|
|
# collect preds
|
|
# collect preds
|
|
|
all_cls_preds.append(cls_pred[0])
|
|
all_cls_preds.append(cls_pred[0])
|
|
@@ -200,11 +171,7 @@ class YoloxPlus(nn.Module):
|
|
|
else:
|
|
else:
|
|
|
# ---------------- Backbone ----------------
|
|
# ---------------- Backbone ----------------
|
|
|
pyramid_feats = self.backbone(x)
|
|
pyramid_feats = self.backbone(x)
|
|
|
-
|
|
|
|
|
- # ---------------- Neck: SPP ----------------
|
|
|
|
|
pyramid_feats[-1] = self.neck(pyramid_feats[-1])
|
|
pyramid_feats[-1] = self.neck(pyramid_feats[-1])
|
|
|
-
|
|
|
|
|
- # ---------------- Neck: PaFPN ----------------
|
|
|
|
|
pyramid_feats = self.fpn(pyramid_feats)
|
|
pyramid_feats = self.fpn(pyramid_feats)
|
|
|
|
|
|
|
|
# ---------------- Heads ----------------
|
|
# ---------------- Heads ----------------
|
|
@@ -214,34 +181,14 @@ class YoloxPlus(nn.Module):
|
|
|
all_box_preds = []
|
|
all_box_preds = []
|
|
|
all_strides = []
|
|
all_strides = []
|
|
|
for level, (feat, head) in enumerate(zip(pyramid_feats, self.det_heads)):
|
|
for level, (feat, head) in enumerate(zip(pyramid_feats, self.det_heads)):
|
|
|
- # ---------------- Pred ----------------
|
|
|
|
|
- cls_pred, reg_pred = head(feat)
|
|
|
|
|
-
|
|
|
|
|
- B, _, H, W = cls_pred.size()
|
|
|
|
|
- fmp_size = [H, W]
|
|
|
|
|
- # generate anchor boxes: [M, 4]
|
|
|
|
|
|
|
+ # anchors: [M, 2]
|
|
|
|
|
+ fmp_size = feat.shape[-2:]
|
|
|
anchors = self.generate_anchors(level, fmp_size)
|
|
anchors = self.generate_anchors(level, fmp_size)
|
|
|
# stride tensor: [M, 1]
|
|
# stride tensor: [M, 1]
|
|
|
stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
|
|
stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
|
|
|
-
|
|
|
|
|
- # process preds
|
|
|
|
|
- cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
|
|
|
|
|
- reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
|
|
|
|
|
|
|
|
|
|
- # ----------------------- Decode bbox -----------------------
|
|
|
|
|
- B, M = reg_pred.shape[:2]
|
|
|
|
|
- # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
|
|
|
|
|
- reg_pred_ = reg_pred.reshape([B, M, 4, self.reg_max])
|
|
|
|
|
- # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
|
|
|
|
|
- reg_pred_ = reg_pred_.permute(0, 3, 2, 1).contiguous()
|
|
|
|
|
- # [B, reg_max, 4, M] -> [B, 1, 4, M]
|
|
|
|
|
- reg_pred_ = self.proj_conv(F.softmax(reg_pred_, dim=1))
|
|
|
|
|
- # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
|
|
|
|
|
- reg_pred_ = reg_pred_.view(B, 4, M).permute(0, 2, 1).contiguous()
|
|
|
|
|
- ## tlbr -> xyxy
|
|
|
|
|
- x1y1_pred = anchors[None] - reg_pred_[..., :2] * self.stride[level]
|
|
|
|
|
- x2y2_pred = anchors[None] + reg_pred_[..., 2:] * self.stride[level]
|
|
|
|
|
- box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
|
|
|
|
|
|
|
+ # pred
|
|
|
|
|
+ cls_pred, reg_pred, box_pred = head(feat, anchors, self.stride[level])
|
|
|
|
|
|
|
|
# collect preds
|
|
# collect preds
|
|
|
all_cls_preds.append(cls_pred)
|
|
all_cls_preds.append(cls_pred)
|