yjh0410 9 mesi fa
parent
commit
0e82b2c681

+ 29 - 6
yolo/models/yolov10/yolov10.py

@@ -6,12 +6,18 @@ import torch.nn as nn
 from .yolov10_backbone import Yolov10Backbone
 from .yolov10_pafpn    import Yolov10PaFPN
 from .yolov10_head     import Yolov10DetHead
+from .yolov10_pred     import Yolov10DetPredLayer
 
+# --------------- External components ---------------
 from utils.misc import multiclass_nms
 
+
 # YOLOv10
 class Yolov10(nn.Module):
-    def __init__(self, cfg, is_val = False) -> None:
+    def __init__(self,
+                 cfg,
+                 is_val = False,
+                 ) -> None:
         super(Yolov10, self).__init__()
         # ---------------------- Basic setting ----------------------
         self.cfg = cfg
@@ -21,10 +27,17 @@ class Yolov10(nn.Module):
         self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
         self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
         self.no_multi_labels  = False if is_val else True
-
+        
+        # ---------------------- Network Parameters ----------------------
+        ## Backbone
         self.backbone = Yolov10Backbone(cfg)
-        self.pafpn    = Yolov10PaFPN(cfg, self.backbone.feat_dims[-3:])
-        self.det_head = Yolov10DetHead(cfg, self.pafpn.out_dims)
+        self.pyramid_feat_dims = self.backbone.feat_dims[-3:]
+        ## Neck: PaFPN
+        self.fpn = Yolov10PaFPN(cfg, self.backbone.feat_dims)
+        ## Head
+        self.head = Yolov10DetHead(cfg, self.fpn.out_dims)
+        ## Pred
+        self.pred = Yolov10DetPredLayer(cfg, self.head.cls_head_dim, self.head.reg_head_dim)
 
     def post_process(self, cls_preds, box_preds):
         """
@@ -105,9 +118,19 @@ class Yolov10(nn.Module):
         return bboxes, scores, labels
     
     def forward(self, x):
+        # ---------------- Backbone ----------------
         pyramid_feats = self.backbone(x)
-        pyramid_feats = self.pafpn(pyramid_feats)
-        outputs = self.det_head(pyramid_feats)
+        # ---------------- Neck: SPP ----------------
+        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+        # ---------------- Neck: PaFPN ----------------
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # ---------------- Heads ----------------
+        cls_feats, reg_feats = self.head(pyramid_feats)
+
+        # ---------------- Preds ----------------
+        outputs = self.pred(cls_feats, reg_feats)
         outputs['image_size'] = [x.shape[2], x.shape[3]]
 
         if not self.training:

+ 109 - 128
yolo/models/yolov10/yolov10_head.py

@@ -1,163 +1,144 @@
-import math
 import torch
 import torch.nn as nn
 from typing import List
 
 try:
-    from .modules import ConvModule, DflLayer
+    from .modules import ConvModule
 except:
-    from  modules import ConvModule, DflLayer
+    from  modules import ConvModule
+
+
+# -------------------- Detection Head --------------------
+class DetHead(nn.Module):
+    def __init__(self,
+                 in_dim       :int  = 256,
+                 cls_head_dim :int  = 256,
+                 reg_head_dim :int  = 256,
+                 num_cls_head :int  = 2,
+                 num_reg_head :int  = 2,
+                 ):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dim = in_dim
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        
+        # --------- Network Parameters ----------
+        ## classification head
+        cls_feats = []
+        self.cls_head_dim = cls_head_dim
+        for i in range(num_cls_head):
+            if i == 0:
+                cls_feats.append(ConvModule(in_dim, in_dim, kernel_size=3, stride=1, groups=in_dim))
+                cls_feats.append(ConvModule(in_dim, self.cls_head_dim, kernel_size=1))
+            else:
+                cls_feats.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, stride=1, groups=self.cls_head_dim))
+                cls_feats.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, stride=1))
+        
+        ## bbox regression head
+        reg_feats = []
+        self.reg_head_dim = reg_head_dim
+        for i in range(num_reg_head):
+            if i == 0:
+                reg_feats.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, stride=1))
+            else:
+                reg_feats.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, stride=1))
+        
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
 
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
 
-# YOLOv10 detection head
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+## Multi-level Detection Head
 class Yolov10DetHead(nn.Module):
-    def __init__(self, cfg, fpn_dims: List = [64, 128, 245]):
+    def __init__(self, cfg, in_dims: List = [256, 512, 1024]):
         super().__init__()
-        self.out_stride = cfg.out_stride
-        self.reg_max = cfg.reg_max
-        self.num_classes = cfg.num_classes
-
-        self.cls_dim = max(fpn_dims[0], min(cfg.num_classes, 128))
-        self.reg_dim = max(fpn_dims[0]//4, 16, 4*cfg.reg_max)
-
-        # classification head
-        self.cls_heads = nn.ModuleList(
-            nn.Sequential(
-                nn.Sequential(ConvModule(dim, dim, kernel_size=3, stride=1, groups=dim),
-                              ConvModule(dim, self.cls_dim, kernel_size=1)),
-                nn.Sequential(ConvModule(self.cls_dim, self.cls_dim, kernel_size=3, stride=1, groups=self.cls_dim),
-                              ConvModule(self.cls_dim, self.cls_dim, kernel_size=1)),
-                nn.Conv2d(self.cls_dim, cfg.num_classes, kernel_size=1),
-            )
-            for dim in fpn_dims
-        )
-
-        # bbox regression head
-        self.reg_heads = nn.ModuleList(
-            nn.Sequential(
-                ConvModule(dim, self.reg_dim, kernel_size=3, stride=1),
-                ConvModule(self.reg_dim, self.reg_dim, kernel_size=3, stride=1),
-                nn.Conv2d(self.reg_dim, 4*cfg.reg_max, kernel_size=1),
-            )
-            for dim in fpn_dims
-        )
-
-        # DFL layer for decoding bbox
-        self.dfl_layer = DflLayer(cfg.reg_max)
-        for p in self.dfl_layer.parameters():
-            p.requires_grad = False
-
-        self.init_bias()
-        
-    def init_bias(self):
-        # cls pred
-        for i, m in enumerate(self.cls_heads):
-            b = m[-1].bias.view(1, -1)
-            b.data.fill_(math.log(5 / self.num_classes / (640. / self.out_stride[i]) ** 2))
-            m[-1].bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-
-        # reg pred
-        for m in self.reg_heads:
-            b = m[-1].bias.view(-1, )
-            b.data.fill_(1.0)
-            m[-1].bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-            
-            w = m[-1].weight
-            w.data.fill_(0.)
-            m[-1].weight = torch.nn.Parameter(w, requires_grad=True)
-
-    def generate_anchors(self, fmp_size, level):
+        self.num_levels = len(cfg.out_stride)
+        ## ----------- Network Parameters -----------
+        self.multi_level_heads = nn.ModuleList(
+            [DetHead(in_dim       = in_dims[level],
+                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 128)),
+                     reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
+                     num_cls_head = cfg.num_cls_head,
+                     num_reg_head = cfg.num_reg_head,
+                     ) for level in range(self.num_levels)])
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
+        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
+
+    def forward(self, feats):
         """
-            fmp_size: (List) [H, W]
+            feats: List[(Tensor)] [[B, C, H, W], ...]
         """
-        # generate grid cells
-        fmp_h, fmp_w = fmp_size
-        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
-        # [H, W, 2] -> [HW, 2]
-        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
-        anchors += 0.5  # add center offset
-        anchors *= self.out_stride[level]
-
-        return anchors
-
-    def forward(self, fpn_feats):
-        anchors = []
-        strides = []
-        cls_preds = []
-        reg_preds = []
-        box_preds = []
-
-        for lvl, (feat, cls_head, reg_head) in enumerate(zip(fpn_feats, self.cls_heads, self.reg_heads)):
-            bs, c, h, w = feat.size()
-            device = feat.device
-            
-            # Prediction
-            cls_pred = cls_head(feat)
-            reg_pred = reg_head(feat)
-
-            # [bs, c, h, w] -> [bs, c, hw] -> [bs, hw, c]
-            cls_pred = cls_pred.flatten(2).permute(0, 2, 1).contiguous()
-            reg_pred = reg_pred.flatten(2).permute(0, 2, 1).contiguous()
-
-            # anchor points: [M, 2]
-            anchor = self.generate_anchors(fmp_size=[h, w], level=lvl).to(device)
-            stride = torch.ones_like(anchor[..., :1]) * self.out_stride[lvl]
-
-            # Decode bbox coords
-            box_pred = self.dfl_layer(reg_pred, anchor[None], self.out_stride[lvl])
-
-            # collect results
-            anchors.append(anchor)
-            strides.append(stride)
-            cls_preds.append(cls_pred)
-            reg_preds.append(reg_pred)
-            box_preds.append(box_pred)
-
-        # output dict
-        outputs = {"pred_cls":       cls_preds,        # List(Tensor) [B, M, C]
-                   "pred_reg":       reg_preds,        # List(Tensor) [B, M, 4*(reg_max)]
-                   "pred_box":       box_preds,        # List(Tensor) [B, M, 4]
-                   "anchors":        anchors,          # List(Tensor) [M, 2]
-                   "stride_tensor": strides,          # List(Tensor) [M, 1]
-                   "strides":        self.out_stride,  # List(Int) = [8, 16, 32]
-                   }
-
-        return outputs
+        cls_feats = []
+        reg_feats = []
+        for feat, head in zip(feats, self.multi_level_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+
+        return cls_feats, reg_feats
 
 
 if __name__=='__main__':
+    import time
     from thop import profile
-
+    
     # YOLOv10-Base config
     class Yolov10BaseConfig(object):
         def __init__(self) -> None:
             # ---------------- Model config ----------------
-            self.width    = 0.50
+            self.width    = 0.25
             self.depth    = 0.34
             self.ratio    = 2.0
-            
             self.reg_max  = 16
             self.out_stride = [8, 16, 32]
             self.max_stride = 32
             self.num_levels = 3
-            self.num_classes = 80
+            ## Head
+            self.num_cls_head = 2
+            self.num_reg_head = 2
 
     cfg = Yolov10BaseConfig()
+    cfg.num_classes = 80
 
-    # Random data
-    fpn_dims = [256, 512, 512]
-    x = [torch.randn(1, fpn_dims[0], 80, 80),
-         torch.randn(1, fpn_dims[1], 40, 40),
-         torch.randn(1, fpn_dims[2], 20, 20)]
+    # Build a head
+    fpn_dims = [64, 128, 256]
+    pyramid_feats = [torch.randn(1, fpn_dims[0], 80, 80),
+                     torch.randn(1, fpn_dims[1], 40, 40),
+                     torch.randn(1, fpn_dims[2], 20, 20)]
+    head = Yolov10DetHead(cfg, fpn_dims)
 
-    # Neck model
-    model = Yolov10DetHead(cfg, fpn_dims)
 
     # Inference
-    outputs = model(x)
-
-    print('============ FLOPs & Params ===========')
-    flops, params = profile(model, inputs=(x, ), verbose=False)
+    t0 = time.time()
+    cls_feats, reg_feats = head(pyramid_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print("====== Yolov10 Head output ======")
+    for level, (cls_f, reg_f) in enumerate(zip(cls_feats, reg_feats)):
+        print("- Level-{} : ".format(level), cls_f.shape, reg_f.shape)
+
+    flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
+    print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
     print('Params : {:.2f} M'.format(params / 1e6))
     

+ 205 - 0
yolo/models/yolov10/yolov10_pred.py

@@ -0,0 +1,205 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# -------------------- Detection Pred Layer --------------------
+class DetPredLayer(nn.Module):
+    def __init__(self,
+                 cls_dim     :int = 256,
+                 reg_dim     :int = 256,
+                 stride      :int = 32,
+                 reg_max     :int = 16,
+                 num_classes :int = 80,
+                 num_coords  :int = 4):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.stride = stride
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.reg_max = reg_max
+        self.num_classes = num_classes
+        self.num_coords = num_coords
+
+        # --------- Network Parameters ----------
+        self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
+        self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1)                
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # cls pred bias
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred bias
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = self.reg_pred.weight
+        w.data.fill_(0.)
+        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+    def generate_anchors(self, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.stride
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat):
+        # pred
+        cls_pred = self.cls_pred(cls_feat)
+        reg_pred = self.reg_pred(reg_feat)
+
+        # generate anchor boxes: [M, 4]
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+        # stride tensor: [M, 1]
+        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
+        
+        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+        
+        # output dict
+        outputs = {"pred_cls": cls_pred,            # List(Tensor) [B, M, C]
+                   "pred_reg": reg_pred,            # List(Tensor) [B, M, 4*(reg_max)]
+                   "anchors": anchors,              # List(Tensor) [M, 2]
+                   "strides": self.stride,          # List(Int) = [8, 16, 32]
+                   "stride_tensor": stride_tensor   # List(Tensor) [M, 1]
+                   }
+
+        return outputs
+
+class Yolov10DetPredLayer(nn.Module):
+    def __init__(self, cfg, cls_dim: int, reg_dim: int):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.num_levels = len(cfg.out_stride)
+
+        # ----------- Network Parameters -----------
+        ## pred layers
+        self.multi_level_preds = nn.ModuleList(
+            [DetPredLayer(cls_dim     = cls_dim,
+                          reg_dim     = reg_dim,
+                          stride      = cfg.out_stride[level],
+                          reg_max     = cfg.reg_max,
+                          num_classes = cfg.num_classes,
+                          num_coords  = 4 * cfg.reg_max)
+                          for level in range(self.num_levels)
+                          ])
+        ## proj conv
+        proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
+        self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
+        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
+
+    def forward(self, cls_feats, reg_feats):
+        all_anchors = []
+        all_strides = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        for level in range(self.num_levels):
+            # -------------- Single-level prediction --------------
+            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
+
+            # -------------- Decode bbox --------------
+            B, M = outputs["pred_reg"].shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
+            delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.cfg.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+            ## tlbr -> xyxy
+            x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.cfg.out_stride[level]
+            x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.cfg.out_stride[level]
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+            # collect results
+            all_cls_preds.append(outputs["pred_cls"])
+            all_reg_preds.append(outputs["pred_reg"])
+            all_box_preds.append(box_pred)
+            all_anchors.append(outputs["anchors"])
+            all_strides.append(outputs["stride_tensor"])
+        
+        # output dict
+        outputs = {"pred_cls":      all_cls_preds,         # List(Tensor) [B, M, C]
+                   "pred_reg":      all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box":      all_box_preds,         # List(Tensor) [B, M, 4]
+                   "anchors":       all_anchors,           # List(Tensor) [M, 2]
+                   "stride_tensor": all_strides,           # List(Tensor) [M, 1]
+                   "strides":       self.cfg.out_stride,   # List(Int) = [8, 16, 32]
+                   }
+
+        return outputs
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv10-Base config
+    class Yolov10BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 0.25
+            self.depth    = 0.34
+            self.ratio    = 2.0
+            self.reg_max  = 16
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.num_levels = 3
+            ## Head
+
+    cfg = Yolov10BaseConfig()
+    cfg.num_classes = 80
+    cls_dim = 80
+    reg_dim = 64
+    # Build a pred layer
+    pred = Yolov10DetPredLayer(cfg, cls_dim, reg_dim)
+
+    # Inference
+    cls_feats = [torch.randn(1, cls_dim, 80, 80),
+                 torch.randn(1, cls_dim, 40, 40),
+                 torch.randn(1, cls_dim, 20, 20),]
+    reg_feats = [torch.randn(1, reg_dim, 80, 80),
+                 torch.randn(1, reg_dim, 40, 40),
+                 torch.randn(1, reg_dim, 20, 20),]
+    t0 = time.time()
+    output = pred(cls_feats, reg_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('====== Pred output ======= ')
+    pred_cls = output["pred_cls"]
+    pred_reg = output["pred_reg"]
+    pred_box = output["pred_box"]
+    anchors  = output["anchors"]
+    
+    for level in range(cfg.num_levels):
+        print("- Level-{} : classification   -> {}".format(level, pred_cls[level].shape))
+        print("- Level-{} : delta regression -> {}".format(level, pred_reg[level].shape))
+        print("- Level-{} : bbox regression  -> {}".format(level, pred_box[level].shape))
+        print("- Level-{} : anchor boxes     -> {}".format(level, anchors[level].shape))
+
+    flops, params = profile(pred, inputs=(cls_feats, reg_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))