Quellcode durchsuchen

rename YOLOX-Plus to ARTDet

yjh0410 vor 2 Jahren
Ursprung
Commit
97d3d9874d

+ 4 - 4
config/__init__.py

@@ -22,7 +22,7 @@ from .yolov4_config import yolov4_cfg
 from .yolov5_config import yolov5_cfg
 from .yolov7_config import yolov7_cfg
 from .yolox_config import yolox_cfg
-from .yolox_plus_config import yolox_plus_cfg
+from .artdet_config import artdet_cfg
 
 
 def build_model_config(args):
@@ -49,9 +49,9 @@ def build_model_config(args):
     # YOLOX
     elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
         cfg = yolox_cfg[args.model]
-    # YOLOX-Plus
-    elif args.model in ['yolox_plus_n', 'yolox_plus_s', 'yolox_plus_m', 'yolox_plus_l', 'yolox_plus_x']:
-        cfg = yolox_plus_cfg[args.model]
+    # ARTDet
+    elif args.model in ['artdet_n', 'artdet_s', 'artdet_m', 'artdet_l', 'artdet_x']:
+        cfg = artdet_cfg[args.model]
     return cfg
 
 

+ 2 - 2
config/yolox_plus_config.py → config/artdet_config.py

@@ -1,8 +1,8 @@
 # YOLOX-Plus Config
 
 
-yolox_plus_cfg = {
-    'yolox_plus_n':{
+artdet_cfg = {
+    'artdet_n':{
         # ---------------- Model config ----------------
         ## Backbone
         'backbone': 'elannet',

+ 4 - 4
models/detectors/__init__.py

@@ -9,7 +9,7 @@ from .yolov4.build import build_yolov4
 from .yolov5.build import build_yolov5
 from .yolov7.build import build_yolov7
 from .yolox.build import build_yolox
-from .yolox_plus.build import build_yolox_plus
+from .artdet.build import build_artdet
 
 
 # build object detector
@@ -47,9 +47,9 @@ def build_model(args,
     elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
         model, criterion = build_yolox(
             args, model_cfg, device, num_classes, trainable, deploy)
-    # YOLOX-Plus  
-    elif args.model in ['yolox_plus_n', 'yolox_plus_s', 'yolox_plus_m', 'yolox_plus_l', 'yolox_plus_x']:
-        model, criterion = build_yolox_plus(
+    # ARTDet  
+    elif args.model in ['artdet_n', 'artdet_s', 'artdet_m', 'artdet_l', 'artdet_x']:
+        model, criterion = build_artdet(
             args, model_cfg, device, num_classes, trainable, deploy)
 
     if trainable:

+ 14 - 67
models/detectors/yolox_plus/yolox_plus.py → models/detectors/artdet/artdet.py

@@ -4,17 +4,17 @@ import torch.nn as nn
 import torch.nn.functional as F
 
 # --------------- Model components ---------------
-from .yolox_plus_backbone import build_backbone
-from .yolox_plus_neck import build_neck
-from .yolox_plus_pafpn import build_fpn
-from .yolox_plus_head import build_head
+from .artdet_backbone import build_backbone
+from .artdet_neck import build_neck
+from .artdet_pafpn import build_fpn
+from .artdet_head import build_head
 
 # --------------- External components ---------------
 from utils.misc import multiclass_nms
 
 
-# YOLOX-Plus
-class YoloxPlus(nn.Module):
+# Anchor-free Real-Time Detection
+class ARTDet(nn.Module):
     def __init__(self, 
                  cfg,
                  device, 
@@ -24,7 +24,7 @@ class YoloxPlus(nn.Module):
                  trainable = False, 
                  topk = 1000,
                  deploy = False):
-        super(YoloxPlus, self).__init__()
+        super(ARTDet, self).__init__()
         # ---------------------- Basic Parameters ----------------------
         self.cfg = cfg
         self.device = device
@@ -38,11 +38,6 @@ class YoloxPlus(nn.Module):
         self.deploy = deploy
         
         # ---------------------- Network Parameters ----------------------
-        ## ----------- proj_conv ------------
-        self.proj = nn.Parameter(torch.linspace(0, cfg['reg_max'], cfg['reg_max']), requires_grad=False)
-        self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
-        self.proj_conv.weight = nn.Parameter(self.proj.view([1, cfg['reg_max'], 1, 1]).clone().detach(), requires_grad=False)
-
         ## ----------- Backbone -----------
         self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
 
@@ -135,43 +130,19 @@ class YoloxPlus(nn.Module):
     def inference_single_image(self, x):
         # ---------------- Backbone ----------------
         pyramid_feats = self.backbone(x)
-
-        # ---------------- Neck: SPP ----------------
         pyramid_feats[-1] = self.neck(pyramid_feats[-1])
-
-        # ---------------- Neck: PaFPN ----------------
         pyramid_feats = self.fpn(pyramid_feats)
 
         # ---------------- Heads ----------------
         all_cls_preds = []
         all_box_preds = []
         for level, (feat, head) in enumerate(zip(pyramid_feats, self.det_heads)):
-            # ---------------- Pred ----------------
-            cls_pred, reg_pred = head(feat)
-
             # anchors: [M, 2]
-            B, _, H, W = reg_pred.size()
-            fmp_size = [H, W]
+            fmp_size = feat.shape[-2:]
             anchors = self.generate_anchors(level, fmp_size)
 
-            # process preds
-            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
-
-            # ----------------------- Decode bbox -----------------------
-            B, M = reg_pred.shape[:2]
-            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
-            reg_pred = reg_pred.reshape([B, M, 4, self.reg_max])
-            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
-            reg_pred = reg_pred.permute(0, 3, 2, 1).contiguous()
-            # [B, reg_max, 4, M] -> [B, 1, 4, M]
-            reg_pred = self.proj_conv(F.softmax(reg_pred, dim=1))
-            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
-            reg_pred = reg_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
-            ## tlbr -> xyxy
-            x1y1_pred = anchors[None] - reg_pred[..., :2] * self.stride[level]
-            x2y2_pred = anchors[None] + reg_pred[..., 2:] * self.stride[level]
-            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+            # pred
+            cls_pred, reg_pred, box_pred = head(feat, anchors, self.stride[level])
 
             # collect preds
             all_cls_preds.append(cls_pred[0])
@@ -200,11 +171,7 @@ class YoloxPlus(nn.Module):
         else:
             # ---------------- Backbone ----------------
             pyramid_feats = self.backbone(x)
-
-            # ---------------- Neck: SPP ----------------
             pyramid_feats[-1] = self.neck(pyramid_feats[-1])
-
-            # ---------------- Neck: PaFPN ----------------
             pyramid_feats = self.fpn(pyramid_feats)
 
             # ---------------- Heads ----------------
@@ -214,34 +181,14 @@ class YoloxPlus(nn.Module):
             all_box_preds = []
             all_strides = []
             for level, (feat, head) in enumerate(zip(pyramid_feats, self.det_heads)):
-                # ---------------- Pred ----------------
-                cls_pred, reg_pred = head(feat)
-
-                B, _, H, W = cls_pred.size()
-                fmp_size = [H, W]
-                # generate anchor boxes: [M, 4]
+                # anchors: [M, 2]
+                fmp_size = feat.shape[-2:]
                 anchors = self.generate_anchors(level, fmp_size)
                 # stride tensor: [M, 1]
                 stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]
-                
-                # process preds
-                cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
 
-                # ----------------------- Decode bbox -----------------------
-                B, M = reg_pred.shape[:2]
-                # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
-                reg_pred_ = reg_pred.reshape([B, M, 4, self.reg_max])
-                # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
-                reg_pred_ = reg_pred_.permute(0, 3, 2, 1).contiguous()
-                # [B, reg_max, 4, M] -> [B, 1, 4, M]
-                reg_pred_ = self.proj_conv(F.softmax(reg_pred_, dim=1))
-                # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
-                reg_pred_ = reg_pred_.view(B, 4, M).permute(0, 2, 1).contiguous()
-                ## tlbr -> xyxy
-                x1y1_pred = anchors[None] - reg_pred_[..., :2] * self.stride[level]
-                x2y2_pred = anchors[None] + reg_pred_[..., 2:] * self.stride[level]
-                box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+                # pred
+                cls_pred, reg_pred, box_pred = head(feat, anchors, self.stride[level])
 
                 # collect preds
                 all_cls_preds.append(cls_pred)

+ 2 - 2
models/detectors/yolox_plus/yolox_plus_backbone.py → models/detectors/artdet/artdet_backbone.py

@@ -1,9 +1,9 @@
 import torch
 import torch.nn as nn
 try:
-    from .yolox_plus_basic import Conv, ELANBlock, DownSample
+    from .artdet_basic import Conv, ELANBlock, DownSample
 except:
-    from yolox_plus_basic import Conv, ELANBlock, DownSample
+    from artdet_basic import Conv, ELANBlock, DownSample
 
 
 

+ 0 - 0
models/detectors/yolox_plus/yolox_plus_basic.py → models/detectors/artdet/artdet_basic.py


+ 32 - 4
models/detectors/yolox_plus/yolox_plus_head.py → models/detectors/artdet/artdet_head.py

@@ -1,9 +1,11 @@
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
+
 try:
-    from .yolox_plus_basic import Conv
+    from .artdet_basic import Conv
 except:
-    from yolox_plus_basic import Conv
+    from artdet_basic import Conv
 
 
 class DecoupledHead(nn.Module):
@@ -14,6 +16,7 @@ class DecoupledHead(nn.Module):
         # --------- Basic Parameters ----------
         self.in_dim = in_dim
         self.num_classes = num_classes
+        self.reg_max = cfg['reg_max']
         self.num_cls_head=cfg['num_cls_head']
         self.num_reg_head=cfg['num_reg_head']
 
@@ -61,8 +64,13 @@ class DecoupledHead(nn.Module):
         self.cls_pred = nn.Conv2d(self.cls_out_dim, num_classes, kernel_size=1) 
         self.reg_pred = nn.Conv2d(self.reg_out_dim, 4*cfg['reg_max'], kernel_size=1) 
 
+        ## ----------- proj_conv ------------
+        self.proj = nn.Parameter(torch.linspace(0, cfg['reg_max'], cfg['reg_max']), requires_grad=False)
+        self.proj_conv = nn.Conv2d(self.reg_max, 1, kernel_size=1, bias=False)
+        self.proj_conv.weight = nn.Parameter(self.proj.view([1, cfg['reg_max'], 1, 1]).clone().detach(), requires_grad=False)
+
 
-    def forward(self, x):
+    def forward(self, x, anchors, stride):
         """
             in_feats: (Tensor) [B, C, H, W]
         """
@@ -72,7 +80,27 @@ class DecoupledHead(nn.Module):
         cls_pred = self.cls_pred(cls_feats)
         reg_pred = self.reg_pred(reg_feats)
 
-        return cls_pred, reg_pred
+        # process preds
+        B = x.shape[0]
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+
+        # ----------------------- Decode bbox -----------------------
+        M = reg_pred.shape[1]
+        # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+        reg_pred_ = reg_pred.reshape([B, M, 4, self.reg_max])
+        # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+        reg_pred_ = reg_pred_.permute(0, 3, 2, 1).contiguous()
+        # [B, reg_max, 4, M] -> [B, 1, 4, M]
+        reg_pred_ = self.proj_conv(F.softmax(reg_pred_, dim=1))
+        # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+        reg_pred_ = reg_pred_.view(B, 4, M).permute(0, 2, 1).contiguous()
+        ## tlbr -> xyxy
+        x1y1_pred = anchors[None] - reg_pred_[..., :2] * stride
+        x2y2_pred = anchors[None] + reg_pred_[..., 2:] * stride
+        box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+        return cls_pred, reg_pred, box_pred
     
 
 # build detection head

+ 1 - 1
models/detectors/yolox_plus/yolox_plus_neck.py → models/detectors/artdet/artdet_neck.py

@@ -1,6 +1,6 @@
 import torch
 import torch.nn as nn
-from .yolox_plus_basic import Conv
+from .artdet_basic import Conv
 
 
 # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher

+ 1 - 1
models/detectors/yolox_plus/yolox_plus_pafpn.py → models/detectors/artdet/artdet_pafpn.py

@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from .yolox_plus_basic import (Conv, build_reduce_layer, build_downsample_layer, build_fpn_block)
+from .artdet_basic import (Conv, build_reduce_layer, build_downsample_layer, build_fpn_block)
 
 
 # YOLO-Style PaFPN

+ 5 - 5
models/detectors/yolox_plus/build.py → models/detectors/artdet/build.py

@@ -5,16 +5,16 @@ import torch
 import torch.nn as nn
 
 from .loss import build_criterion
-from .yolox_plus import YoloxPlus
+from .artdet import ARTDet
 
 
 # build object detector
-def build_yolox_plus(args, cfg, device, num_classes=80, trainable=False, deploy=False):
+def build_artdet(args, cfg, device, num_classes=80, trainable=False, deploy=False):
     print('==============================')
     print('Build {} ...'.format(args.model.upper()))
         
-    # -------------- Build YOLO --------------
-    model = YoloxPlus(
+    # -------------- Build ARTDet --------------
+    model = ARTDet(
         cfg=cfg,
         device=device, 
         num_classes=num_classes,
@@ -25,7 +25,7 @@ def build_yolox_plus(args, cfg, device, num_classes=80, trainable=False, deploy=
         deploy=deploy
         )
 
-    # -------------- Initialize YOLO --------------
+    # -------------- Initialize ARTDet --------------
     for m in model.modules():
         if isinstance(m, nn.BatchNorm2d):
             m.eps = 1e-3

+ 0 - 0
models/detectors/yolox_plus/loss.py → models/detectors/artdet/loss.py


+ 0 - 0
models/detectors/yolox_plus/matcher.py → models/detectors/artdet/matcher.py