yjh0410 1 år sedan
förälder
incheckning
4d786199ea

+ 3 - 0
yolo/config/__init__.py

@@ -2,6 +2,7 @@
 from .yolov1_config  import build_yolov1_config
 from .yolov2_config  import build_yolov2_config
 from .yolov3_config  import build_yolov3_config
+from .yolov4_config  import build_yolov4_config
 from .yolov5_config  import build_yolov5_config
 from .yolox_config   import build_yolox_config
 from .yolov6_config  import build_yolov6_config
@@ -21,6 +22,8 @@ def build_config(args):
         cfg = build_yolov2_config(args)
     elif 'yolov3' in args.model:
         cfg = build_yolov3_config(args)
+    elif 'yolov4' in args.model:
+        cfg = build_yolov4_config(args)
     elif 'yolox' in args.model:
         cfg = build_yolox_config(args)
     elif 'yolov5' in args.model:

+ 174 - 0
yolo/config/yolov4_config.py

@@ -0,0 +1,174 @@
+# yolo Config
+
+
+def build_yolov4_config(args):
+    if   args.model == 'yolov4_n':
+        return Yolov4NConfig()
+    elif args.model == 'yolov4_s':
+        return Yolov4SConfig()
+    elif args.model == 'yolov4_m':
+        return Yolov4MConfig()
+    elif args.model == 'yolov4_l':
+        return Yolov4LConfig()
+    elif args.model == 'yolov4_x':
+        return Yolov4XConfig()
+    else:
+        raise NotImplementedError("No config for model: {}".format(args.model))
+    
+# YOLOv4-Base config
+class Yolov4BaseConfig(object):
+    def __init__(self) -> None:
+        # ---------------- Model config ----------------
+        self.width    = 1.0
+        self.depth    = 1.0
+        self.out_stride = [8, 16, 32]
+        self.max_stride = 32
+        self.model_scale = "b"
+        ## Backbone
+        self.use_pretrained = True
+        ## Head
+        self.head_dim       = 256
+        self.num_cls_head   = 2
+        self.num_reg_head   = 2
+        self.anchor_size    = {0: [[10, 13],   [16, 30],   [33, 23]],
+                               1: [[30, 61],   [62, 45],   [59, 119]],
+                               2: [[116, 90],  [156, 198], [373, 326]]}
+
+        # ---------------- Post-process config ----------------
+        ## Post process
+        self.val_topk = 1000
+        self.val_conf_thresh = 0.001
+        self.val_nms_thresh  = 0.7
+        self.test_topk = 100
+        self.test_conf_thresh = 0.3
+        self.test_nms_thresh  = 0.5
+
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        self.iou_thresh = 0.5
+        ## Loss weight
+        self.loss_obj = 1.0
+        self.loss_cls = 1.0
+        self.loss_box = 5.0
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema = True
+        self.ema_decay = 0.9998
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.optimizer    = 'adamw'
+        self.base_lr      = 0.001     # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
+        self.batch_size_base = 64
+        self.momentum     = 0.9
+        self.weight_decay = 0.05
+        self.clip_max_norm   = 35.0
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8
+
+        # ---------------- Lr Scheduler config ----------------
+        self.warmup_epoch = 3
+        self.lr_scheduler = "cosine"
+        self.max_epoch    = 300
+        self.eval_epoch   = 10
+        self.no_aug_epoch = 20
+
+        # ---------------- Data process config ----------------
+        self.aug_type = 'yolo'
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.15
+        self.copy_paste  = 0.0           # approximated by the YOLOX's mixup
+        self.multi_scale = [0.5, 1.25]   # multi scale: [img_size * 0.5, img_size * 1.25]
+        ## Pixel mean & std
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [255., 255., 255.]
+        ## Transforms
+        self.train_img_size = 640
+        self.test_img_size  = 640
+        self.affine_params = {
+            'degrees': 0.0,
+            'translate': 0.2,
+            'scale': [0.1, 2.0],
+            'shear': 0.0,
+            'perspective': 0.0,
+            'hsv_h': 0.015,
+            'hsv_s': 0.7,
+            'hsv_v': 0.4,
+        }
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))
+
+# YOLOv4-N
+class Yolov4NConfig(Yolov4BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.25
+        self.depth = 0.34
+        self.model_scale = "n"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0
+
+# YOLOv4-S
+class Yolov4SConfig(Yolov4BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.50
+        self.depth = 0.34
+        self.model_scale = "s"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0
+
+# YOLOv4-M
+class Yolov4MConfig(Yolov4BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.75
+        self.depth = 0.67
+        self.model_scale = "m"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 0.0
+
+# YOLOv4-L
+class Yolov4LConfig(Yolov4BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 1.0
+        self.depth = 1.0
+        self.model_scale = "l"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 0.0
+
+# YOLOv4-X
+class Yolov4XConfig(Yolov4BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 1.25
+        self.depth = 1.34
+        self.model_scale = "x"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.1
+        self.copy_paste  = 0.0

+ 4 - 0
yolo/models/__init__.py

@@ -5,6 +5,7 @@ import torch
 from .yolov1.build import build_yolov1
 from .yolov2.build import build_yolov2
 from .yolov3.build import build_yolov3
+from .yolov4.build import build_yolov4
 from .yolov5.build import build_yolov5
 from .yolox.build  import build_yolox
 from .yolov6.build import build_yolov6
@@ -26,6 +27,9 @@ def build_model(args, cfg, is_val=False):
     ## Modified YOLOv3
     elif 'yolov3' in args.model:
         model, criterion = build_yolov3(cfg, is_val)
+    ## Modified YOLOv4
+    elif 'yolov4' in args.model:
+        model, criterion = build_yolov4(cfg, is_val)
     ## Anchor-free YOLOv5
     elif 'yolox' in args.model:
         model, criterion = build_yolox(cfg, is_val)

+ 9 - 50
yolo/models/yolov4/build.py

@@ -1,65 +1,24 @@
-#!/usr/bin/env python3
-# -*- coding:utf-8 -*-
-
-import torch
 import torch.nn as nn
 
-from .loss import build_criterion
-from .yolov4 import YOLOv4
+from .loss import SetCriterion
+from .yolov4 import Yolov4
 
 
 # build object detector
-def build_yolov4(args, cfg, device, num_classes=80, trainable=False, deploy=False):
-    print('==============================')
-    print('Build {} ...'.format(args.model.upper()))
-    
-    print('==============================')
-    print('Model Configuration: \n', cfg)
-    
+def build_yolov4(cfg, is_val=False):
     # -------------- Build YOLO --------------
-    model = YOLOv4(cfg                = cfg,
-                   device             = device, 
-                   num_classes        = num_classes,
-                   trainable          = trainable,
-                   conf_thresh        = args.conf_thresh,
-                   nms_thresh         = args.nms_thresh,
-                   topk               = args.topk,
-                   deploy             = deploy,
-                   no_multi_labels    = args.no_multi_labels,
-                   nms_class_agnostic = args.nms_class_agnostic
-                   )
+    model = Yolov4(cfg, is_val)
 
     # -------------- Initialize YOLO --------------
     for m in model.modules():
         if isinstance(m, nn.BatchNorm2d):
             m.eps = 1e-3
             m.momentum = 0.03    
-    # Init bias
-    init_prob = 0.01
-    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-    # obj pred
-    for obj_pred in model.obj_preds:
-        b = obj_pred.bias.view(1, -1)
-        b.data.fill_(bias_value.item())
-        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-    # cls pred
-    for cls_pred in model.cls_preds:
-        b = cls_pred.bias.view(1, -1)
-        b.data.fill_(bias_value.item())
-        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-    # reg pred
-    for reg_pred in model.reg_preds:
-        b = reg_pred.bias.view(-1, )
-        b.data.fill_(1.0)
-        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        w = reg_pred.weight
-        w.data.fill_(0.)
-        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
-
+            
     # -------------- Build criterion --------------
     criterion = None
-    if trainable:
+    if is_val:
         # build criterion for training
-        criterion = build_criterion(cfg, device, num_classes)
-    return model, criterion
+        criterion = SetCriterion(cfg)
+        
+    return model, criterion

+ 12 - 25
yolo/models/yolov4/loss.py

@@ -1,36 +1,34 @@
 import torch
 import torch.nn.functional as F
-from .matcher import Yolov4Matcher
+
 from utils.box_ops import get_ious
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
+from .matcher import Yolov3Matcher
+
 
-class Criterion(object):
-    def __init__(self, cfg, device, num_classes=80):
+class SetCriterion(object):
+    def __init__(self, cfg):
         self.cfg = cfg
-        self.device = device
-        self.num_classes = num_classes
-        # loss weight
-        self.loss_obj_weight = cfg['loss_obj_weight']
-        self.loss_cls_weight = cfg['loss_cls_weight']
-        self.loss_box_weight = cfg['loss_box_weight']
+        self.num_classes = cfg.num_classes
+        self.loss_obj_weight = cfg.loss_obj
+        self.loss_cls_weight = cfg.loss_cls
+        self.loss_box_weight = cfg.loss_box
 
         # matcher
-        self.matcher = Yolov4Matcher(num_classes, 3, cfg['anchor_size'], cfg['iou_thresh'])
-
+        anchor_size = cfg.anchor_size[0] + cfg.anchor_size[1] + cfg.anchor_size[2]
+        self.matcher = Yolov3Matcher(cfg.num_classes, 3, anchor_size, cfg.iou_thresh)
 
     def loss_objectness(self, pred_obj, gt_obj):
         loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
 
         return loss_obj
     
-
     def loss_classes(self, pred_cls, gt_label):
         loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
 
         return loss_cls
 
-
     def loss_bboxes(self, pred_box, gt_box):
         # regression loss
         ious = get_ious(pred_box,
@@ -41,8 +39,7 @@ class Criterion(object):
 
         return loss_box, ious
 
-
-    def __call__(self, outputs, targets, epoch=0):
+    def __call__(self, outputs, targets):
         device = outputs['pred_cls'][0].device
         fpn_strides = outputs['strides']
         fmp_sizes = outputs['fmp_sizes']
@@ -99,16 +96,6 @@ class Criterion(object):
 
         return loss_dict
     
-
-def build_criterion(cfg, device, num_classes):
-    criterion = Criterion(
-        cfg=cfg,
-        device=device,
-        num_classes=num_classes
-        )
-
-    return criterion
-
     
 if __name__ == "__main__":
     pass

+ 10 - 21
yolo/models/yolov4/matcher.py

@@ -2,7 +2,7 @@ import numpy as np
 import torch
 
 
-class Yolov4Matcher(object):
+class Yolov3Matcher(object):
     def __init__(self, num_classes, num_anchors, anchor_size, iou_thresh):
         self.num_classes = num_classes
         self.num_anchors = num_anchors
@@ -12,7 +12,6 @@ class Yolov4Matcher(object):
             for anchor in anchor_size]
             )  # [KA, 4]
 
-
     def compute_iou(self, anchor_boxes, gt_box):
         """
             anchor_boxes : ndarray -> [KA, 4] (cx, cy, bw, bh).
@@ -50,7 +49,6 @@ class Yolov4Matcher(object):
         
         return iou
 
-
     @torch.no_grad()
     def __call__(self, fmp_sizes, fpn_strides, targets):
         """
@@ -138,26 +136,17 @@ class Yolov4Matcher(object):
                 # label assignment
                 for result in label_assignment_results:
                     grid_x, grid_y, level, anchor_idx = result
-                    stride = fpn_strides[level]
-                    x1s, y1s = x1 / stride, y1 / stride
-                    x2s, y2s = x2 / stride, y2 / stride
                     fmp_h, fmp_w = fmp_sizes[level]
 
-                    # 3x3 center sampling
-                    for j in range(grid_y - 1, grid_y + 2):
-                        for i in range(grid_x - 1, grid_x + 2):
-                            is_in_box = (j >= y1s and j < y2s) and (i >= x1s and i < x2s)
-                            is_valid = (j >= 0 and j < fmp_h) and (i >= 0 and i < fmp_w)
-
-                            if is_in_box and is_valid:
-                                # obj
-                                gt_objectness[level][batch_index, j, i, anchor_idx] = 1.0
-                                # cls
-                                cls_ont_hot = torch.zeros(self.num_classes)
-                                cls_ont_hot[int(gt_label)] = 1.0
-                                gt_classes[level][batch_index, j, i, anchor_idx] = cls_ont_hot
-                                # box
-                                gt_bboxes[level][batch_index, j, i, anchor_idx] = torch.as_tensor([x1, y1, x2, y2])
+                    if grid_x < fmp_w and grid_y < fmp_h:
+                        # obj
+                        gt_objectness[level][batch_index, grid_y, grid_x, anchor_idx] = 1.0
+                        # cls
+                        cls_ont_hot = torch.zeros(self.num_classes)
+                        cls_ont_hot[int(gt_label)] = 1.0
+                        gt_classes[level][batch_index, grid_y, grid_x, anchor_idx] = cls_ont_hot
+                        # box
+                        gt_bboxes[level][batch_index, grid_y, grid_x, anchor_idx] = torch.as_tensor([x1, y1, x2, y2])
 
         # [B, M, C]
         gt_objectness = torch.cat([gt.view(bs, -1, 1) for gt in gt_objectness], dim=1).float()

+ 76 - 0
yolo/models/yolov4/modules.py

@@ -0,0 +1,76 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# --------------------- Basic modules ---------------------
+class ConvModule(nn.Module):
+    def __init__(self, 
+                 in_dim,        # in channels
+                 out_dim,       # out channels 
+                 kernel_size=1, # kernel size 
+                 padding=0,     # padding
+                 stride=1,      # padding
+                 dilation=1,    # dilation
+                ):
+        super(ConvModule, self).__init__()
+        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
+        self.norm = nn.BatchNorm2d(out_dim)
+        self.act  = nn.SiLU(inplace=True)
+
+    def forward(self, x):
+        return self.act(self.norm(self.conv(x)))
+
+class YoloBottleneck(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 kernel_size  :List  = [1, 3],
+                 expansion    :float = 0.5,
+                 shortcut     :bool  = False,
+                 ) -> None:
+        super(YoloBottleneck, self).__init__()
+        inter_dim = int(out_dim * expansion)
+        # ----------------- Network setting -----------------
+        self.conv_layer1 = ConvModule(in_dim, inter_dim, kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1)
+        self.conv_layer2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.conv_layer2(self.conv_layer1(x))
+
+        return x + h if self.shortcut else h
+
+class CSPBlock(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 num_blocks :int   = 1,
+                 expansion  :float = 0.5,
+                 shortcut   :bool  = False,
+                 ):
+        super(CSPBlock, self).__init__()
+        # ---------- Basic parameters ----------
+        self.num_blocks = num_blocks
+        self.expansion = expansion
+        self.shortcut = shortcut
+        inter_dim = round(out_dim * expansion)
+        # ---------- Model parameters ----------
+        self.conv_layer_1 = ConvModule(in_dim, inter_dim, kernel_size=1)
+        self.conv_layer_2 = ConvModule(in_dim, inter_dim, kernel_size=1)
+        self.conv_layer_3 = ConvModule(inter_dim * 2, out_dim, kernel_size=1)
+        self.module = nn.Sequential(*[
+            YoloBottleneck(inter_dim,
+                           inter_dim,
+                           kernel_size = [1, 3],
+                           expansion   = 1.0,
+                           shortcut    = shortcut,
+                           ) for _ in range(num_blocks)])
+
+    def forward(self, x):
+        x1 = self.conv_layer_1(x)
+        x2 = self.module(self.conv_layer_2(x))
+        out = self.conv_layer_3(torch.cat([x1, x2], dim=1))
+
+        return out
+    

+ 63 - 218
yolo/models/yolov4/yolov4.py

@@ -1,127 +1,71 @@
+# --------------- Torch components ---------------
 import torch
 import torch.nn as nn
 
-from utils.misc import multiclass_nms
+# --------------- Model components ---------------
+from .yolov4_backbone import Yolov4Backbone
+from .yolov4_neck     import SPPF
+from .yolov4_pafpn    import Yolov4PaFPN
+from .yolov4_head     import Yolov4DetHead
+from .yolov4_pred     import Yolov4DetPredLayer
 
-from .yolov4_backbone import build_backbone
-from .yolov4_neck import build_neck
-from .yolov4_pafpn import build_fpn
-from .yolov4_head import build_head
+# --------------- External components ---------------
+from utils.misc import multiclass_nms
 
 
 # YOLOv4
-class YOLOv4(nn.Module):
+class Yolov4(nn.Module):
     def __init__(self,
                  cfg,
-                 device,
-                 num_classes=20,
-                 conf_thresh=0.01,
-                 nms_thresh=0.5,
-                 topk=100,
-                 trainable=False,
-                 deploy=False,
-                 no_multi_labels=False,
-                 nms_class_agnostic=False):
-        super(YOLOv4, self).__init__()
-        # ------------------- Basic parameters -------------------
-        self.cfg = cfg                                 # 模型配置文件
-        self.device = device                           # cuda或者是cpu
-        self.num_classes = num_classes                 # 类别的数量
-        self.trainable = trainable                     # 训练的标记
-        self.conf_thresh = conf_thresh                 # 得分阈值
-        self.nms_thresh = nms_thresh                   # NMS阈值
-        self.topk_candidates = topk                    # topk
-        self.stride = [8, 16, 32]                      # 网络的输出步长
-        self.deploy = deploy
-        self.no_multi_labels = no_multi_labels
-        self.nms_class_agnostic = nms_class_agnostic
-        # ------------------- Anchor box -------------------
-        self.num_levels = 3
-        self.num_anchors = len(cfg['anchor_size']) // self.num_levels
-        self.anchor_size = torch.as_tensor(
-            cfg['anchor_size']
-            ).float().view(self.num_levels, self.num_anchors, 2) # [S, A, 2]
+                 is_val = False,
+                 ) -> None:
+        super(Yolov4, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
+        self.no_multi_labels  = False if is_val else True
         
-        # ------------------- Network Structure -------------------
-        ## 主干网络
-        self.backbone, feats_dim = build_backbone(
-            cfg['backbone'], trainable&cfg['pretrained'])
-
-        ## 颈部网络: SPP模块
-        self.neck = build_neck(cfg, in_dim=feats_dim[-1], out_dim=feats_dim[-1])
-        feats_dim[-1] = self.neck.out_dim
-
-        ## 颈部网络: 特征金字塔
-        self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=int(256*cfg['width']))
-        self.head_dim = self.fpn.out_dim
-
-        ## 检测头
-        self.non_shared_heads = nn.ModuleList(
-            [build_head(cfg, head_dim, head_dim, num_classes) 
-            for head_dim in self.head_dim
-            ])
-
-        ## 预测层
-        self.obj_preds = nn.ModuleList(
-                            [nn.Conv2d(head.reg_out_dim, 1 * self.num_anchors, kernel_size=1) 
-                                for head in self.non_shared_heads
-                              ]) 
-        self.cls_preds = nn.ModuleList(
-                            [nn.Conv2d(head.cls_out_dim, self.num_classes * self.num_anchors, kernel_size=1) 
-                                for head in self.non_shared_heads
-                              ]) 
-        self.reg_preds = nn.ModuleList(
-                            [nn.Conv2d(head.reg_out_dim, 4 * self.num_anchors, kernel_size=1) 
-                                for head in self.non_shared_heads
-                              ])                 
-    
-
-    # ---------------------- Basic Functions ----------------------
-    ## generate anchor points
-    def generate_anchors(self, level, fmp_size):
-        """
-            fmp_size: (List) [H, W]
-        """
-        fmp_h, fmp_w = fmp_size
-        # [KA, 2]
-        anchor_size = self.anchor_size[level]
+        # ---------------------- Network Parameters ----------------------
+        ## Backbone
+        self.backbone = Yolov4Backbone(cfg)
+        self.pyramid_feat_dims = self.backbone.feat_dims[-3:]
+        ## Neck: SPP
+        self.neck = SPPF(self.pyramid_feat_dims[-1], self.pyramid_feat_dims[-1])
+        ## Neck: FPN
+        self.fpn = Yolov4PaFPN(cfg, self.pyramid_feat_dims)
+        ## Head
+        self.head = Yolov4DetHead(cfg, self.fpn.out_dims)
+        ## Pred
+        self.pred = Yolov4DetPredLayer(cfg)
 
-        # generate grid cells
-        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
-        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
-        # [HW, 2] -> [HW, KA, 2] -> [M, 2]
-        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1) + 0.5
-        anchor_xy = anchor_xy.view(-1, 2).to(self.device)
-
-        # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2] -> [M, 2]
-        anchor_wh = anchor_size.unsqueeze(0).repeat(fmp_h*fmp_w, 1, 1)
-        anchor_wh = anchor_wh.view(-1, 2).to(self.device)
-
-        anchors = torch.cat([anchor_xy, anchor_wh], dim=-1)
-
-        return anchors
-        
-    ## post-process
     def post_process(self, obj_preds, cls_preds, box_preds):
         """
+        We process predictions at each scale hierarchically
         Input:
-            cls_preds: List[np.array] -> [[M, C], ...]
-            box_preds: List[np.array] -> [[M, 4], ...]
-            obj_preds: List[np.array] -> [[M, 1], ...] or None
+            obj_preds: List[torch.Tensor] -> [[B, M, 1], ...], B=1
+            cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
+            box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
         Output:
             bboxes: np.array -> [N, 4]
             scores: np.array -> [N,]
             labels: np.array -> [N,]
         """
-        assert len(cls_preds) == self.num_levels
         all_scores = []
         all_labels = []
         all_bboxes = []
         
         for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
+            obj_pred_i = obj_pred_i[0]
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
             if self.no_multi_labels:
                 # [M,]
-                scores, labels = torch.max(torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()), dim=1)
+                scores, labels = torch.max(
+                    torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()), dim=1)
 
                 # Keep top k top scoring indices only.
                 num_topk = min(self.topk_candidates, box_pred_i.size(0))
@@ -138,10 +82,9 @@ class YOLOv4(nn.Module):
 
                 labels = labels[topk_idxs]
                 bboxes = box_pred_i[topk_idxs]
-
             else:
                 # [M, C] -> [MC,]
-                scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
+                scores_i = torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()).flatten()
 
                 # Keep top k top scoring indices only.
                 num_topk = min(self.topk_candidates, box_pred_i.size(0))
@@ -165,9 +108,9 @@ class YOLOv4(nn.Module):
             all_labels.append(labels)
             all_bboxes.append(bboxes)
 
-        scores = torch.cat(all_scores)
-        labels = torch.cat(all_labels)
-        bboxes = torch.cat(all_bboxes)
+        scores = torch.cat(all_scores, dim=0)
+        labels = torch.cat(all_labels, dim=0)
+        bboxes = torch.cat(all_bboxes, dim=0)
 
         # to cpu & numpy
         scores = scores.cpu().numpy()
@@ -176,135 +119,37 @@ class YOLOv4(nn.Module):
 
         # nms
         scores, labels, bboxes = multiclass_nms(
-            scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
-
+            scores, labels, bboxes, self.nms_thresh, self.num_classes)
+        
         return bboxes, scores, labels
     
-
-    # ---------------------- Main Process for Inference ----------------------
-    @torch.no_grad()
-    def inference(self, x):
-        # 主干网络
+    def forward(self, x):
+        # ---------------- Backbone ----------------
         pyramid_feats = self.backbone(x)
-
-        # 颈部网络
+        # ---------------- Neck: SPP ----------------
         pyramid_feats[-1] = self.neck(pyramid_feats[-1])
 
-        # 特征金字塔
+        # ---------------- Neck: PaFPN ----------------
         pyramid_feats = self.fpn(pyramid_feats)
 
-        # 检测头
-        all_anchors = []
-        all_obj_preds = []
-        all_cls_preds = []
-        all_box_preds = []
-        for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
-            cls_feat, reg_feat = head(feat)
+        # ---------------- Heads ----------------
+        cls_feats, reg_feats = self.head(pyramid_feats)
 
-            # [1, C, H, W]
-            obj_pred = self.obj_preds[level](reg_feat)
-            cls_pred = self.cls_preds[level](cls_feat)
-            reg_pred = self.reg_preds[level](reg_feat)
+        # ---------------- Preds ----------------
+        outputs = self.pred(cls_feats, reg_feats)
+        outputs['image_size'] = [x.shape[2], x.shape[3]]
 
-            # anchors: [M, 2]
-            fmp_size = cls_pred.shape[-2:]
-            anchors = self.generate_anchors(level, fmp_size)
+        if not self.training:
+            all_obj_preds = outputs['pred_obj']
+            all_cls_preds = outputs['pred_cls']
+            all_box_preds = outputs['pred_box']
 
-            # [1, AC, H, W] -> [H, W, AC] -> [M, C]
-            obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
-            cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
-            reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
-
-            # decode bbox
-            ctr_pred = (torch.sigmoid(reg_pred[..., :2]) * 3.0 - 1.5 + anchors[..., :2]) * self.stride[level]
-            wh_pred = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
-            pred_x1y1 = ctr_pred - wh_pred * 0.5
-            pred_x2y2 = ctr_pred + wh_pred * 0.5
-            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-            all_obj_preds.append(obj_pred)
-            all_cls_preds.append(cls_pred)
-            all_box_preds.append(box_pred)
-            all_anchors.append(anchors)
-
-        if self.deploy:
-            obj_preds = torch.cat(all_obj_preds, dim=0)
-            cls_preds = torch.cat(all_cls_preds, dim=0)
-            box_preds = torch.cat(all_box_preds, dim=0)
-            scores = torch.sqrt(obj_preds.sigmoid() * cls_preds.sigmoid())
-            bboxes = box_preds
-            # [n_anchors_all, 4 + C]
-            outputs = torch.cat([bboxes, scores], dim=-1)
-
-        else:
             # post process
-            bboxes, scores, labels = self.post_process(
-                all_obj_preds, all_cls_preds, all_box_preds)
+            bboxes, scores, labels = self.post_process(all_obj_preds, all_cls_preds, all_box_preds)
             outputs = {
                 "scores": scores,
                 "labels": labels,
                 "bboxes": bboxes
             }
-
+        
         return outputs
-
-
-    # ---------------------- Main Process for Training ----------------------
-    def forward(self, x):
-        if not self.trainable:
-            return self.inference(x)
-        else:
-            bs = x.shape[0]
-            # 主干网络
-            pyramid_feats = self.backbone(x)
-
-            # 颈部网络
-            pyramid_feats[-1] = self.neck(pyramid_feats[-1])
-
-            # 特征金字塔
-            pyramid_feats = self.fpn(pyramid_feats)
-
-            # 检测头
-            all_fmp_sizes = []
-            all_obj_preds = []
-            all_cls_preds = []
-            all_box_preds = []
-            for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
-                cls_feat, reg_feat = head(feat)
-
-                # [B, C, H, W]
-                obj_pred = self.obj_preds[level](reg_feat)
-                cls_pred = self.cls_preds[level](cls_feat)
-                reg_pred = self.reg_preds[level](reg_feat)
-
-                fmp_size = cls_pred.shape[-2:]
-
-                # generate anchor boxes: [M, 4]
-                anchors = self.generate_anchors(level, fmp_size)
-                
-                # [B, AC, H, W] -> [B, H, W, AC] -> [B, M, C]
-                obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, 1)
-                cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, self.num_classes)
-                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, 4)
-
-                # decode bbox
-                ctr_pred = (torch.sigmoid(reg_pred[..., :2]) * 3.0 - 1.5 + anchors[..., :2]) * self.stride[level]
-                wh_pred = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
-                pred_x1y1 = ctr_pred - wh_pred * 0.5
-                pred_x2y2 = ctr_pred + wh_pred * 0.5
-                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-                all_obj_preds.append(obj_pred)
-                all_cls_preds.append(cls_pred)
-                all_box_preds.append(box_pred)
-                all_fmp_sizes.append(fmp_size)
-
-            # output dict
-            outputs = {"pred_obj": all_obj_preds,        # List [B, M, 1]
-                       "pred_cls": all_cls_preds,        # List [B, M, C]
-                       "pred_box": all_box_preds,        # List [B, M, 4]
-                       'fmp_sizes': all_fmp_sizes,       # List
-                       'strides': self.stride,           # List
-                       }
-
-            return outputs 

+ 80 - 122
yolo/models/yolov4/yolov4_backbone.py

@@ -2,94 +2,84 @@ import torch
 import torch.nn as nn
 
 try:
-    from .yolov4_basic import Conv, CSPBlock
+    from .modules import ConvModule, CSPBlock
 except:
-    from yolov4_basic import Conv, CSPBlock
-    
-
-model_urls = {
-    "cspdarknet_tiny": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/cspdarknet_tiny.pth",
-    "cspdarknet53": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/cspdarknet53_silu.pth",
+    from  modules import ConvModule, CSPBlock
+
+# IN1K pretrained weight
+pretrained_urls = {
+    'n': None,
+    's': None,
+    'm': None,
+    'l': None,
+    'x': None,
 }
 
-# --------------------- CSPDarkNet-53 -----------------------
-## CSPDarkNet-53
-class CSPDarkNet53(nn.Module):
-    def __init__(self, act_type='silu', norm_type='BN'):
-        super(CSPDarkNet53, self).__init__()
-        self.feat_dims = [256, 512, 1024]
-
-        # P1
-        self.layer_1 = nn.Sequential(
-            Conv(3, 32, k=3, p=1, act_type=act_type, norm_type=norm_type),
-            Conv(32, 64, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
-            CSPBlock(64, 64, expand_ratio=0.5, nblocks=1, shortcut=True, act_type=act_type, norm_type=norm_type)
-        )
-        # P2
-        self.layer_2 = nn.Sequential(
-            Conv(64, 128, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
-            CSPBlock(128, 128, expand_ratio=0.5, nblocks=2, shortcut=True, act_type=act_type, norm_type=norm_type)
-        )
-        # P3
-        self.layer_3 = nn.Sequential(
-            Conv(128, 256, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
-            CSPBlock(256, 256, expand_ratio=0.5, nblocks=8, shortcut=True, act_type=act_type, norm_type=norm_type)
-        )
-        # P4
-        self.layer_4 = nn.Sequential(
-            Conv(256, 512, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
-            CSPBlock(512, 512, expand_ratio=0.5, nblocks=8, shortcut=True, act_type=act_type, norm_type=norm_type)
-        )
-        # P5
-        self.layer_5 = nn.Sequential(
-            Conv(512, 1024, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
-            CSPBlock(1024, 1024, expand_ratio=0.5, nblocks=4, shortcut=True, act_type=act_type, norm_type=norm_type)
-        )
-
-
-    def forward(self, x):
-        c1 = self.layer_1(x)
-        c2 = self.layer_2(c1)
-        c3 = self.layer_3(c2)
-        c4 = self.layer_4(c3)
-        c5 = self.layer_5(c4)
-
-        outputs = [c3, c4, c5]
-
-        return outputs
-
-## CSPDarkNet-Tiny
-class CSPDarkNetTiny(nn.Module):
-    def __init__(self, act_type='silu', norm_type='BN'):
-        super(CSPDarkNetTiny, self).__init__()
-        self.feat_dims = [64, 128, 256]
-
-        # stride = 2
-        self.layer_1 = nn.Sequential(
-            Conv(3, 16, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
-            CSPBlock(16, 16, expand_ratio=0.5, nblocks=1, shortcut=True, act_type=act_type, norm_type=norm_type)
-        )
-        # stride = 4
+# --------------------- Yolov3's Backbone -----------------------
+## Modified DarkNet
+class Yolov4Backbone(nn.Module):
+    def __init__(self, cfg):
+        super(Yolov4Backbone, self).__init__()
+        # ------------------ Basic setting ------------------
+        self.model_scale = cfg.model_scale
+        self.feat_dims = [round(64   * cfg.width),
+                          round(128  * cfg.width),
+                          round(256  * cfg.width),
+                          round(512  * cfg.width),
+                          round(1024 * cfg.width)]
+        
+        # ------------------ Network setting ------------------
+        ## P1/2
+        self.layer_1 = ConvModule(3, self.feat_dims[0], kernel_size=6, padding=2, stride=2)
+        # P2/4
         self.layer_2 = nn.Sequential(
-            Conv(16, 32, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
-            CSPBlock(32, 32, expand_ratio=0.5, nblocks=1, shortcut=True, act_type=act_type, norm_type=norm_type)
+            ConvModule(self.feat_dims[0], self.feat_dims[1], kernel_size=3, padding=1, stride=2),
+            CSPBlock(in_dim     = self.feat_dims[1],
+                     out_dim    = self.feat_dims[1],
+                     num_blocks = round(3*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     )
         )
-        # stride = 8
+        # P3/8
         self.layer_3 = nn.Sequential(
-            Conv(32, 64, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
-            CSPBlock(64, 64, expand_ratio=0.5, nblocks=3, shortcut=True, act_type=act_type, norm_type=norm_type)
+            ConvModule(self.feat_dims[1], self.feat_dims[2], kernel_size=3, padding=1, stride=2),
+            CSPBlock(in_dim     = self.feat_dims[2],
+                     out_dim    = self.feat_dims[2],
+                     num_blocks = round(9*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     )
         )
-        # stride = 16
+        # P4/16
         self.layer_4 = nn.Sequential(
-            Conv(64, 128, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
-            CSPBlock(128, 128, expand_ratio=0.5, nblocks=3, shortcut=True, act_type=act_type, norm_type=norm_type)
+            ConvModule(self.feat_dims[2], self.feat_dims[3], kernel_size=3, padding=1, stride=2),
+            CSPBlock(in_dim     = self.feat_dims[3],
+                     out_dim    = self.feat_dims[3],
+                     num_blocks = round(9*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     )
         )
-        # stride = 32
+        # P5/32
         self.layer_5 = nn.Sequential(
-            Conv(128, 256, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
-            CSPBlock(256, 256, expand_ratio=0.5, nblocks=2, shortcut=True, act_type=act_type, norm_type=norm_type)
+            ConvModule(self.feat_dims[3], self.feat_dims[4], kernel_size=3, padding=1, stride=2),
+            CSPBlock(in_dim     = self.feat_dims[4],
+                     out_dim    = self.feat_dims[4],
+                     num_blocks = round(3*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     )
         )
 
+        # Initialize all layers
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
 
     def forward(self, x):
         c1 = self.layer_1(x)
@@ -97,68 +87,36 @@ class CSPDarkNetTiny(nn.Module):
         c3 = self.layer_3(c2)
         c4 = self.layer_4(c3)
         c5 = self.layer_5(c4)
-
         outputs = [c3, c4, c5]
 
         return outputs
 
 
-# --------------------- Functions -----------------------
-def build_backbone(model_name='cspdarknet53', pretrained=False): 
-    """Constructs a cspdarknet-53 model.
-    Args:
-        pretrained (bool): If True, returns a model pre-trained on ImageNet
-    """
-    if model_name == 'cspdarknet53':
-        backbone = CSPDarkNet53(act_type='silu', norm_type='BN')
-        feat_dims = backbone.feat_dims
-    elif model_name == 'cspdarknet_tiny':
-        backbone = CSPDarkNetTiny(act_type='silu', norm_type='BN')
-        feat_dims = backbone.feat_dims
-
-    if pretrained:
-        url = model_urls[model_name]
-        if url is not None:
-            print('Loading pretrained weight ...')
-            checkpoint = torch.hub.load_state_dict_from_url(
-                url=url, map_location="cpu", check_hash=True)
-            # checkpoint state dict
-            checkpoint_state_dict = checkpoint.pop("model")
-            # model state dict
-            model_state_dict = backbone.state_dict()
-            # check
-            for k in list(checkpoint_state_dict.keys()):
-                if k in model_state_dict:
-                    shape_model = tuple(model_state_dict[k].shape)
-                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
-                    if shape_model != shape_checkpoint:
-                        checkpoint_state_dict.pop(k)
-                else:
-                    checkpoint_state_dict.pop(k)
-                    print('Unused key: ', k)
-
-            backbone.load_state_dict(checkpoint_state_dict)
-        else:
-            print('No backbone pretrained: CSPDarkNet53')        
-
-    return backbone, feat_dims
-
-
 if __name__ == '__main__':
     import time
     from thop import profile
-    model, feats = build_backbone(model_name='cspdarknet_tiny', pretrained=False)
-    x = torch.randn(1, 3, 224, 224)
+    class BaseConfig(object):
+        def __init__(self) -> None:
+            self.width = 0.5
+            self.depth = 0.34
+            self.model_scale = "s"
+            self.use_pretrained = True
+
+    cfg = BaseConfig()
+    model = Yolov4Backbone(cfg)
+    x = torch.randn(1, 3, 640, 640)
     t0 = time.time()
     outputs = model(x)
+    print(model)
     t1 = time.time()
     print('Time: ', t1 - t0)
     for out in outputs:
         print(out.shape)
 
-    x = torch.randn(1, 3, 224, 224)
+    x = torch.randn(1, 3, 640, 640)
     print('==============================')
     flops, params = profile(model, inputs=(x, ), verbose=False)
     print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
+    print('Params : {:.2f} M'.format(params / 1e6))
+    

+ 0 - 131
yolo/models/yolov4/yolov4_basic.py

@@ -1,131 +0,0 @@
-import torch
-import torch.nn as nn
-
-
-class SiLU(nn.Module):
-    """export-friendly version of nn.SiLU()"""
-
-    @staticmethod
-    def forward(x):
-        return x * torch.sigmoid(x)
-
-
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-
-
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-
-
-# Basic conv layer
-class Conv(nn.Module):
-    def __init__(self, 
-                 c1,                   # in channels
-                 c2,                   # out channels 
-                 k=1,                  # kernel size 
-                 p=0,                  # padding
-                 s=1,                  # padding
-                 d=1,                  # dilation
-                 act_type='lrelu',     # activation
-                 norm_type='BN',       # normalization
-                 depthwise=False):
-        super(Conv, self).__init__()
-        convs = []
-        add_bias = False if norm_type else True
-        if depthwise:
-            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
-            # depthwise conv
-            if norm_type:
-                convs.append(get_norm(norm_type, c1))
-            if act_type:
-                convs.append(get_activation(act_type))
-            # pointwise conv
-            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
-            if norm_type:
-                convs.append(get_norm(norm_type, c2))
-            if act_type:
-                convs.append(get_activation(act_type))
-
-        else:
-            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
-            if norm_type:
-                convs.append(get_norm(norm_type, c2))
-            if act_type:
-                convs.append(get_activation(act_type))
-            
-        self.convs = nn.Sequential(*convs)
-
-
-    def forward(self, x):
-        return self.convs(x)
-
-
-# BottleNeck
-class Bottleneck(nn.Module):
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 expand_ratio=0.5,
-                 shortcut=False,
-                 depthwise=False,
-                 act_type='silu',
-                 norm_type='BN'):
-        super(Bottleneck, self).__init__()
-        inter_dim = int(out_dim * expand_ratio)  # hidden channels            
-        self.cv1 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
-        self.cv2 = Conv(inter_dim, out_dim, k=3, p=1, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
-        self.shortcut = shortcut and in_dim == out_dim
-
-    def forward(self, x):
-        h = self.cv2(self.cv1(x))
-
-        return x + h if self.shortcut else h
-
-
-# CSP-stage block
-class CSPBlock(nn.Module):
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 expand_ratio=0.5,
-                 nblocks=1,
-                 shortcut=False,
-                 depthwise=False,
-                 act_type='silu',
-                 norm_type='BN'):
-        super(CSPBlock, self).__init__()
-        inter_dim = int(out_dim * expand_ratio)
-        self.cv1 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
-        self.cv2 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
-        self.cv3 = Conv(2 * inter_dim, out_dim, k=1, norm_type=norm_type, act_type=act_type)
-        self.m = nn.Sequential(*[
-            Bottleneck(inter_dim, inter_dim, expand_ratio=1.0, shortcut=shortcut,
-                       norm_type=norm_type, act_type=act_type, depthwise=depthwise)
-                       for _ in range(nblocks)
-                       ])
-
-    def forward(self, x):
-        x1 = self.cv1(x)
-        x2 = self.cv2(x)
-        x3 = self.m(x1)
-        out = self.cv3(torch.cat([x3, x2], dim=1))
-
-        return out
-    

+ 110 - 47
yolo/models/yolov4/yolov4_head.py

@@ -1,62 +1,47 @@
 import torch
 import torch.nn as nn
 
-from .yolov4_basic import Conv
+try:
+    from .modules import ConvModule
+except:
+    from  modules import ConvModule
 
 
-class DecoupledHead(nn.Module):
-    def __init__(self, cfg, in_dim, out_dim, num_classes=80):
+## Single-level Detection Head
+class DetHead(nn.Module):
+    def __init__(self,
+                 in_dim       :int  = 256,
+                 cls_head_dim :int  = 256,
+                 reg_head_dim :int  = 256,
+                 num_cls_head :int  = 2,
+                 num_reg_head :int  = 2,
+                 ):
         super().__init__()
-        print('==============================')
-        print('Head: Decoupled Head')
+        # --------- Basic Parameters ----------
         self.in_dim = in_dim
-        self.num_cls_head=cfg['num_cls_head']
-        self.num_reg_head=cfg['num_reg_head']
-        self.act_type=cfg['head_act']
-        self.norm_type=cfg['head_norm']
-
-        # cls head
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        
+        # --------- Network Parameters ----------
+        ## cls head
         cls_feats = []
-        self.cls_out_dim = max(out_dim, num_classes)
-        for i in range(cfg['num_cls_head']):
+        self.cls_head_dim = cls_head_dim
+        for i in range(num_cls_head):
             if i == 0:
-                cls_feats.append(
-                    Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
+                cls_feats.append(ConvModule(in_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
             else:
-                cls_feats.append(
-                    Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
-                
-        # reg head
+                cls_feats.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
+        ## reg head
         reg_feats = []
-        self.reg_out_dim = max(out_dim, 64)
-        for i in range(cfg['num_reg_head']):
+        self.reg_head_dim = reg_head_dim
+        for i in range(num_reg_head):
             if i == 0:
-                reg_feats.append(
-                    Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
+                reg_feats.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
             else:
-                reg_feats.append(
-                    Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1, 
-                        act_type=self.act_type,
-                        norm_type=self.norm_type,
-                        depthwise=cfg['head_depthwise'])
-                        )
-
+                reg_feats.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
         self.cls_feats = nn.Sequential(*cls_feats)
         self.reg_feats = nn.Sequential(*reg_feats)
 
-
     def forward(self, x):
         """
             in_feats: (Tensor) [B, C, H, W]
@@ -66,9 +51,87 @@ class DecoupledHead(nn.Module):
 
         return cls_feats, reg_feats
     
+## Multi-level Detection Head
+class Yolov4DetHead(nn.Module):
+    def __init__(self, cfg, in_dims):
+        super().__init__()
+        ## ----------- Network Parameters -----------
+        self.multi_level_heads = nn.ModuleList(
+            [DetHead(in_dim       = in_dims[level],
+                     cls_head_dim = round(cfg.head_dim * cfg.width),
+                     reg_head_dim = round(cfg.head_dim * cfg.width),
+                     num_cls_head = cfg.num_cls_head,
+                     num_reg_head = cfg.num_reg_head,
+                     ) for level in range(len(cfg.out_stride))])
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.cls_head_dim = cfg.head_dim
+        self.reg_head_dim = cfg.head_dim
+
+        # Initialize all layers
+        self.init_weights()
+
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, feats):
+        """
+            feats: List[(Tensor)] [[B, C, H, W], ...]
+        """
+        cls_feats = []
+        reg_feats = []
+        for feat, head in zip(feats, self.multi_level_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+
+        return cls_feats, reg_feats
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv4-Base config
+    class Yolov4BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 0.50
+            self.depth    = 0.34
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.num_levels = 3
+            ## Head
+            self.head_dim  = 256
+            self.num_cls_head   = 2
+            self.num_reg_head   = 2
+
+    cfg = Yolov4BaseConfig()
+    # Build a head
+    pyramid_feats = [torch.randn(1, cfg.head_dim, 80, 80),
+                     torch.randn(1, cfg.head_dim, 40, 40),
+                     torch.randn(1, cfg.head_dim, 20, 20)]
+    head = Yolov4DetHead(cfg, [cfg.head_dim]*3)
+
 
-# build detection head
-def build_head(cfg, in_dim, out_dim, num_classes=80):
-    head = DecoupledHead(cfg, in_dim, out_dim, num_classes) 
+    # Inference
+    t0 = time.time()
+    cls_feats, reg_feats = head(pyramid_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print("====== Yolov4 Head output ======")
+    for level, (cls_f, reg_f) in enumerate(zip(cls_feats, reg_feats)):
+        print("- Level-{} : ".format(level), cls_f.shape, reg_f.shape)
 
-    return head
+    flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 42 - 77
yolo/models/yolov4/yolov4_neck.py

@@ -1,6 +1,10 @@
 import torch
 import torch.nn as nn
-from .yolov4_basic import Conv
+
+try:
+    from .modules import ConvModule
+except:
+    from  modules import ConvModule
 
 
 # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
@@ -8,13 +12,24 @@ class SPPF(nn.Module):
     """
         This code referenced to https://github.com/ultralytics/yolov5
     """
-    def __init__(self, in_dim, out_dim, expand_ratio=0.5, pooling_size=5, act_type='lrelu', norm_type='BN'):
+    def __init__(self, in_dim, out_dim):
         super().__init__()
-        inter_dim = int(in_dim * expand_ratio)
+        ## ----------- Basic Parameters -----------
+        inter_dim = in_dim // 2
         self.out_dim = out_dim
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = Conv(inter_dim * 4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.m = nn.MaxPool2d(kernel_size=pooling_size, stride=1, padding=pooling_size // 2)
+        ## ----------- Network Parameters -----------
+        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1, padding=0, stride=1)
+        self.cv2 = ConvModule(inter_dim * 4, out_dim, kernel_size=1, padding=0, stride=1)
+        self.m = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
+
+        # Initialize all layers
+        self.init_weights()
+
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
 
     def forward(self, x):
         x = self.cv1(x)
@@ -24,75 +39,25 @@ class SPPF(nn.Module):
         return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
 
 
-# SPPF block with CSP module
-class SPPFBlockCSP(nn.Module):
-    """
-        CSP Spatial Pyramid Pooling Block
-    """
-    def __init__(self,
-                 in_dim,
-                 out_dim,
-                 expand_ratio=0.5,
-                 pooling_size=5,
-                 act_type='lrelu',
-                 norm_type='BN',
-                 depthwise=False
-                 ):
-        super(SPPFBlockCSP, self).__init__()
-        inter_dim = int(in_dim * expand_ratio)
-        self.out_dim = out_dim
-        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.m = nn.Sequential(
-            Conv(inter_dim, inter_dim, k=3, p=1, 
-                 act_type=act_type, norm_type=norm_type, 
-                 depthwise=depthwise),
-            SPPF(inter_dim, 
-                 inter_dim, 
-                 expand_ratio=1.0, 
-                 pooling_size=pooling_size, 
-                 act_type=act_type, 
-                 norm_type=norm_type),
-            Conv(inter_dim, inter_dim, k=3, p=1, 
-                 act_type=act_type, norm_type=norm_type, 
-                 depthwise=depthwise)
-        )
-        self.cv3 = Conv(inter_dim * 2, self.out_dim, k=1, act_type=act_type, norm_type=norm_type)
-
-        
-    def forward(self, x):
-        x1 = self.cv1(x)
-        x2 = self.cv2(x)
-        x3 = self.m(x2)
-        y = self.cv3(torch.cat([x1, x3], dim=1))
-
-        return y
-
-
-def build_neck(cfg, in_dim, out_dim):
-    model = cfg['neck']
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # Build a neck
+    in_dim  = 512
+    out_dim = 512
+    neck = SPPF(in_dim, out_dim)
+
+    # Inference
+    x = torch.randn(1, in_dim, 20, 20)
+    t0 = time.time()
+    output = neck(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('Neck output: ', output.shape)
+
+    flops, params = profile(neck, inputs=(x, ), verbose=False)
     print('==============================')
-    print('Neck: {}'.format(model))
-    # build neck
-    if model == 'sppf':
-        neck = SPPF(
-            in_dim=in_dim,
-            out_dim=out_dim,
-            expand_ratio=cfg['expand_ratio'], 
-            pooling_size=cfg['pooling_size'],
-            act_type=cfg['neck_act'],
-            norm_type=cfg['neck_norm']
-            )
-    elif model == 'csp_sppf':
-        neck = SPPFBlockCSP(
-            in_dim=in_dim,
-            out_dim=out_dim,
-            expand_ratio=cfg['expand_ratio'], 
-            pooling_size=cfg['pooling_size'],
-            act_type=cfg['neck_act'],
-            norm_type=cfg['neck_norm'],
-            depthwise=cfg['neck_depthwise']
-            )
-
-    return neck
-        
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 118 - 113
yolo/models/yolov4/yolov4_pafpn.py

@@ -1,137 +1,142 @@
+from typing import List
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from .yolov4_basic import Conv, CSPBlock
 
+try:
+    from .modules import ConvModule, CSPBlock
+except:
+    from  modules import ConvModule, CSPBlock
 
-# PaFPN-CSP
+
+# Yolov4FPN
 class Yolov4PaFPN(nn.Module):
-    def __init__(self, 
-                 in_dims=[256, 512, 1024],
-                 out_dim=256,
-                 width=1.0,
-                 depth=1.0,
-                 act_type='silu',
-                 norm_type='BN',
-                 depthwise=False):
+    def __init__(self, cfg, in_dims: List = [256, 512, 1024]):
         super(Yolov4PaFPN, self).__init__()
         self.in_dims = in_dims
-        self.out_dim = out_dim
         c3, c4, c5 = in_dims
 
-        # top down
+        # ---------------------- Yolov4's Top down FPN ----------------------
         ## P5 -> P4
-        self.reduce_layer_1 = Conv(c5, int(512*width), k=1, norm_type=norm_type, act_type=act_type)
-        self.top_down_layer_1 = CSPBlock(in_dim = c4 + int(512*width),
-                                         out_dim = int(512*width),
-                                         expand_ratio = 0.5,
-                                         nblocks = int(3*depth),
-                                         shortcut = False,
-                                         depthwise = depthwise,
-                                         norm_type = norm_type,
-                                         act_type = act_type
+        self.reduce_layer_1   = ConvModule(c5, round(512*cfg.width), kernel_size=1, padding=0, stride=1)
+        self.top_down_layer_1 = CSPBlock(in_dim     = c4 + round(512*cfg.width),
+                                         out_dim    = round(512*cfg.width),
+                                         num_blocks = round(3*cfg.depth),
+                                         expansion  = 0.5,
+                                         shortcut   = False,
                                          )
 
         ## P4 -> P3
-        self.reduce_layer_2 = Conv(c4, int(256*width), k=1, norm_type=norm_type, act_type=act_type)
-        self.top_down_layer_2 = CSPBlock(in_dim = c3 + int(256*width), 
-                                         out_dim = int(256*width),
-                                         expand_ratio = 0.5,
-                                         nblocks = int(3*depth),
-                                         shortcut = False,
-                                         depthwise = depthwise,
-                                         norm_type = norm_type,
-                                         act_type=act_type
+        self.reduce_layer_2   = ConvModule(round(512*cfg.width), round(256*cfg.width), kernel_size=1, padding=0, stride=1)
+        self.top_down_layer_2 = CSPBlock(in_dim     = c3 + round(256*cfg.width),
+                                         out_dim    = round(256*cfg.width),
+                                         num_blocks = round(3*cfg.depth),
+                                         expansion  = 0.5,
+                                         shortcut   = False,
                                          )
-
-        # bottom up
+        
+        # ---------------------- Yolov4's Bottom up PAN ----------------------
         ## P3 -> P4
-        self.reduce_layer_3 = Conv(int(256*width), int(256*width), k=3, p=1, s=2,
-                                   depthwise=depthwise, norm_type=norm_type, act_type=act_type)
-        self.bottom_up_layer_1 = CSPBlock(in_dim = int(256*width) + int(256*width),
-                                          out_dim = int(512*width),
-                                          expand_ratio = 0.5,
-                                          nblocks = int(3*depth),
-                                          shortcut = False,
-                                          depthwise = depthwise,
-                                          norm_type = norm_type,
-                                          act_type=act_type
-                                          )
-
+        self.downsample_layer_1 = ConvModule(round(256*cfg.width), round(256*cfg.width), kernel_size=3, padding=1, stride=2)
+        self.bottom_up_layer_1  = CSPBlock(in_dim     = round(256*cfg.width) + round(256*cfg.width),
+                                           out_dim    = round(512*cfg.width),
+                                           num_blocks = round(3*cfg.depth),
+                                           expansion  = 0.5,
+                                           shortcut   = False,
+                                           )
         ## P4 -> P5
-        self.reduce_layer_4 = Conv(int(512*width), int(512*width), k=3, p=1, s=2,
-                                   depthwise=depthwise, norm_type=norm_type, act_type=act_type)
-        self.bottom_up_layer_2 = CSPBlock(in_dim = int(512*width) + int(512*width),
-                                          out_dim = int(1024*width),
-                                          expand_ratio = 0.5,
-                                          nblocks = int(3*depth),
-                                          shortcut = False,
-                                          depthwise = depthwise,
-                                          norm_type = norm_type,
-                                          act_type=act_type
-                                          )
+        self.downsample_layer_2 = ConvModule(round(512*cfg.width), round(512*cfg.width), kernel_size=3, padding=1, stride=2)
+        self.bottom_up_layer_2  = CSPBlock(in_dim     = round(512*cfg.width) + round(512*cfg.width),
+                                           out_dim    = round(1024*cfg.width),
+                                           num_blocks = round(3*cfg.depth),
+                                           expansion  = 0.5,
+                                           shortcut   = False,
+                                           )
+
+        # ---------------------- Yolov4's output projection ----------------------
+        self.out_layers = nn.ModuleList([
+            ConvModule(in_dim, round(cfg.head_dim*cfg.width), kernel_size=1)
+                      for in_dim in [round(256*cfg.width), round(512*cfg.width), round(1024*cfg.width)]
+                      ])
+        self.out_dims = [round(cfg.head_dim*cfg.width)] * 3
+
+        # Initialize all layers
+        self.init_weights()
+
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                m.reset_parameters()
 
-        # output proj layers
-        if out_dim is not None:
-            # output proj layers
-            self.out_layers = nn.ModuleList([
-                Conv(in_dim, out_dim, k=1,
-                        norm_type=norm_type, act_type=act_type)
-                        for in_dim in [int(256 * width), int(512 * width), int(1024 * width)]
-                        ])
-            self.out_dim = [out_dim] * 3
+    def forward(self, features):
+        c3, c4, c5 = features
+        
+        # ------------------ Top down FPN ------------------
+        ## P5 -> P4
+        p5 = self.reduce_layer_1(c5)
+        p5_up = F.interpolate(p5, scale_factor=2.0)
+        p4 = self.top_down_layer_1(torch.cat([c4, p5_up], dim=1))
 
-        else:
-            self.out_layers = None
-            self.out_dim = [int(256 * width), int(512 * width), int(1024 * width)]
+        ## P4 -> P3
+        p4 = self.reduce_layer_2(p4)
+        p4_up = F.interpolate(p4, scale_factor=2.0)
+        p3 = self.top_down_layer_2(torch.cat([c3, p4_up], dim=1))
 
+        # ------------------ Bottom up PAN ------------------
+        ## P3 -> P4
+        p3_ds = self.downsample_layer_1(p3)
+        p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
 
-    def forward(self, features):
-        c3, c4, c5 = features
+        ## P4 -> P5
+        p4_ds = self.downsample_layer_2(p4)
+        p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
 
-        c6 = self.reduce_layer_1(c5)
-        c7 = F.interpolate(c6, scale_factor=2.0)   # s32->s16
-        c8 = torch.cat([c7, c4], dim=1)
-        c9 = self.top_down_layer_1(c8)
-        # P3/8
-        c10 = self.reduce_layer_2(c9)
-        c11 = F.interpolate(c10, scale_factor=2.0)   # s16->s8
-        c12 = torch.cat([c11, c3], dim=1)
-        c13 = self.top_down_layer_2(c12)  # to det
-        # p4/16
-        c14 = self.reduce_layer_3(c13)
-        c15 = torch.cat([c14, c10], dim=1)
-        c16 = self.bottom_up_layer_1(c15)  # to det
-        # p5/32
-        c17 = self.reduce_layer_4(c16)
-        c18 = torch.cat([c17, c6], dim=1)
-        c19 = self.bottom_up_layer_2(c18)  # to det
-
-        out_feats = [c13, c16, c19] # [P3, P4, P5]
+        out_feats = [p3, p4, p5]
 
         # output proj layers
-        if self.out_layers is not None:
-            # output proj layers
-            out_feats_proj = []
-            for feat, layer in zip(out_feats, self.out_layers):
-                out_feats_proj.append(layer(feat))
-            return out_feats_proj
-
-        return out_feats
-
-
-def build_fpn(cfg, in_dims, out_dim=None):
-    model = cfg['fpn']
-    # build neck
-    if model == 'yolov4_pafpn':
-        fpn_net = Yolov4PaFPN(in_dims=in_dims,
-                             out_dim=out_dim,
-                             width=cfg['width'],
-                             depth=cfg['depth'],
-                             act_type=cfg['fpn_act'],
-                             norm_type=cfg['fpn_norm'],
-                             depthwise=cfg['fpn_depthwise']
-                             )
-
-
-    return fpn_net
+        out_feats_proj = []
+        for feat, layer in zip(out_feats, self.out_layers):
+            out_feats_proj.append(layer(feat))
+            
+        return out_feats_proj
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv4-Base config
+    class Yolov4BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 0.50
+            self.depth    = 0.34
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.num_levels = 3
+            ## Head
+            self.head_dim = 256
+
+    cfg = Yolov4BaseConfig()
+    # Build a head
+    in_dims  = [128, 256, 512]
+    fpn = Yolov4PaFPN(cfg, in_dims)
+
+    # Inference
+    x = [torch.randn(1, in_dims[0], 80, 80),
+         torch.randn(1, in_dims[1], 40, 40),
+         torch.randn(1, in_dims[2], 20, 20)]
+    t0 = time.time()
+    output = fpn(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('====== FPN output ====== ')
+    for level, feat in enumerate(output):
+        print("- Level-{} : ".format(level), feat.shape)
+
+    flops, params = profile(fpn, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 216 - 0
yolo/models/yolov4/yolov4_pred.py

@@ -0,0 +1,216 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+# -------------------- Detection Pred Layer --------------------
+class DetPredLayer(nn.Module):
+    def __init__(self,
+                 cls_dim      :int,
+                 reg_dim      :int,
+                 stride       :int,
+                 num_classes  :int,
+                 anchor_sizes :List,
+                 ):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.stride  = stride
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.num_classes = num_classes
+        # ------------------- Anchor box -------------------
+        self.anchor_size = torch.as_tensor(anchor_sizes).float().view(-1, 2) # [A, 2]
+        self.num_anchors = self.anchor_size.shape[0]
+
+        # --------- Network Parameters ----------
+        self.obj_pred = nn.Conv2d(self.cls_dim, 1 * self.num_anchors, kernel_size=1)
+        self.cls_pred = nn.Conv2d(self.cls_dim, num_classes * self.num_anchors, kernel_size=1)
+        self.reg_pred = nn.Conv2d(self.reg_dim, 4 * self.num_anchors, kernel_size=1)                
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # Init bias
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        # obj pred
+        b = self.obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        self.obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # cls pred
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = self.reg_pred.weight
+        w.data.fill_(0.)
+        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+    def generate_anchors(self, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # 特征图的宽和高
+        fmp_h, fmp_w = fmp_size
+
+        # 生成网格的x坐标和y坐标
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+
+        # 将xy两部分的坐标拼起来:[H, W, 2] -> [HW, 2]
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        # [HW, 2] -> [HW, A, 2] -> [M, 2], M=HWA
+        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1)
+        anchor_xy = anchor_xy.view(-1, 2)
+
+        # [A, 2] -> [1, A, 2] -> [HW, A, 2] -> [M, 2], M=HWA
+        anchor_wh = self.anchor_size.unsqueeze(0).repeat(fmp_h*fmp_w, 1, 1)
+        anchor_wh = anchor_wh.view(-1, 2)
+
+        anchors = torch.cat([anchor_xy, anchor_wh], dim=-1)
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat):
+        # 预测层
+        obj_pred = self.obj_pred(reg_feat)
+        cls_pred = self.cls_pred(cls_feat)
+        reg_pred = self.reg_pred(reg_feat)
+
+        # 生成网格坐标
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+
+        # 对 pred 的size做一些view调整,便于后续的处理
+        # [B, C*A, H, W] -> [B, H, W, C*A] -> [B, H*W*A, C]
+        obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+        
+        # 解算边界框坐标
+        cxcy_pred = (torch.sigmoid(reg_pred[..., :2]) + anchors[..., :2]) * self.stride
+        bwbh_pred = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
+        pred_x1y1 = cxcy_pred - bwbh_pred * 0.5
+        pred_x2y2 = cxcy_pred + bwbh_pred * 0.5
+        box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        # output dict
+        outputs = {"pred_obj": obj_pred,       # (torch.Tensor) [B, M, 1]
+                   "pred_cls": cls_pred,       # (torch.Tensor) [B, M, C]
+                   "pred_reg": reg_pred,       # (torch.Tensor) [B, M, 4]
+                   "pred_box": box_pred,       # (torch.Tensor) [B, M, 4]
+                   "anchors" : anchors,        # (torch.Tensor) [M, 2]
+                   "fmp_size": fmp_size,
+                   "stride"  : self.stride,    # (Int)
+                   }
+
+        return outputs
+
+class Yolov4DetPredLayer(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.num_levels = len(cfg.out_stride)
+
+        # ----------- Network Parameters -----------
+        ## pred layers
+        self.multi_level_preds = nn.ModuleList(
+            [DetPredLayer(cls_dim      = round(cfg.head_dim * cfg.width),
+                          reg_dim      = round(cfg.head_dim * cfg.width),
+                          stride       = cfg.out_stride[level],
+                          anchor_sizes = cfg.anchor_size[level],
+                          num_classes  = cfg.num_classes,)
+                          for level in range(self.num_levels)
+                          ])
+
+    def forward(self, cls_feats, reg_feats):
+        all_anchors = []
+        all_strides = []
+        all_fmp_sizes = []
+        all_obj_preds = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        for level in range(self.num_levels):
+            # -------------- Single-level prediction --------------
+            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
+
+            # collect results
+            all_obj_preds.append(outputs["pred_obj"])
+            all_cls_preds.append(outputs["pred_cls"])
+            all_reg_preds.append(outputs["pred_reg"])
+            all_box_preds.append(outputs["pred_box"])
+            all_fmp_sizes.append(outputs["fmp_size"])
+            all_anchors.append(outputs["anchors"])
+        
+        # output dict
+        outputs = {"pred_obj":  all_obj_preds,         # List(Tensor) [B, M, 1]
+                   "pred_cls":  all_cls_preds,         # List(Tensor) [B, M, C]
+                   "pred_reg":  all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box":  all_box_preds,         # List(Tensor) [B, M, 4]
+                   "fmp_sizes": all_fmp_sizes,         # List(Tensor) [M, 1]
+                   "anchors":   all_anchors,           # List(Tensor) [M, 2]
+                   "strides":   self.cfg.out_stride,   # List(Int) = [8, 16, 32]
+                   }
+
+        return outputs
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv8-Base config
+    class Yolov4BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 1.0
+            self.depth    = 1.0
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            self.num_levels = 3
+            ## Head
+            self.head_dim  = 256
+            self.anchor_size = {0: [[10, 13],   [16, 30],   [33, 23]],
+                                1: [[30, 61],   [62, 45],   [59, 119]],
+                                2: [[116, 90],  [156, 198], [373, 326]]}
+
+    cfg = Yolov4BaseConfig()
+    cfg.num_classes = 20
+    # Build a pred layer
+    pred = Yolov4DetPredLayer(cfg)
+
+    # Inference
+    cls_feats = [torch.randn(1, cfg.head_dim, 80, 80),
+                 torch.randn(1, cfg.head_dim, 40, 40),
+                 torch.randn(1, cfg.head_dim, 20, 20),]
+    reg_feats = [torch.randn(1, cfg.head_dim, 80, 80),
+                 torch.randn(1, cfg.head_dim, 40, 40),
+                 torch.randn(1, cfg.head_dim, 20, 20),]
+    t0 = time.time()
+    output = pred(cls_feats, reg_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('====== Pred output ======= ')
+    pred_obj = output["pred_obj"]
+    pred_cls = output["pred_cls"]
+    pred_reg = output["pred_reg"]
+    pred_box = output["pred_box"]
+    anchors  = output["anchors"]
+    
+    for level in range(cfg.num_levels):
+        print("- Level-{} : objectness       -> {}".format(level, pred_obj[level].shape))
+        print("- Level-{} : classification   -> {}".format(level, pred_cls[level].shape))
+        print("- Level-{} : delta regression -> {}".format(level, pred_reg[level].shape))
+        print("- Level-{} : bbox regression  -> {}".format(level, pred_box[level].shape))
+        print("- Level-{} : anchor boxes     -> {}".format(level, anchors[level].shape))
+
+    flops, params = profile(pred, inputs=(cls_feats, reg_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 2 - 2
yolo/models/yolov5/yolov5_head.py

@@ -62,7 +62,7 @@ class Yolov5DetHead(nn.Module):
                      reg_head_dim = round(cfg.head_dim * cfg.width),
                      num_cls_head = cfg.num_cls_head,
                      num_reg_head = cfg.num_reg_head,
-                     ) for level in range(cfg.num_levels)])
+                     ) for level in range(len(cfg.out_stride))])
         # --------- Basic Parameters ----------
         self.in_dims = in_dims
         self.cls_head_dim = cfg.head_dim
@@ -134,4 +134,4 @@ if __name__=='__main__':
     flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
     print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))    
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 4 - 3
yolo/models/yolov5/yolov5_pred.py

@@ -112,6 +112,7 @@ class Yolov5DetPredLayer(nn.Module):
         super().__init__()
         # --------- Basic Parameters ----------
         self.cfg = cfg
+        self.num_levels = len(cfg.out_stride)
 
         # ----------- Network Parameters -----------
         ## pred layers
@@ -121,7 +122,7 @@ class Yolov5DetPredLayer(nn.Module):
                           stride       = cfg.out_stride[level],
                           anchor_sizes = cfg.anchor_size[level],
                           num_classes  = cfg.num_classes,)
-                          for level in range(cfg.num_levels)
+                          for level in range(self.num_levels)
                           ])
 
     def forward(self, cls_feats, reg_feats):
@@ -131,7 +132,7 @@ class Yolov5DetPredLayer(nn.Module):
         all_cls_preds = []
         all_reg_preds = []
         all_box_preds = []
-        for level in range(self.cfg.num_levels):
+        for level in range(self.num_levels):
             # -------------- Single-level prediction --------------
             outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
 
@@ -199,7 +200,7 @@ if __name__=='__main__':
     pred_box = output["pred_box"]
     anchors  = output["anchors"]
     
-    for level in range(cfg.num_levels):
+    for level in range(len(cfg.out_stride)):
         print("- Level-{} : objectness       -> {}".format(level, pred_obj[level].shape))
         print("- Level-{} : classification   -> {}".format(level, pred_cls[level].shape))
         print("- Level-{} : delta regression -> {}".format(level, pred_reg[level].shape))