Ver código fonte

update yolov4

yjh0410 11 meses atrás
pai
commit
eb1add0868

+ 7 - 91
yolo/config/yolov4_config.py

@@ -2,37 +2,23 @@
 
 
 def build_yolov4_config(args):
-    if   args.model == 'yolov4_n':
-        return Yolov4NConfig()
-    elif args.model == 'yolov4_s':
-        return Yolov4SConfig()
-    elif args.model == 'yolov4_m':
-        return Yolov4MConfig()
-    elif args.model == 'yolov4_l':
-        return Yolov4LConfig()
-    elif args.model == 'yolov4_x':
-        return Yolov4XConfig()
-    else:
-        raise NotImplementedError("No config for model: {}".format(args.model))
+    return Yolov4Config()
     
-# YOLOv4-Base config
-class Yolov4BaseConfig(object):
+# YOLOv4 config
+class Yolov4Config(object):
     def __init__(self) -> None:
         # ---------------- Model config ----------------
-        self.width    = 1.0
-        self.depth    = 1.0
         self.out_stride = [8, 16, 32]
         self.max_stride = 32
-        self.model_scale = "b"
         ## Backbone
         self.use_pretrained = True
         ## Head
         self.head_dim       = 256
         self.num_cls_head   = 2
         self.num_reg_head   = 2
-        self.anchor_size    = {0: [[10, 13],   [16, 30],   [33, 23]],
-                               1: [[30, 61],   [62, 45],   [59, 119]],
-                               2: [[116, 90],  [156, 198], [373, 326]]}
+        self.anchor_size  = [[10, 13],   [16, 30],   [33, 23],
+                             [30, 61],   [62, 45],   [59, 119],
+                             [116, 90],  [156, 198], [373, 326]]
 
         # ---------------- Post-process config ----------------
         ## Post process
@@ -78,7 +64,7 @@ class Yolov4BaseConfig(object):
         # ---------------- Data process config ----------------
         self.aug_type = 'yolo'
         self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.15
+        self.mixup_prob  = 0.1
         self.copy_paste  = 0.0           # approximated by the YOLOX's mixup
         self.multi_scale = [0.5, 1.25]   # multi scale: [img_size * 0.5, img_size * 1.25]
         ## Pixel mean & std
@@ -102,73 +88,3 @@ class Yolov4BaseConfig(object):
         config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
         for k, v in config_dict.items():
             print("{} : {}".format(k, v))
-
-# YOLOv4-N
-class Yolov4NConfig(Yolov4BaseConfig):
-    def __init__(self) -> None:
-        super().__init__()
-        # ---------------- Model config ----------------
-        self.width = 0.25
-        self.depth = 0.34
-        self.model_scale = "n"
-
-        # ---------------- Data process config ----------------
-        self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.0
-        self.copy_paste  = 0.0
-
-# YOLOv4-S
-class Yolov4SConfig(Yolov4BaseConfig):
-    def __init__(self) -> None:
-        super().__init__()
-        # ---------------- Model config ----------------
-        self.width = 0.50
-        self.depth = 0.34
-        self.model_scale = "s"
-
-        # ---------------- Data process config ----------------
-        self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.0
-        self.copy_paste  = 0.0
-
-# YOLOv4-M
-class Yolov4MConfig(Yolov4BaseConfig):
-    def __init__(self) -> None:
-        super().__init__()
-        # ---------------- Model config ----------------
-        self.width = 0.75
-        self.depth = 0.67
-        self.model_scale = "m"
-
-        # ---------------- Data process config ----------------
-        self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.1
-        self.copy_paste  = 0.0
-
-# YOLOv4-L
-class Yolov4LConfig(Yolov4BaseConfig):
-    def __init__(self) -> None:
-        super().__init__()
-        # ---------------- Model config ----------------
-        self.width = 1.0
-        self.depth = 1.0
-        self.model_scale = "l"
-
-        # ---------------- Data process config ----------------
-        self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.1
-        self.copy_paste  = 0.0
-
-# YOLOv4-X
-class Yolov4XConfig(Yolov4BaseConfig):
-    def __init__(self) -> None:
-        super().__init__()
-        # ---------------- Model config ----------------
-        self.width = 1.25
-        self.depth = 1.34
-        self.model_scale = "x"
-
-        # ---------------- Data process config ----------------
-        self.mosaic_prob = 1.0
-        self.mixup_prob  = 0.1
-        self.copy_paste  = 0.0

+ 0 - 2
yolo/models/yolov3/modules.py

@@ -2,8 +2,6 @@ import torch
 import torch.nn as nn
 
 
-
-
 # --------------------- Basic modules ---------------------
 class ConvModule(nn.Module):
     def __init__(self, 

+ 0 - 53
yolo/models/yolov4/README.md

@@ -1,53 +0,0 @@
-# YOLOv4:
-
-|    Model    |     Backbone    | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
-|-------------|-----------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
-| YOLOv4-Tiny | CSPDarkNet-Tiny | 1xb16 |  640  |        31.0            |       49.1        |   8.1             |   2.9              | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov4_t_coco.pth) |
-| YOLOv4      | CSPDarkNet-53   | 1xb16 |  640  |        46.6            |       65.8        |   162.7           |   61.5             | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov4_coco.pth) |
-
-- For training, we train YOLOv4 and YOLOv4-Tiny with 250 epochs on COCO.
-- For data augmentation, we use the large scale jitter (LSJ), Mosaic augmentation and Mixup augmentation, following the setting of [YOLOv5](https://github.com/ultralytics/yolov5).
-- For optimizer, we use SGD with momentum 0.937, weight decay 0.0005 and base lr 0.01.
-- For learning rate scheduler, we use linear decay scheduler.
-- For YOLOv4's structure, we use decoupled head, following the setting of [YOLOX](https://github.com/Megvii-BaseDetection/YOLOX).
-
-## Train YOLOv4
-### Single GPU
-Taking training YOLOv4 on COCO as the example,
-```Shell
-python train.py --cuda -d coco --root path/to/coco -m yolov4 -bs 16 -size 640 --wp_epoch 3 --max_epoch 300 --eval_epoch 10 --no_aug_epoch 20 --ema --fp16 --multi_scale 
-```
-
-### Multi GPU
-Taking training YOLOv4 on COCO as the example,
-```Shell
-python -m torch.distributed.run --nproc_per_node=8 train.py --cuda -dist -d coco --root /data/datasets/ -m yolov4 -bs 128 -size 640 --wp_epoch 3 --max_epoch 300  --eval_epoch 10 --no_aug_epoch 20 --ema --fp16 --sybn --multi_scale --save_folder weights/ 
-```
-
-## Test YOLOv4
-Taking testing YOLOv4 on COCO-val as the example,
-```Shell
-python test.py --cuda -d coco --root path/to/coco -m yolov4 --weight path/to/yolov4_coco.pth -size 640 --show 
-```
-
-## Evaluate YOLOv4
-Taking evaluating YOLOv4 on COCO-val as the example,
-```Shell
-python eval.py --cuda -d coco --root path/to/coco -m yolov4 --weight path/to/yolov4_coco.pth
-```
-
-## Demo
-### Detect with Image
-```Shell
-python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m yolov4 --weight path/to/yolov4_coco.pth -size 640 --show
-```
-
-### Detect with Video
-```Shell
-python demo.py --mode video --path_to_vid path/to/video --cuda -m yolov4 --weight path/to/yolov4_coco.pth -size 640 --show --gif
-```
-
-### Detect with Camera
-```Shell
-python demo.py --mode camera --cuda -m yolov4 --weight path/to/yolov4_coco.pth -size 640 --show --gif
-```

+ 2 - 10
yolo/models/yolov4/build.py

@@ -1,5 +1,3 @@
-import torch.nn as nn
-
 from .loss import SetCriterion
 from .yolov4 import Yolov4
 
@@ -8,17 +6,11 @@ from .yolov4 import Yolov4
 def build_yolov4(cfg, is_val=False):
     # -------------- Build YOLO --------------
     model = Yolov4(cfg, is_val)
-
-    # -------------- Initialize YOLO --------------
-    for m in model.modules():
-        if isinstance(m, nn.BatchNorm2d):
-            m.eps = 1e-3
-            m.momentum = 0.03    
-            
+  
     # -------------- Build criterion --------------
     criterion = None
     if is_val:
         # build criterion for training
         criterion = SetCriterion(cfg)
         
-    return model, criterion
+    return model, criterion

+ 14 - 16
yolo/models/yolov4/loss.py

@@ -1,23 +1,21 @@
 import torch
 import torch.nn.functional as F
-
+from .matcher import Yolov4Matcher
 from utils.box_ops import get_ious
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
-from .matcher import Yolov3Matcher
-
 
 class SetCriterion(object):
     def __init__(self, cfg):
         self.cfg = cfg
         self.num_classes = cfg.num_classes
-        self.loss_obj_weight = cfg.loss_obj
-        self.loss_cls_weight = cfg.loss_cls
-        self.loss_box_weight = cfg.loss_box
+        # loss weight
+        self.loss_obj_weight = cfg.loss_obj_weight
+        self.loss_cls_weight = cfg.loss_cls_weight
+        self.loss_box_weight = cfg.loss_box_weight
 
         # matcher
-        anchor_size = cfg.anchor_size[0] + cfg.anchor_size[1] + cfg.anchor_size[2]
-        self.matcher = Yolov3Matcher(cfg.num_classes, 3, anchor_size, cfg.iou_thresh)
+        self.matcher = Yolov4Matcher(self.num_classes, 3, cfg.anchor_size, cfg.iou_thresh)
 
     def loss_objectness(self, pred_obj, gt_obj):
         loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
@@ -40,21 +38,21 @@ class SetCriterion(object):
         return loss_box, ious
 
     def __call__(self, outputs, targets):
-        device = outputs['pred_cls'][0].device
-        fpn_strides = outputs['strides']
-        fmp_sizes = outputs['fmp_sizes']
+        # Label assignment
         (
             gt_objectness, 
             gt_classes, 
             gt_bboxes,
-            ) = self.matcher(fmp_sizes=fmp_sizes, 
-                             fpn_strides=fpn_strides, 
-                             targets=targets)
+            ) = self.matcher(fmp_sizes   = outputs['fmp_sizes'], 
+                             fpn_strides = outputs['strides'], 
+                             targets     = targets)
+        
         # List[B, M, C] -> [B, M, C] -> [BM, C]
         pred_obj = torch.cat(outputs['pred_obj'], dim=1).view(-1)                      # [BM,]
         pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes)    # [BM, C]
         pred_box = torch.cat(outputs['pred_box'], dim=1).view(-1, 4)                   # [BM, 4]
-       
+        device = pred_box.device
+
         gt_objectness = gt_objectness.view(-1).to(device).float()               # [BM,]
         gt_classes = gt_classes.view(-1, self.num_classes).to(device).float()   # [BM, C]
         gt_bboxes = gt_bboxes.view(-1, 4).to(device).float()                    # [BM, 4]
@@ -96,6 +94,6 @@ class SetCriterion(object):
 
         return loss_dict
     
-    
+
 if __name__ == "__main__":
     pass

+ 21 - 10
yolo/models/yolov4/matcher.py

@@ -2,7 +2,7 @@ import numpy as np
 import torch
 
 
-class Yolov3Matcher(object):
+class Yolov4Matcher(object):
     def __init__(self, num_classes, num_anchors, anchor_size, iou_thresh):
         self.num_classes = num_classes
         self.num_anchors = num_anchors
@@ -12,6 +12,7 @@ class Yolov3Matcher(object):
             for anchor in anchor_size]
             )  # [KA, 4]
 
+
     def compute_iou(self, anchor_boxes, gt_box):
         """
             anchor_boxes : ndarray -> [KA, 4] (cx, cy, bw, bh).
@@ -49,6 +50,7 @@ class Yolov3Matcher(object):
         
         return iou
 
+
     @torch.no_grad()
     def __call__(self, fmp_sizes, fpn_strides, targets):
         """
@@ -136,17 +138,26 @@ class Yolov3Matcher(object):
                 # label assignment
                 for result in label_assignment_results:
                     grid_x, grid_y, level, anchor_idx = result
+                    stride = fpn_strides[level]
+                    x1s, y1s = x1 / stride, y1 / stride
+                    x2s, y2s = x2 / stride, y2 / stride
                     fmp_h, fmp_w = fmp_sizes[level]
 
-                    if grid_x < fmp_w and grid_y < fmp_h:
-                        # obj
-                        gt_objectness[level][batch_index, grid_y, grid_x, anchor_idx] = 1.0
-                        # cls
-                        cls_ont_hot = torch.zeros(self.num_classes)
-                        cls_ont_hot[int(gt_label)] = 1.0
-                        gt_classes[level][batch_index, grid_y, grid_x, anchor_idx] = cls_ont_hot
-                        # box
-                        gt_bboxes[level][batch_index, grid_y, grid_x, anchor_idx] = torch.as_tensor([x1, y1, x2, y2])
+                    # 3x3 center sampling
+                    for j in range(grid_y - 1, grid_y + 2):
+                        for i in range(grid_x - 1, grid_x + 2):
+                            is_in_box = (j >= y1s and j < y2s) and (i >= x1s and i < x2s)
+                            is_valid = (j >= 0 and j < fmp_h) and (i >= 0 and i < fmp_w)
+
+                            if is_in_box and is_valid:
+                                # obj
+                                gt_objectness[level][batch_index, j, i, anchor_idx] = 1.0
+                                # cls
+                                cls_ont_hot = torch.zeros(self.num_classes)
+                                cls_ont_hot[int(gt_label)] = 1.0
+                                gt_classes[level][batch_index, j, i, anchor_idx] = cls_ont_hot
+                                # box
+                                gt_bboxes[level][batch_index, j, i, anchor_idx] = torch.as_tensor([x1, y1, x2, y2])
 
         # [B, M, C]
         gt_objectness = torch.cat([gt.view(bs, -1, 1) for gt in gt_objectness], dim=1).float()

+ 39 - 49
yolo/models/yolov4/modules.py

@@ -1,76 +1,66 @@
 import torch
 import torch.nn as nn
-from typing import List
 
 
 # --------------------- Basic modules ---------------------
 class ConvModule(nn.Module):
     def __init__(self, 
-                 in_dim,        # in channels
-                 out_dim,       # out channels 
-                 kernel_size=1, # kernel size 
-                 padding=0,     # padding
-                 stride=1,      # padding
-                 dilation=1,    # dilation
-                ):
+                 in_dim: int,          # in channels
+                 out_dim: int,         # out channels 
+                 kernel_size: int = 1, # kernel size 
+                 stride:int = 1,       # padding
+                 ):
         super(ConvModule, self).__init__()
-        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
-        self.norm = nn.BatchNorm2d(out_dim)
-        self.act  = nn.SiLU(inplace=True)
+        convs = []
+        convs.append(nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, padding=kernel_size//2, stride=stride, bias=False))
+        convs.append(nn.BatchNorm2d(out_dim))
+        convs.append(nn.SiLU(inplace=True))
+        self.convs = nn.Sequential(*convs)
 
     def forward(self, x):
-        return self.act(self.norm(self.conv(x)))
+        return self.convs(x)
 
-class YoloBottleneck(nn.Module):
+class Bottleneck(nn.Module):
     def __init__(self,
-                 in_dim       :int,
-                 out_dim      :int,
-                 kernel_size  :List  = [1, 3],
-                 expansion    :float = 0.5,
-                 shortcut     :bool  = False,
-                 ) -> None:
-        super(YoloBottleneck, self).__init__()
-        inter_dim = int(out_dim * expansion)
-        # ----------------- Network setting -----------------
-        self.conv_layer1 = ConvModule(in_dim, inter_dim, kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1)
-        self.conv_layer2 = ConvModule(inter_dim, out_dim, kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1)
+                 in_dim: int,
+                 out_dim: int,
+                 expand_ratio: float = 0.5,
+                 shortcut: bool = False,
+                 ):
+        super(Bottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)  # hidden channels            
+        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
+        self.cv2 = ConvModule(inter_dim, out_dim, kernel_size=3, stride=1)
         self.shortcut = shortcut and in_dim == out_dim
 
     def forward(self, x):
-        h = self.conv_layer2(self.conv_layer1(x))
+        h = self.cv2(self.cv1(x))
 
         return x + h if self.shortcut else h
 
 class CSPBlock(nn.Module):
     def __init__(self,
-                 in_dim,
-                 out_dim,
-                 num_blocks :int   = 1,
-                 expansion  :float = 0.5,
-                 shortcut   :bool  = False,
+                 in_dim: int,
+                 out_dim: int,
+                 expand_ratio: float = 0.5,
+                 num_blocks: int = 1,
+                 shortcut: bool = False,
                  ):
         super(CSPBlock, self).__init__()
-        # ---------- Basic parameters ----------
-        self.num_blocks = num_blocks
-        self.expansion = expansion
-        self.shortcut = shortcut
-        inter_dim = round(out_dim * expansion)
-        # ---------- Model parameters ----------
-        self.conv_layer_1 = ConvModule(in_dim, inter_dim, kernel_size=1)
-        self.conv_layer_2 = ConvModule(in_dim, inter_dim, kernel_size=1)
-        self.conv_layer_3 = ConvModule(inter_dim * 2, out_dim, kernel_size=1)
-        self.module = nn.Sequential(*[
-            YoloBottleneck(inter_dim,
-                           inter_dim,
-                           kernel_size = [1, 3],
-                           expansion   = 1.0,
-                           shortcut    = shortcut,
-                           ) for _ in range(num_blocks)])
+        inter_dim = int(out_dim * expand_ratio)
+        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
+        self.cv2 = ConvModule(in_dim, inter_dim, kernel_size=1)
+        self.cv3 = ConvModule(2 * inter_dim, out_dim, kernel_size=1)
+        self.m = nn.Sequential(*[
+            Bottleneck(inter_dim, inter_dim, expand_ratio=1.0, shortcut=shortcut)
+                       for _ in range(num_blocks)
+                       ])
 
     def forward(self, x):
-        x1 = self.conv_layer_1(x)
-        x2 = self.module(self.conv_layer_2(x))
-        out = self.conv_layer_3(torch.cat([x1, x2], dim=1))
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.m(x1)
+        out = self.cv3(torch.cat([x3, x2], dim=1))
 
         return out
     

+ 103 - 40
yolo/models/yolov4/yolov4.py

@@ -1,46 +1,83 @@
-# --------------- Torch components ---------------
 import torch
 import torch.nn as nn
 
+from utils.misc import multiclass_nms
+
 # --------------- Model components ---------------
 from .yolov4_backbone import Yolov4Backbone
-from .yolov4_neck     import SPPF
+from .yolov4_neck     import SPPFBlockCSP
 from .yolov4_pafpn    import Yolov4PaFPN
-from .yolov4_head     import Yolov4DetHead
-from .yolov4_pred     import Yolov4DetPredLayer
+from .yolov4_head     import DecoupledHead
 
 # --------------- External components ---------------
 from utils.misc import multiclass_nms
 
 
-# YOLOv4
 class Yolov4(nn.Module):
-    def __init__(self,
-                 cfg,
-                 is_val = False,
-                 ) -> None:
+    def __init__(self, cfg, is_val: bool = False) -> None:
         super(Yolov4, self).__init__()
         # ---------------------- Basic setting ----------------------
         self.cfg = cfg
         self.num_classes = cfg.num_classes
+        self.out_stride = cfg.out_stride
+        self.num_levels = len(cfg.out_stride)
         ## Post-process parameters
         self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
         self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
         self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
         self.no_multi_labels  = False if is_val else True
+
+        # ------------------- Anchor box setting -------------------
+        self.num_anchors = len(cfg.anchor_size) // self.num_levels
+        self.anchor_size = torch.as_tensor(
+            cfg.anchor_size
+            ).float().view(self.num_levels, self.num_anchors, 2) # [nl, na, 2]
         
-        # ---------------------- Network Parameters ----------------------
-        ## Backbone
-        self.backbone = Yolov4Backbone(cfg)
-        self.pyramid_feat_dims = self.backbone.feat_dims[-3:]
-        ## Neck: SPP
-        self.neck = SPPF(self.pyramid_feat_dims[-1], self.pyramid_feat_dims[-1])
-        ## Neck: FPN
-        self.fpn = Yolov4PaFPN(cfg, self.pyramid_feat_dims)
-        ## Head
-        self.head = Yolov4DetHead(cfg, self.fpn.out_dims)
-        ## Pred
-        self.pred = Yolov4DetPredLayer(cfg)
+        # ------------------- Network Structure -------------------
+        self.backbone = Yolov4Backbone(use_pretrained=cfg.use_pretrained)
+        self.neck     = SPPFBlockCSP(self.backbone.feat_dims[-1], self.backbone.feat_dims[-1], expand_ratio=0.5)
+        self.fpn      = Yolov4PaFPN(self.backbone.feat_dims[-3:], head_dim=cfg.head_dim)
+        self.non_shared_heads = nn.ModuleList([DecoupledHead(cfg, in_dim)
+                                               for in_dim in self.fpn.fpn_out_dims
+                                               ])
+
+        ## 预测层
+        self.obj_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_head_dim, 1 * self.num_anchors, kernel_size=1)
+                             for head in self.non_shared_heads
+                             ]) 
+        self.cls_preds = nn.ModuleList(
+                            [nn.Conv2d(head.cls_head_dim, self.num_classes * self.num_anchors, kernel_size=1) 
+                             for head in self.non_shared_heads
+                             ]) 
+        self.reg_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_head_dim, 4 * self.num_anchors, kernel_size=1) 
+                             for head in self.non_shared_heads
+                             ])                 
+    
+    def generate_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        fmp_h, fmp_w = fmp_size
+        # [KA, 2]
+        anchor_size = self.anchor_size[level]
+
+        # generate grid cells
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        # [HW, 2] -> [HW, KA, 2] -> [M, 2]
+        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1)
+        anchor_xy = anchor_xy.view(-1, 2)
+        anchor_xy += 0.5
+
+        # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2] -> [M, 2]
+        anchor_wh = anchor_size.unsqueeze(0).repeat(fmp_h*fmp_w, 1, 1)
+        anchor_wh = anchor_wh.view(-1, 2)
+
+        anchors = torch.cat([anchor_xy, anchor_wh], dim=-1)
+
+        return anchors
 
     def post_process(self, obj_preds, cls_preds, box_preds):
         """
@@ -64,8 +101,7 @@ class Yolov4(nn.Module):
             box_pred_i = box_pred_i[0]
             if self.no_multi_labels:
                 # [M,]
-                scores, labels = torch.max(
-                    torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()), dim=1)
+                scores, labels = torch.max(torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()), dim=1)
 
                 # Keep top k top scoring indices only.
                 num_topk = min(self.topk_candidates, box_pred_i.size(0))
@@ -124,32 +160,59 @@ class Yolov4(nn.Module):
         return bboxes, scores, labels
     
     def forward(self, x):
-        # ---------------- Backbone ----------------
+        bs = x.shape[0]
         pyramid_feats = self.backbone(x)
-        # ---------------- Neck: SPP ----------------
         pyramid_feats[-1] = self.neck(pyramid_feats[-1])
-
-        # ---------------- Neck: PaFPN ----------------
         pyramid_feats = self.fpn(pyramid_feats)
 
-        # ---------------- Heads ----------------
-        cls_feats, reg_feats = self.head(pyramid_feats)
-
-        # ---------------- Preds ----------------
-        outputs = self.pred(cls_feats, reg_feats)
-        outputs['image_size'] = [x.shape[2], x.shape[3]]
+        all_fmp_sizes = []
+        all_obj_preds = []
+        all_cls_preds = []
+        all_box_preds = []
+        for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+            cls_feat, reg_feat = head(feat)
+
+            # [B, C, H, W]
+            obj_pred = self.obj_preds[level](reg_feat)
+            cls_pred = self.cls_preds[level](cls_feat)
+            reg_pred = self.reg_preds[level](reg_feat)
+
+            fmp_size = cls_pred.shape[-2:]
+
+            # generate anchor boxes: [M, 4]
+            anchors = self.generate_anchors(level, fmp_size)
+            anchors = anchors.to(x.device)
+            
+            # [B, AC, H, W] -> [B, H, W, AC] -> [B, M, C]
+            obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, 1)
+            cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, self.num_classes)
+            reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, 4)
+
+            # decode bbox
+            ctr_pred = (torch.sigmoid(reg_pred[..., :2]) * 3.0 - 1.5 + anchors[..., :2]) * self.out_stride[level]
+            wh_pred = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
+            pred_x1y1 = ctr_pred - wh_pred * 0.5
+            pred_x2y2 = ctr_pred + wh_pred * 0.5
+            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+            all_obj_preds.append(obj_pred)
+            all_cls_preds.append(cls_pred)
+            all_box_preds.append(box_pred)
+            all_fmp_sizes.append(fmp_size)
 
         if not self.training:
-            all_obj_preds = outputs['pred_obj']
-            all_cls_preds = outputs['pred_cls']
-            all_box_preds = outputs['pred_box']
-
-            # post process
             bboxes, scores, labels = self.post_process(all_obj_preds, all_cls_preds, all_box_preds)
             outputs = {
                 "scores": scores,
                 "labels": labels,
                 "bboxes": bboxes
             }
-        
-        return outputs
+        else:
+            outputs = {"pred_obj":  all_obj_preds,   # List [B, M, 1]
+                       "pred_cls":  all_cls_preds,   # List [B, M, C]
+                       "pred_box":  all_box_preds,   # List [B, M, 4]
+                       "fmp_sizes": all_fmp_sizes,   # List
+                       "strides":   self.out_stride, # List
+                        }
+
+        return outputs 

+ 70 - 75
yolo/models/yolov4/yolov4_backbone.py

@@ -5,71 +5,44 @@ try:
     from .modules import ConvModule, CSPBlock
 except:
     from  modules import ConvModule, CSPBlock
+    
 
-# IN1K pretrained weight
-pretrained_urls = {
-    'n': None,
-    's': None,
-    'm': None,
-    'l': None,
-    'x': None,
+in1k_pretrained_urls = {
+    "cspdarknet53": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/cspdarknet53_silu.pth",
 }
 
-# --------------------- Yolov3's Backbone -----------------------
-## Modified DarkNet
+# --------------------- Yolov4 backbone (CSPDarkNet-53 with SiLU) -----------------------
 class Yolov4Backbone(nn.Module):
-    def __init__(self, cfg):
+    def __init__(self, use_pretrained: bool = False):
         super(Yolov4Backbone, self).__init__()
-        # ------------------ Basic setting ------------------
-        self.model_scale = cfg.model_scale
-        self.feat_dims = [round(64   * cfg.width),
-                          round(128  * cfg.width),
-                          round(256  * cfg.width),
-                          round(512  * cfg.width),
-                          round(1024 * cfg.width)]
-        
-        # ------------------ Network setting ------------------
-        ## P1/2
-        self.layer_1 = ConvModule(3, self.feat_dims[0], kernel_size=6, padding=2, stride=2)
-        # P2/4
+        self.feat_dims = [256, 512, 1024]
+        self.use_pretrained = use_pretrained
+
+        # P1
+        self.layer_1 = nn.Sequential(
+            ConvModule(3, 32, kernel_size=3),
+            ConvModule(32, 64, kernel_size=3, stride=2),
+            CSPBlock(64, 64, expand_ratio=0.5, num_blocks=1, shortcut=True)
+        )
+        # P2
         self.layer_2 = nn.Sequential(
-            ConvModule(self.feat_dims[0], self.feat_dims[1], kernel_size=3, padding=1, stride=2),
-            CSPBlock(in_dim     = self.feat_dims[1],
-                     out_dim    = self.feat_dims[1],
-                     num_blocks = round(3*cfg.depth),
-                     expansion  = 0.5,
-                     shortcut   = True,
-                     )
+            ConvModule(64, 128, kernel_size=3, stride=2),
+            CSPBlock(128, 128, expand_ratio=0.5, num_blocks=2, shortcut=True)
         )
-        # P3/8
+        # P3
         self.layer_3 = nn.Sequential(
-            ConvModule(self.feat_dims[1], self.feat_dims[2], kernel_size=3, padding=1, stride=2),
-            CSPBlock(in_dim     = self.feat_dims[2],
-                     out_dim    = self.feat_dims[2],
-                     num_blocks = round(9*cfg.depth),
-                     expansion  = 0.5,
-                     shortcut   = True,
-                     )
+            ConvModule(128, 256, kernel_size=3, stride=2),
+            CSPBlock(256, 256, expand_ratio=0.5, num_blocks=8, shortcut=True)
         )
-        # P4/16
+        # P4
         self.layer_4 = nn.Sequential(
-            ConvModule(self.feat_dims[2], self.feat_dims[3], kernel_size=3, padding=1, stride=2),
-            CSPBlock(in_dim     = self.feat_dims[3],
-                     out_dim    = self.feat_dims[3],
-                     num_blocks = round(9*cfg.depth),
-                     expansion  = 0.5,
-                     shortcut   = True,
-                     )
+            ConvModule(256, 512, kernel_size=3, stride=2),
+            CSPBlock(512, 512, expand_ratio=0.5, num_blocks=8, shortcut=True)
         )
-        # P5/32
+        # P5
         self.layer_5 = nn.Sequential(
-            ConvModule(self.feat_dims[3], self.feat_dims[4], kernel_size=3, padding=1, stride=2),
-            CSPBlock(in_dim     = self.feat_dims[4],
-                     out_dim    = self.feat_dims[4],
-                     num_blocks = round(3*cfg.depth),
-                     expansion  = 0.5,
-                     shortcut   = True,
-                     )
+            ConvModule(512, 1024, kernel_size=3, stride=2),
+            CSPBlock(1024, 1024, expand_ratio=0.5, num_blocks=4, shortcut=True)
         )
 
         # Initialize all layers
@@ -81,42 +54,64 @@ class Yolov4Backbone(nn.Module):
             if isinstance(m, torch.nn.Conv2d):
                 m.reset_parameters()
 
+        # Load imagenet pretrained weight
+        if self.use_pretrained:
+            self.load_pretrained()
+
+    def load_pretrained(self):
+        url = in1k_pretrained_urls["cspdarknet53"]
+        if url is not None:
+            print('Loading backbone pretrained weight from : {}'.format(url))
+            # checkpoint state dict
+            checkpoint = torch.hub.load_state_dict_from_url(
+                url=url, map_location="cpu", check_hash=True)
+            checkpoint_state_dict = checkpoint.pop("model")
+            # model state dict
+            model_state_dict = self.state_dict()
+            # check
+            for k in list(checkpoint_state_dict.keys()):
+                if k in model_state_dict:
+                    shape_model = tuple(model_state_dict[k].shape)
+                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                    if shape_model != shape_checkpoint:
+                        checkpoint_state_dict.pop(k)
+                else:
+                    checkpoint_state_dict.pop(k)
+                    print('Unused key: ', k)
+            # load the weight
+            self.load_state_dict(checkpoint_state_dict)
+        else:
+            print('No pretrained weight for model scale: {}.'.format(self.model_scale))
+
     def forward(self, x):
         c1 = self.layer_1(x)
         c2 = self.layer_2(c1)
         c3 = self.layer_3(c2)
         c4 = self.layer_4(c3)
         c5 = self.layer_5(c4)
+
         outputs = [c3, c4, c5]
 
         return outputs
 
 
-if __name__ == '__main__':
-    import time
+if __name__=='__main__':
     from thop import profile
-    class BaseConfig(object):
-        def __init__(self) -> None:
-            self.width = 0.5
-            self.depth = 0.34
-            self.model_scale = "s"
-            self.use_pretrained = True
-
-    cfg = BaseConfig()
-    model = Yolov4Backbone(cfg)
-    x = torch.randn(1, 3, 640, 640)
-    t0 = time.time()
+
+    # Build backbone
+    model = Yolov4Backbone(use_pretrained=True)
+
+    # Randomly generate a input data
+    x = torch.randn(2, 3, 640, 640)
+
+    # Inference
     outputs = model(x)
-    print(model)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
+    print(' - the shape of input :  ', x.shape)
     for out in outputs:
-        print(out.shape)
+        print(' - the shape of output : ', out.shape)
 
     x = torch.randn(1, 3, 640, 640)
-    print('==============================')
     flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
-    
+    print('============== FLOPs & Params ================')
+    print(' - FLOPs  : {:.2f} G'.format(flops / 1e9 * 2))
+    print(' - Params : {:.2f} M'.format(params / 1e6))

+ 36 - 96
yolo/models/yolov4/yolov4_head.py

@@ -1,5 +1,6 @@
 import torch
 import torch.nn as nn
+from typing import List
 
 try:
     from .modules import ConvModule
@@ -7,38 +8,31 @@ except:
     from  modules import ConvModule
 
 
-## Single-level Detection Head
-class DetHead(nn.Module):
-    def __init__(self,
-                 in_dim       :int  = 256,
-                 cls_head_dim :int  = 256,
-                 reg_head_dim :int  = 256,
-                 num_cls_head :int  = 2,
-                 num_reg_head :int  = 2,
-                 ):
+class DecoupledHead(nn.Module):
+    def __init__(self, cfg, in_dim: int = 256):
         super().__init__()
-        # --------- Basic Parameters ----------
         self.in_dim = in_dim
-        self.num_cls_head = num_cls_head
-        self.num_reg_head = num_reg_head
-        
-        # --------- Network Parameters ----------
-        ## cls head
+        self.cls_head_dim = cfg.head_dim
+        self.reg_head_dim = cfg.head_dim
+        self.num_cls_head = cfg.num_cls_head
+        self.num_reg_head = cfg.num_reg_head
+
+        # classification feature head
         cls_feats = []
-        self.cls_head_dim = cls_head_dim
-        for i in range(num_cls_head):
+        for i in range(self.num_cls_head):
             if i == 0:
-                cls_feats.append(ConvModule(in_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
+                cls_feats.append(ConvModule(in_dim, self.cls_head_dim, kernel_size=3, stride=1))
             else:
-                cls_feats.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
-        ## reg head
+                cls_feats.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, stride=1))
+                
+        # box regression feature head
         reg_feats = []
-        self.reg_head_dim = reg_head_dim
-        for i in range(num_reg_head):
+        for i in range(self.num_reg_head):
             if i == 0:
-                reg_feats.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
+                reg_feats.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, stride=1))
             else:
-                reg_feats.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
+                reg_feats.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, stride=1))
+
         self.cls_feats = nn.Sequential(*cls_feats)
         self.reg_feats = nn.Sequential(*reg_feats)
 
@@ -50,88 +44,34 @@ class DetHead(nn.Module):
         reg_feats = self.reg_feats(x)
 
         return cls_feats, reg_feats
-    
-## Multi-level Detection Head
-class Yolov4DetHead(nn.Module):
-    def __init__(self, cfg, in_dims):
-        super().__init__()
-        ## ----------- Network Parameters -----------
-        self.multi_level_heads = nn.ModuleList(
-            [DetHead(in_dim       = in_dims[level],
-                     cls_head_dim = round(cfg.head_dim * cfg.width),
-                     reg_head_dim = round(cfg.head_dim * cfg.width),
-                     num_cls_head = cfg.num_cls_head,
-                     num_reg_head = cfg.num_reg_head,
-                     ) for level in range(len(cfg.out_stride))])
-        # --------- Basic Parameters ----------
-        self.in_dims = in_dims
-        self.cls_head_dim = cfg.head_dim
-        self.reg_head_dim = cfg.head_dim
-
-        # Initialize all layers
-        self.init_weights()
-
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                # In order to be consistent with the source code,
-                # reset the Conv2d initialization parameters
-                m.reset_parameters()
-
-    def forward(self, feats):
-        """
-            feats: List[(Tensor)] [[B, C, H, W], ...]
-        """
-        cls_feats = []
-        reg_feats = []
-        for feat, head in zip(feats, self.multi_level_heads):
-            # ---------------- Pred ----------------
-            cls_feat, reg_feat = head(feat)
-
-            cls_feats.append(cls_feat)
-            reg_feats.append(reg_feat)
-
-        return cls_feats, reg_feats
 
 
 if __name__=='__main__':
-    import time
     from thop import profile
-    # Model config
     
-    # YOLOv4-Base config
-    class Yolov4BaseConfig(object):
+    # YOLOv2 configuration
+    class Yolov3BaseConfig(object):
         def __init__(self) -> None:
             # ---------------- Model config ----------------
-            self.width    = 0.50
-            self.depth    = 0.34
-            self.out_stride = [8, 16, 32]
-            self.max_stride = 32
-            self.num_levels = 3
-            ## Head
             self.head_dim  = 256
-            self.num_cls_head   = 2
-            self.num_reg_head   = 2
+            self.num_cls_head = 2
+            self.num_reg_head = 2
+    cfg = Yolov3BaseConfig()
 
-    cfg = Yolov4BaseConfig()
     # Build a head
-    pyramid_feats = [torch.randn(1, cfg.head_dim, 80, 80),
-                     torch.randn(1, cfg.head_dim, 40, 40),
-                     torch.randn(1, cfg.head_dim, 20, 20)]
-    head = Yolov4DetHead(cfg, [cfg.head_dim]*3)
+    model = DecoupledHead(cfg, in_dim= 256)
 
+    # Randomly generate a input data
+    x = torch.randn(2, 256, 20, 20)
 
     # Inference
-    t0 = time.time()
-    cls_feats, reg_feats = head(pyramid_feats)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    print("====== Yolov4 Head output ======")
-    for level, (cls_f, reg_f) in enumerate(zip(cls_feats, reg_feats)):
-        print("- Level-{} : ".format(level), cls_f.shape, reg_f.shape)
-
-    flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
+    cls_feats, reg_feats = model(x)
+    print(' - the shape of input :  ', x.shape)
+    print(' - the shape of cls feats : ', cls_feats.shape)
+    print(' - the shape of reg feats : ', reg_feats.shape)
+
+    x = torch.randn(1, 256, 20, 20)
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('============== FLOPs & Params ================')
+    print(' - FLOPs  : {:.2f} G'.format(flops / 1e9 * 2))
+    print(' - Params : {:.2f} M'.format(params / 1e6))

+ 42 - 29
yolo/models/yolov4/yolov4_neck.py

@@ -7,30 +7,18 @@ except:
     from  modules import ConvModule
 
 
-# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
 class SPPF(nn.Module):
     """
         This code referenced to https://github.com/ultralytics/yolov5
     """
-    def __init__(self, in_dim, out_dim):
+    def __init__(self, in_dim, out_dim, expand_ratio=0.5):
         super().__init__()
-        ## ----------- Basic Parameters -----------
-        inter_dim = in_dim // 2
+        inter_dim = int(in_dim * expand_ratio)
         self.out_dim = out_dim
-        ## ----------- Network Parameters -----------
-        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1, padding=0, stride=1)
-        self.cv2 = ConvModule(inter_dim * 4, out_dim, kernel_size=1, padding=0, stride=1)
+        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
+        self.cv2 = ConvModule(inter_dim * 4, out_dim, kernel_size=1)
         self.m = nn.MaxPool2d(kernel_size=5, stride=1, padding=2)
 
-        # Initialize all layers
-        self.init_weights()
-
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                m.reset_parameters()
-
     def forward(self, x):
         x = self.cv1(x)
         y1 = self.m(x)
@@ -38,26 +26,51 @@ class SPPF(nn.Module):
 
         return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
 
+class SPPFBlockCSP(nn.Module):
+    def __init__(self,
+                 in_dim: int,
+                 out_dim: int,
+                 expand_ratio: float = 0.5,
+                 ):
+        super(SPPFBlockCSP, self).__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = ConvModule(in_dim, inter_dim, kernel_size=1)
+        self.cv2 = ConvModule(in_dim, inter_dim, kernel_size=1)
+        self.m = nn.Sequential(
+            ConvModule(inter_dim, inter_dim, kernel_size=3),
+            SPPF(inter_dim, inter_dim, expand_ratio=1.0),
+            ConvModule(inter_dim, inter_dim, kernel_size=3)
+        )
+        self.cv3 = ConvModule(inter_dim * 2, self.out_dim, kernel_size=1)
+
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.m(x2)
+        y = self.cv3(torch.cat([x1, x3], dim=1))
+
+        return y
+
 
 if __name__=='__main__':
-    import time
     from thop import profile
-    # Model config
     
     # Build a neck
     in_dim  = 512
     out_dim = 512
-    neck = SPPF(in_dim, out_dim)
+    model = SPPFBlockCSP(512, 512, expand_ratio=0.5)
+
+    # Randomly generate a input data
+    x = torch.randn(2, in_dim, 20, 20)
 
     # Inference
+    output = model(x)
+    print(' - the shape of input :  ', x.shape)
+    print(' - the shape of output : ', output.shape)
+
     x = torch.randn(1, in_dim, 20, 20)
-    t0 = time.time()
-    output = neck(x)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    print('Neck output: ', output.shape)
-
-    flops, params = profile(neck, inputs=(x, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('============== FLOPs & Params ================')
+    print(' - FLOPs  : {:.2f} G'.format(flops / 1e9 * 2))
+    print(' - Params : {:.2f} M'.format(params / 1e6))

+ 66 - 85
yolo/models/yolov4/yolov4_pafpn.py

@@ -1,7 +1,7 @@
-from typing import List
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from typing import List
 
 try:
     from .modules import ConvModule, CSPBlock
@@ -9,96 +9,88 @@ except:
     from  modules import ConvModule, CSPBlock
 
 
-# Yolov4FPN
+# PaFPN-CSP
 class Yolov4PaFPN(nn.Module):
-    def __init__(self, cfg, in_dims: List = [256, 512, 1024]):
+    def __init__(self, 
+                 in_dims: List = [256, 512, 1024],
+                 head_dim: int = 256,
+                 ):
         super(Yolov4PaFPN, self).__init__()
         self.in_dims = in_dims
+        self.head_dim = head_dim
+        self.fpn_out_dims = [head_dim] * 3
         c3, c4, c5 = in_dims
 
-        # ---------------------- Yolov4's Top down FPN ----------------------
+        # top down
         ## P5 -> P4
-        self.reduce_layer_1   = ConvModule(c5, round(512*cfg.width), kernel_size=1, padding=0, stride=1)
-        self.top_down_layer_1 = CSPBlock(in_dim     = c4 + round(512*cfg.width),
-                                         out_dim    = round(512*cfg.width),
-                                         num_blocks = round(3*cfg.depth),
-                                         expansion  = 0.5,
-                                         shortcut   = False,
+        self.reduce_layer_1   = ConvModule(c5, 512, kernel_size=1)
+        self.top_down_layer_1 = CSPBlock(in_dim = c4 + 512,
+                                         out_dim = 512,
+                                         expand_ratio = 0.5,
+                                         num_blocks = 3,
+                                         shortcut = False,
                                          )
 
         ## P4 -> P3
-        self.reduce_layer_2   = ConvModule(round(512*cfg.width), round(256*cfg.width), kernel_size=1, padding=0, stride=1)
-        self.top_down_layer_2 = CSPBlock(in_dim     = c3 + round(256*cfg.width),
-                                         out_dim    = round(256*cfg.width),
-                                         num_blocks = round(3*cfg.depth),
-                                         expansion  = 0.5,
-                                         shortcut   = False,
+        self.reduce_layer_2   = ConvModule(512, 256, kernel_size=1)
+        self.top_down_layer_2 = CSPBlock(in_dim = c3 + 256, 
+                                         out_dim = 256,
+                                         expand_ratio = 0.5,
+                                         num_blocks = 3,
+                                         shortcut = False,
                                          )
-        
-        # ---------------------- Yolov4's Bottom up PAN ----------------------
+
+        # bottom up
         ## P3 -> P4
-        self.downsample_layer_1 = ConvModule(round(256*cfg.width), round(256*cfg.width), kernel_size=3, padding=1, stride=2)
-        self.bottom_up_layer_1  = CSPBlock(in_dim     = round(256*cfg.width) + round(256*cfg.width),
-                                           out_dim    = round(512*cfg.width),
-                                           num_blocks = round(3*cfg.depth),
-                                           expansion  = 0.5,
-                                           shortcut   = False,
-                                           )
+        self.reduce_layer_3    = ConvModule(256, 256, kernel_size=3, stride=2)
+        self.bottom_up_layer_1 = CSPBlock(in_dim = 256 + 256,
+                                          out_dim = 512,
+                                          expand_ratio = 0.5,
+                                          num_blocks = 3,
+                                          shortcut = False,
+                                          )
+
         ## P4 -> P5
-        self.downsample_layer_2 = ConvModule(round(512*cfg.width), round(512*cfg.width), kernel_size=3, padding=1, stride=2)
-        self.bottom_up_layer_2  = CSPBlock(in_dim     = round(512*cfg.width) + round(512*cfg.width),
-                                           out_dim    = round(1024*cfg.width),
-                                           num_blocks = round(3*cfg.depth),
-                                           expansion  = 0.5,
-                                           shortcut   = False,
-                                           )
-
-        # ---------------------- Yolov4's output projection ----------------------
-        self.out_layers = nn.ModuleList([
-            ConvModule(in_dim, round(cfg.head_dim*cfg.width), kernel_size=1)
-                      for in_dim in [round(256*cfg.width), round(512*cfg.width), round(1024*cfg.width)]
-                      ])
-        self.out_dims = [round(cfg.head_dim*cfg.width)] * 3
-
-        # Initialize all layers
-        self.init_weights()
-
-    def init_weights(self):
-        """Initialize the parameters."""
-        for m in self.modules():
-            if isinstance(m, torch.nn.Conv2d):
-                m.reset_parameters()
+        self.reduce_layer_4    = ConvModule(512, 512, kernel_size=3, stride=2)
+        self.bottom_up_layer_2 = CSPBlock(in_dim = 512 + 512,
+                                          out_dim = 1024,
+                                          expand_ratio = 0.5,
+                                          num_blocks = 3,
+                                          shortcut = False,
+                                          )
+
+        # output proj layers
+        self.out_layers = nn.ModuleList([ConvModule(in_dim, head_dim, kernel_size=1)
+                                         for in_dim in [256, 512, 1024]
+                                         ])
 
     def forward(self, features):
         c3, c4, c5 = features
-        
-        # ------------------ Top down FPN ------------------
-        ## P5 -> P4
-        p5 = self.reduce_layer_1(c5)
-        p5_up = F.interpolate(p5, scale_factor=2.0)
-        p4 = self.top_down_layer_1(torch.cat([c4, p5_up], dim=1))
-
-        ## P4 -> P3
-        p4 = self.reduce_layer_2(p4)
-        p4_up = F.interpolate(p4, scale_factor=2.0)
-        p3 = self.top_down_layer_2(torch.cat([c3, p4_up], dim=1))
-
-        # ------------------ Bottom up PAN ------------------
-        ## P3 -> P4
-        p3_ds = self.downsample_layer_1(p3)
-        p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
-
-        ## P4 -> P5
-        p4_ds = self.downsample_layer_2(p4)
-        p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
 
-        out_feats = [p3, p4, p5]
+        c6 = self.reduce_layer_1(c5)
+        c7 = F.interpolate(c6, scale_factor=2.0)   # s32->s16
+        c8 = torch.cat([c7, c4], dim=1)
+        c9 = self.top_down_layer_1(c8)
+        # P3/8
+        c10 = self.reduce_layer_2(c9)
+        c11 = F.interpolate(c10, scale_factor=2.0)   # s16->s8
+        c12 = torch.cat([c11, c3], dim=1)
+        c13 = self.top_down_layer_2(c12)  # to det
+        # p4/16
+        c14 = self.reduce_layer_3(c13)
+        c15 = torch.cat([c14, c10], dim=1)
+        c16 = self.bottom_up_layer_1(c15)  # to det
+        # p5/32
+        c17 = self.reduce_layer_4(c16)
+        c18 = torch.cat([c17, c6], dim=1)
+        c19 = self.bottom_up_layer_2(c18)  # to det
+
+        out_feats = [c13, c16, c19] # [P3, P4, P5]
 
         # output proj layers
         out_feats_proj = []
         for feat, layer in zip(out_feats, self.out_layers):
             out_feats_proj.append(layer(feat))
-            
         return out_feats_proj
 
 
@@ -107,27 +99,16 @@ if __name__=='__main__':
     from thop import profile
     # Model config
     
-    # YOLOv4-Base config
-    class Yolov4BaseConfig(object):
-        def __init__(self) -> None:
-            # ---------------- Model config ----------------
-            self.width    = 0.50
-            self.depth    = 0.34
-            self.out_stride = [8, 16, 32]
-            self.max_stride = 32
-            self.num_levels = 3
-            ## Head
-            self.head_dim = 256
-
-    cfg = Yolov4BaseConfig()
     # Build a head
     in_dims  = [128, 256, 512]
-    fpn = Yolov4PaFPN(cfg, in_dims)
+    fpn = Yolov4PaFPN(in_dims, head_dim=256)
 
-    # Inference
+    # Randomly generate a input data
     x = [torch.randn(1, in_dims[0], 80, 80),
          torch.randn(1, in_dims[1], 40, 40),
          torch.randn(1, in_dims[2], 20, 20)]
+    
+    # Inference
     t0 = time.time()
     output = fpn(x)
     t1 = time.time()

+ 0 - 216
yolo/models/yolov4/yolov4_pred.py

@@ -1,216 +0,0 @@
-import torch
-import torch.nn as nn
-from typing import List
-
-# -------------------- Detection Pred Layer --------------------
-class DetPredLayer(nn.Module):
-    def __init__(self,
-                 cls_dim      :int,
-                 reg_dim      :int,
-                 stride       :int,
-                 num_classes  :int,
-                 anchor_sizes :List,
-                 ):
-        super().__init__()
-        # --------- Basic Parameters ----------
-        self.stride  = stride
-        self.cls_dim = cls_dim
-        self.reg_dim = reg_dim
-        self.num_classes = num_classes
-        # ------------------- Anchor box -------------------
-        self.anchor_size = torch.as_tensor(anchor_sizes).float().view(-1, 2) # [A, 2]
-        self.num_anchors = self.anchor_size.shape[0]
-
-        # --------- Network Parameters ----------
-        self.obj_pred = nn.Conv2d(self.cls_dim, 1 * self.num_anchors, kernel_size=1)
-        self.cls_pred = nn.Conv2d(self.cls_dim, num_classes * self.num_anchors, kernel_size=1)
-        self.reg_pred = nn.Conv2d(self.reg_dim, 4 * self.num_anchors, kernel_size=1)                
-
-        self.init_bias()
-        
-    def init_bias(self):
-        # Init bias
-        init_prob = 0.01
-        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-        # obj pred
-        b = self.obj_pred.bias.view(1, -1)
-        b.data.fill_(bias_value.item())
-        self.obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # cls pred
-        b = self.cls_pred.bias.view(1, -1)
-        b.data.fill_(bias_value.item())
-        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # reg pred
-        b = self.reg_pred.bias.view(-1, )
-        b.data.fill_(1.0)
-        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        w = self.reg_pred.weight
-        w.data.fill_(0.)
-        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
-    def generate_anchors(self, fmp_size):
-        """
-            fmp_size: (List) [H, W]
-        """
-        # 特征图的宽和高
-        fmp_h, fmp_w = fmp_size
-
-        # 生成网格的x坐标和y坐标
-        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
-
-        # 将xy两部分的坐标拼起来:[H, W, 2] -> [HW, 2]
-        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
-        # [HW, 2] -> [HW, A, 2] -> [M, 2], M=HWA
-        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1)
-        anchor_xy = anchor_xy.view(-1, 2)
-
-        # [A, 2] -> [1, A, 2] -> [HW, A, 2] -> [M, 2], M=HWA
-        anchor_wh = self.anchor_size.unsqueeze(0).repeat(fmp_h*fmp_w, 1, 1)
-        anchor_wh = anchor_wh.view(-1, 2)
-
-        anchors = torch.cat([anchor_xy, anchor_wh], dim=-1)
-
-        return anchors
-        
-    def forward(self, cls_feat, reg_feat):
-        # 预测层
-        obj_pred = self.obj_pred(reg_feat)
-        cls_pred = self.cls_pred(cls_feat)
-        reg_pred = self.reg_pred(reg_feat)
-
-        # 生成网格坐标
-        B, _, H, W = cls_pred.size()
-        fmp_size = [H, W]
-        anchors = self.generate_anchors(fmp_size)
-        anchors = anchors.to(cls_pred.device)
-
-        # 对 pred 的size做一些view调整,便于后续的处理
-        # [B, C*A, H, W] -> [B, H, W, C*A] -> [B, H*W*A, C]
-        obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
-        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
-        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
-        
-        # 解算边界框坐标
-        cxcy_pred = (torch.sigmoid(reg_pred[..., :2]) + anchors[..., :2]) * self.stride
-        bwbh_pred = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
-        pred_x1y1 = cxcy_pred - bwbh_pred * 0.5
-        pred_x2y2 = cxcy_pred + bwbh_pred * 0.5
-        box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-        # output dict
-        outputs = {"pred_obj": obj_pred,       # (torch.Tensor) [B, M, 1]
-                   "pred_cls": cls_pred,       # (torch.Tensor) [B, M, C]
-                   "pred_reg": reg_pred,       # (torch.Tensor) [B, M, 4]
-                   "pred_box": box_pred,       # (torch.Tensor) [B, M, 4]
-                   "anchors" : anchors,        # (torch.Tensor) [M, 2]
-                   "fmp_size": fmp_size,
-                   "stride"  : self.stride,    # (Int)
-                   }
-
-        return outputs
-
-class Yolov4DetPredLayer(nn.Module):
-    def __init__(self, cfg):
-        super().__init__()
-        # --------- Basic Parameters ----------
-        self.cfg = cfg
-        self.num_levels = len(cfg.out_stride)
-
-        # ----------- Network Parameters -----------
-        ## pred layers
-        self.multi_level_preds = nn.ModuleList(
-            [DetPredLayer(cls_dim      = round(cfg.head_dim * cfg.width),
-                          reg_dim      = round(cfg.head_dim * cfg.width),
-                          stride       = cfg.out_stride[level],
-                          anchor_sizes = cfg.anchor_size[level],
-                          num_classes  = cfg.num_classes,)
-                          for level in range(self.num_levels)
-                          ])
-
-    def forward(self, cls_feats, reg_feats):
-        all_anchors = []
-        all_strides = []
-        all_fmp_sizes = []
-        all_obj_preds = []
-        all_cls_preds = []
-        all_reg_preds = []
-        all_box_preds = []
-        for level in range(self.num_levels):
-            # -------------- Single-level prediction --------------
-            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
-
-            # collect results
-            all_obj_preds.append(outputs["pred_obj"])
-            all_cls_preds.append(outputs["pred_cls"])
-            all_reg_preds.append(outputs["pred_reg"])
-            all_box_preds.append(outputs["pred_box"])
-            all_fmp_sizes.append(outputs["fmp_size"])
-            all_anchors.append(outputs["anchors"])
-        
-        # output dict
-        outputs = {"pred_obj":  all_obj_preds,         # List(Tensor) [B, M, 1]
-                   "pred_cls":  all_cls_preds,         # List(Tensor) [B, M, C]
-                   "pred_reg":  all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
-                   "pred_box":  all_box_preds,         # List(Tensor) [B, M, 4]
-                   "fmp_sizes": all_fmp_sizes,         # List(Tensor) [M, 1]
-                   "anchors":   all_anchors,           # List(Tensor) [M, 2]
-                   "strides":   self.cfg.out_stride,   # List(Int) = [8, 16, 32]
-                   }
-
-        return outputs
-
-
-if __name__=='__main__':
-    import time
-    from thop import profile
-    # Model config
-    
-    # YOLOv8-Base config
-    class Yolov4BaseConfig(object):
-        def __init__(self) -> None:
-            # ---------------- Model config ----------------
-            self.width    = 1.0
-            self.depth    = 1.0
-            self.out_stride = [8, 16, 32]
-            self.max_stride = 32
-            self.num_levels = 3
-            ## Head
-            self.head_dim  = 256
-            self.anchor_size = {0: [[10, 13],   [16, 30],   [33, 23]],
-                                1: [[30, 61],   [62, 45],   [59, 119]],
-                                2: [[116, 90],  [156, 198], [373, 326]]}
-
-    cfg = Yolov4BaseConfig()
-    cfg.num_classes = 20
-    # Build a pred layer
-    pred = Yolov4DetPredLayer(cfg)
-
-    # Inference
-    cls_feats = [torch.randn(1, cfg.head_dim, 80, 80),
-                 torch.randn(1, cfg.head_dim, 40, 40),
-                 torch.randn(1, cfg.head_dim, 20, 20),]
-    reg_feats = [torch.randn(1, cfg.head_dim, 80, 80),
-                 torch.randn(1, cfg.head_dim, 40, 40),
-                 torch.randn(1, cfg.head_dim, 20, 20),]
-    t0 = time.time()
-    output = pred(cls_feats, reg_feats)
-    t1 = time.time()
-    print('Time: ', t1 - t0)
-    print('====== Pred output ======= ')
-    pred_obj = output["pred_obj"]
-    pred_cls = output["pred_cls"]
-    pred_reg = output["pred_reg"]
-    pred_box = output["pred_box"]
-    anchors  = output["anchors"]
-    
-    for level in range(cfg.num_levels):
-        print("- Level-{} : objectness       -> {}".format(level, pred_obj[level].shape))
-        print("- Level-{} : classification   -> {}".format(level, pred_cls[level].shape))
-        print("- Level-{} : delta regression -> {}".format(level, pred_reg[level].shape))
-        print("- Level-{} : bbox regression  -> {}".format(level, pred_box[level].shape))
-        print("- Level-{} : anchor boxes     -> {}".format(level, anchors[level].shape))
-
-    flops, params = profile(pred, inputs=(cls_feats, reg_feats, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))