Browse Source

add yolov5

yjh0410 1 năm trước cách đây
mục cha
commit
366f291021

+ 4 - 1
config/__init__.py

@@ -3,13 +3,14 @@ from .yolov1_config   import build_yolov1_config
 from .yolov2_config   import build_yolov2_config
 from .yolov3_config   import build_yolov3_config
 from .yolov4_config   import build_yolov4_config
+from .yolov5_config   import build_yolov5_config
 from .yolov8_config   import build_yolov8_config
 from .rtdetr_config import build_rtdetr_config
 
 def build_config(args):
     print('==============================')
     print('Model: {} ...'.format(args.model.upper()))
-    # YOLOv8
+    # YOLO series
     if   'yolov1' in args.model:
         cfg = build_yolov1_config(args)
     elif 'yolov2' in args.model:
@@ -18,6 +19,8 @@ def build_config(args):
         cfg = build_yolov3_config(args)
     elif 'yolov4' in args.model:
         cfg = build_yolov4_config(args)
+    elif 'yolov5' in args.model:
+        cfg = build_yolov5_config(args)
     elif 'yolov8' in args.model:
         cfg = build_yolov8_config(args)
     # RT-DETR

+ 129 - 0
config/yolov5_config.py

@@ -0,0 +1,129 @@
+# yolo Config
+
+
+def build_yolov5_config(args):
+    if args.model == 'yolov5_s':
+        return Yolov5SConfig()
+    else:
+        raise NotImplementedError("No config for model: {}".format(args.model))
+    
+# YOLOv5-Base config
+class Yolov5BaseConfig(object):
+    def __init__(self) -> None:
+        # ---------------- Model config ----------------
+        self.width    = 1.0
+        self.depth    = 1.0
+        self.out_stride = [8, 16, 32]
+        self.max_stride = 32
+        self.num_levels = 3
+        self.scale      = "b"
+        ## Backbone
+        self.bk_act   = 'silu'
+        self.bk_norm  = 'BN'
+        self.bk_depthwise = False
+        ## Neck
+        self.neck_act       = 'silu'
+        self.neck_norm      = 'BN'
+        self.neck_depthwise = False
+        self.neck_expand_ratio = 0.5
+        self.spp_pooling_size  = 5
+        ## FPN
+        self.fpn_act  = 'silu'
+        self.fpn_norm = 'BN'
+        self.fpn_depthwise = False
+        ## Head
+        self.head_act  = 'silu'
+        self.head_norm = 'BN'
+        self.head_depthwise = False
+        self.head_dim       = 256
+        self.num_cls_head   = 2
+        self.num_reg_head   = 2
+        self.anchor_size    = {0: [[10, 13],   [16, 30],   [33, 23]],
+                               1: [[30, 61],   [62, 45],   [59, 119]],
+                               2: [[116, 90],  [156, 198], [373, 326]]}
+
+        # ---------------- Post-process config ----------------
+        ## Post process
+        self.val_topk = 1000
+        self.val_conf_thresh = 0.001
+        self.val_nms_thresh  = 0.7
+        self.test_topk = 100
+        self.test_conf_thresh = 0.2
+        self.test_nms_thresh  = 0.5
+
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        self.anchor_thresh = 4.0
+        ## Loss weight
+        self.loss_obj = 1.0
+        self.loss_cls = 1.0
+        self.loss_box = 5.0
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema = True
+        self.ema_decay = 0.9998
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.optimizer    = 'adamw'
+        self.per_image_lr = 0.001 / 64
+        self.base_lr      = None      # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
+        self.momentum     = 0.9
+        self.weight_decay = 0.05
+        self.clip_max_norm   = -1.
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8
+
+        # ---------------- Lr Scheduler config ----------------
+        self.warmup_epoch = 3
+        self.lr_scheduler = "cosine"
+        self.max_epoch    = 300
+        self.eval_epoch   = 10
+        self.no_aug_epoch = 20
+
+        # ---------------- Data process config ----------------
+        self.aug_type = 'yolo'
+        self.box_format = 'xyxy'
+        self.normalize_coords = False
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.15
+        self.copy_paste  = 0.0           # approximated by the YOLOX's mixup
+        self.multi_scale = [0.5, 1.25]   # multi scale: [img_size * 0.5, img_size * 1.25]
+        ## Pixel mean & std
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [255., 255., 255.]
+        ## Transforms
+        self.train_img_size = 640
+        self.test_img_size  = 640
+        self.use_ablu = True
+        self.affine_params = {
+            'degrees': 0.0,
+            'translate': 0.1,
+            'scale': [0.1, 2.0],
+            'shear': 0.0,
+            'perspective': 0.0,
+            'hsv_h': 0.015,
+            'hsv_s': 0.7,
+            'hsv_v': 0.4,
+        }
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))
+
+# YOLOv5-S
+class Yolov5SConfig(Yolov5BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.50
+        self.depth = 0.34
+        self.scale = "s"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0

+ 4 - 0
models/__init__.py

@@ -6,6 +6,7 @@ from .yolov1.build import build_yolov1
 from .yolov2.build import build_yolov2
 from .yolov3.build import build_yolov3
 from .yolov4.build import build_yolov4
+from .yolov5.build import build_yolov5
 from .yolov8.build import build_yolov8
 from .rtdetr.build import build_rtdetr
 
@@ -24,6 +25,9 @@ def build_model(args, cfg, is_val=False):
     ## Modified YOLOv4
     elif 'yolov4' in args.model:
         model, criterion = build_yolov4(cfg, is_val)
+    ## Modified YOLOv5
+    elif 'yolov5' in args.model:
+        model, criterion = build_yolov5(cfg, is_val)
     ## YOLOv8
     elif 'yolov8' in args.model:
         model, criterion = build_yolov8(cfg, is_val)

+ 11 - 18
models/yolov4/matcher.py

@@ -12,6 +12,7 @@ class Yolov4Matcher(object):
             for anchor in anchor_size]
             )  # [KA, 4]
 
+
     def compute_iou(self, anchor_boxes, gt_box):
         """
             anchor_boxes : ndarray -> [KA, 4] (cx, cy, bw, bh).
@@ -49,6 +50,7 @@ class Yolov4Matcher(object):
         
         return iou
 
+
     @torch.no_grad()
     def __call__(self, fmp_sizes, fpn_strides, targets):
         """
@@ -136,26 +138,17 @@ class Yolov4Matcher(object):
                 # label assignment
                 for result in label_assignment_results:
                     grid_x, grid_y, level, anchor_idx = result
-                    stride = fpn_strides[level]
-                    x1s, y1s = x1 / stride, y1 / stride
-                    x2s, y2s = x2 / stride, y2 / stride
                     fmp_h, fmp_w = fmp_sizes[level]
 
-                    # 3x3 center sampling
-                    for j in range(grid_y - 1, grid_y + 2):
-                        for i in range(grid_x - 1, grid_x + 2):
-                            is_in_box = (j >= y1s and j < y2s) and (i >= x1s and i < x2s)
-                            is_valid = (j >= 0 and j < fmp_h) and (i >= 0 and i < fmp_w)
-
-                            if is_in_box and is_valid:
-                                # obj
-                                gt_objectness[level][batch_index, j, i, anchor_idx] = 1.0
-                                # cls
-                                cls_ont_hot = torch.zeros(self.num_classes)
-                                cls_ont_hot[int(gt_label)] = 1.0
-                                gt_classes[level][batch_index, j, i, anchor_idx] = cls_ont_hot
-                                # box
-                                gt_bboxes[level][batch_index, j, i, anchor_idx] = torch.as_tensor([x1, y1, x2, y2])
+                    if grid_x < fmp_w and grid_y < fmp_h:
+                        # obj
+                        gt_objectness[level][batch_index, grid_y, grid_x, anchor_idx] = 1.0
+                        # cls
+                        cls_ont_hot = torch.zeros(self.num_classes)
+                        cls_ont_hot[int(gt_label)] = 1.0
+                        gt_classes[level][batch_index, grid_y, grid_x, anchor_idx] = cls_ont_hot
+                        # box
+                        gt_bboxes[level][batch_index, grid_y, grid_x, anchor_idx] = torch.as_tensor([x1, y1, x2, y2])
 
         # [B, M, C]
         gt_objectness = torch.cat([gt.view(bs, -1, 1) for gt in gt_objectness], dim=1).float()

+ 2 - 2
models/yolov4/yolov4_pred.py

@@ -59,7 +59,7 @@ class DetPredLayer(nn.Module):
         # 将xy两部分的坐标拼起来:[H, W, 2] -> [HW, 2]
         anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
         # [HW, 2] -> [HW, A, 2] -> [M, 2], M=HWA
-        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1) + 0.5
+        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1)
         anchor_xy = anchor_xy.view(-1, 2)
 
         # [A, 2] -> [1, A, 2] -> [HW, A, 2] -> [M, 2], M=HWA
@@ -89,7 +89,7 @@ class DetPredLayer(nn.Module):
         reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
         
         # 解算边界框坐标
-        cxcy_pred = (reg_pred[..., :2] + anchors[..., :2]) * self.stride
+        cxcy_pred = (torch.sigmoid(reg_pred[..., :2]) + anchors[..., :2]) * self.stride
         bwbh_pred = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
         pred_x1y1 = cxcy_pred - bwbh_pred * 0.5
         pred_x2y2 = cxcy_pred + bwbh_pred * 0.5

+ 24 - 0
models/yolov5/build.py

@@ -0,0 +1,24 @@
+import torch.nn as nn
+
+from .loss import SetCriterion
+from .yolov5 import Yolov5
+
+
+# build object detector
+def build_yolov5(cfg, is_val=False):
+    # -------------- Build YOLO --------------
+    model = Yolov5(cfg, is_val)
+
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+            
+    # -------------- Build criterion --------------
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+        
+    return model, criterion

+ 101 - 0
models/yolov5/loss.py

@@ -0,0 +1,101 @@
+import torch
+import torch.nn.functional as F
+
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+from .matcher import Yolov5Matcher
+
+
+class SetCriterion(object):
+    def __init__(self, cfg):
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        self.loss_obj_weight = cfg.loss_obj
+        self.loss_cls_weight = cfg.loss_cls
+        self.loss_box_weight = cfg.loss_box
+
+        # matcher
+        anchor_size = cfg.anchor_size[0] + cfg.anchor_size[1] + cfg.anchor_size[2]
+        self.matcher = Yolov5Matcher(cfg.num_classes, 3, anchor_size, cfg.anchor_thresh)
+
+    def loss_objectness(self, pred_obj, gt_obj):
+        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
+
+        return loss_obj
+    
+    def loss_classes(self, pred_cls, gt_label):
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+
+        return loss_cls
+
+    def loss_bboxes(self, pred_box, gt_box):
+        # regression loss
+        ious = get_ious(pred_box,
+                        gt_box,
+                        box_mode="xyxy",
+                        iou_type='giou')
+        loss_box = 1.0 - ious
+
+        return loss_box, ious
+
+    def __call__(self, outputs, targets):
+        device = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        fmp_sizes = outputs['fmp_sizes']
+        (
+            gt_objectness, 
+            gt_classes, 
+            gt_bboxes,
+            ) = self.matcher(fmp_sizes=fmp_sizes, 
+                             fpn_strides=fpn_strides, 
+                             targets=targets)
+        # List[B, M, C] -> [B, M, C] -> [BM, C]
+        pred_obj = torch.cat(outputs['pred_obj'], dim=1).view(-1)                      # [BM,]
+        pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes)    # [BM, C]
+        pred_box = torch.cat(outputs['pred_box'], dim=1).view(-1, 4)                   # [BM, 4]
+       
+        gt_objectness = gt_objectness.view(-1).to(device).float()               # [BM,]
+        gt_classes = gt_classes.view(-1, self.num_classes).to(device).float()   # [BM, C]
+        gt_bboxes = gt_bboxes.view(-1, 4).to(device).float()                    # [BM, 4]
+
+        pos_masks = (gt_objectness > 0)
+        num_fgs = pos_masks.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+
+        # box loss
+        pred_box_pos = pred_box[pos_masks]
+        gt_bboxes_pos = gt_bboxes[pos_masks]
+        loss_box, ious = self.loss_bboxes(pred_box_pos, gt_bboxes_pos)
+        loss_box = loss_box.sum() / num_fgs
+        
+        # cls loss
+        pred_cls_pos = pred_cls[pos_masks]
+        gt_classes_pos = gt_classes[pos_masks] * ious.unsqueeze(-1).clamp(0.)
+        loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # obj loss
+        loss_obj = self.loss_objectness(pred_obj, gt_objectness)
+        loss_obj = loss_obj.sum() / num_fgs
+
+        # total loss
+        losses = self.loss_obj_weight * loss_obj + \
+                 self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box
+
+        loss_dict = dict(
+                loss_obj = loss_obj,
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                losses = losses
+        )
+
+        return loss_dict
+    
+    
+if __name__ == "__main__":
+    pass

+ 214 - 0
models/yolov5/matcher.py

@@ -0,0 +1,214 @@
+import numpy as np
+import torch
+
+
+class Yolov5Matcher(object):
+    def __init__(self, num_classes, num_anchors, anchor_size, anchor_theshold):
+        self.num_classes = num_classes
+        self.num_anchors = num_anchors
+        self.anchor_theshold = anchor_theshold
+        # [KA, 2]
+        self.anchor_sizes = np.array([[anchor[0], anchor[1]]
+                                      for anchor in anchor_size])
+        # [KA, 4]
+        self.anchor_boxes = np.array([[0., 0., anchor[0], anchor[1]]
+                                      for anchor in anchor_size])
+
+    def compute_iou(self, anchor_boxes, gt_box):
+        """
+            anchor_boxes : ndarray -> [KA, 4] (cx, cy, bw, bh).
+            gt_box       : ndarray -> [1, 4] (cx, cy, bw, bh).
+        """
+        # anchors: [KA, 4]
+        anchors_xyxy = np.zeros_like(anchor_boxes)
+        anchors_area = anchor_boxes[..., 2] * anchor_boxes[..., 3]
+        # convert [cx, cy, bw, bh] -> [x1, y1, x2, y2]
+        anchors_xyxy[..., :2] = anchor_boxes[..., :2] - anchor_boxes[..., 2:] * 0.5  # x1y1
+        anchors_xyxy[..., 2:] = anchor_boxes[..., :2] + anchor_boxes[..., 2:] * 0.5  # x2y2
+        
+        # expand gt_box: [1, 4] -> [KA, 4]
+        gt_box = np.array(gt_box).reshape(-1, 4)
+        gt_box = np.repeat(gt_box, anchors_xyxy.shape[0], axis=0)
+        gt_box_area = gt_box[..., 2] * gt_box[..., 3]
+        # convert [cx, cy, bw, bh] -> [x1, y1, x2, y2]
+        gt_box_xyxy = np.zeros_like(gt_box)
+        gt_box_xyxy[..., :2] = gt_box[..., :2] - gt_box[..., 2:] * 0.5  # x1y1
+        gt_box_xyxy[..., 2:] = gt_box[..., :2] + gt_box[..., 2:] * 0.5  # x2y2
+
+        # intersection
+        inter_w = np.minimum(anchors_xyxy[:, 2], gt_box_xyxy[:, 2]) - \
+                  np.maximum(anchors_xyxy[:, 0], gt_box_xyxy[:, 0])
+        inter_h = np.minimum(anchors_xyxy[:, 3], gt_box_xyxy[:, 3]) - \
+                  np.maximum(anchors_xyxy[:, 1], gt_box_xyxy[:, 1])
+        inter_area = inter_w * inter_h
+        
+        # union
+        union_area = anchors_area + gt_box_area - inter_area
+
+        # iou
+        iou = inter_area / union_area
+        iou = np.clip(iou, a_min=1e-10, a_max=1.0)
+        
+        return iou
+
+    def iou_assignment(self, ctr_points, gt_box, fpn_strides):
+        # compute IoU
+        iou = self.compute_iou(self.anchor_boxes, gt_box)
+        iou_mask = (iou > 0.5)
+
+        label_assignment_results = []
+        if iou_mask.sum() == 0:
+            # We assign the anchor box with highest IoU score.
+            iou_ind = np.argmax(iou)
+
+            level = iou_ind // self.num_anchors              # pyramid level
+            anchor_idx = iou_ind - level * self.num_anchors  # anchor index
+
+            # get the corresponding stride
+            stride = fpn_strides[level]
+
+            # compute the grid cell
+            xc, yc = ctr_points
+            xc_s = xc / stride
+            yc_s = yc / stride
+            grid_x = int(xc_s)
+            grid_y = int(yc_s)
+
+            label_assignment_results.append([grid_x, grid_y, xc_s, yc_s, level, anchor_idx])
+        else:            
+            for iou_ind, iou_m in enumerate(iou_mask):
+                if iou_m:
+                    level = iou_ind // self.num_anchors              # pyramid level
+                    anchor_idx = iou_ind - level * self.num_anchors  # anchor index
+
+                    # get the corresponding stride
+                    stride = fpn_strides[level]
+
+                    # compute the gride cell
+                    xc, yc = ctr_points
+                    xc_s = xc / stride
+                    yc_s = yc / stride
+                    grid_x = int(xc_s)
+                    grid_y = int(yc_s)
+
+                    label_assignment_results.append([grid_x, grid_y, xc_s, yc_s, level, anchor_idx])
+
+        return label_assignment_results
+
+    def aspect_ratio_assignment(self, ctr_points, keeps, fpn_strides):
+        label_assignment_results = []
+        for keep_idx, keep in enumerate(keeps):
+            if keep:
+                level = keep_idx // self.num_anchors              # pyramid level
+                anchor_idx = keep_idx - level * self.num_anchors  # anchor index
+
+                # get the corresponding stride
+                stride = fpn_strides[level]
+
+                # compute the gride cell
+                xc, yc = ctr_points
+                xc_s = xc / stride
+                yc_s = yc / stride
+                grid_x = int(xc_s)
+                grid_y = int(yc_s)
+
+                label_assignment_results.append([grid_x, grid_y, xc_s, yc_s, level, anchor_idx])
+        
+        return label_assignment_results
+    
+    @torch.no_grad()
+    def __call__(self, fmp_sizes, fpn_strides, targets):
+        """
+            fmp_size: (List) [fmp_h, fmp_w]
+            fpn_strides: (List) -> [8, 16, 32, ...] stride of network output.
+            targets: (Dict) dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}
+        """
+        assert len(fmp_sizes) == len(fpn_strides)
+        # prepare
+        bs = len(targets)
+        gt_objectness = [
+            torch.zeros([bs, fmp_h, fmp_w, self.num_anchors, 1]) 
+            for (fmp_h, fmp_w) in fmp_sizes
+            ]
+        gt_classes = [
+            torch.zeros([bs, fmp_h, fmp_w, self.num_anchors, self.num_classes]) 
+            for (fmp_h, fmp_w) in fmp_sizes
+            ]
+        gt_bboxes = [
+            torch.zeros([bs, fmp_h, fmp_w, self.num_anchors, 4]) 
+            for (fmp_h, fmp_w) in fmp_sizes
+            ]
+
+        for batch_index in range(bs):
+            targets_per_image = targets[batch_index]
+            # [N,]
+            tgt_cls = targets_per_image["labels"].numpy()
+            # [N, 4]
+            tgt_box = targets_per_image['boxes'].numpy()
+
+            for gt_box, gt_label in zip(tgt_box, tgt_cls):
+                # get a bbox coords
+                x1, y1, x2, y2 = gt_box.tolist()
+                # xyxy -> cxcywh
+                xc, yc = (x2 + x1) * 0.5, (y2 + y1) * 0.5
+                bw, bh = x2 - x1, y2 - y1
+                gt_box = np.array([[0., 0., bw, bh]])
+
+                # check target
+                if bw < 1. or bh < 1.:
+                    # invalid target
+                    continue
+
+                # compute aspect ratio
+                ratios = gt_box[..., 2:] / self.anchor_sizes
+                keeps = np.maximum(ratios, 1 / ratios).max(-1) < self.anchor_theshold
+
+                if keeps.sum() == 0:
+                    label_assignment_results = self.iou_assignment([xc, yc], gt_box, fpn_strides)
+                else:
+                    label_assignment_results = self.aspect_ratio_assignment([xc, yc], keeps, fpn_strides)
+
+                # label assignment
+                for result in label_assignment_results:
+                    # assignment
+                    grid_x, grid_y, xc_s, yc_s, level, anchor_idx = result
+                    stride = fpn_strides[level]
+                    fmp_h, fmp_w = fmp_sizes[level]
+                    # coord on the feature
+                    x1s, y1s = x1 / stride, y1 / stride
+                    x2s, y2s = x2 / stride, y2 / stride
+                    # offset
+                    off_x = xc_s - grid_x
+                    off_y = yc_s - grid_y
+ 
+                    if off_x <= 0.5 and off_y <= 0.5:  # top left
+                        grids = [(grid_x-1, grid_y), (grid_x, grid_y-1), (grid_x, grid_y)]
+                    elif off_x > 0.5 and off_y <= 0.5: # top right
+                        grids = [(grid_x+1, grid_y), (grid_x, grid_y-1), (grid_x, grid_y)]
+                    elif off_x <= 0.5 and off_y > 0.5: # bottom left
+                        grids = [(grid_x-1, grid_y), (grid_x, grid_y+1), (grid_x, grid_y)]
+                    elif off_x > 0.5 and off_y > 0.5:  # bottom right
+                        grids = [(grid_x+1, grid_y), (grid_x, grid_y+1), (grid_x, grid_y)]
+
+                    for (i, j) in grids:
+                        is_in_box = (j >= y1s and j < y2s) and (i >= x1s and i < x2s)
+                        is_valid = (j >= 0 and j < fmp_h) and (i >= 0 and i < fmp_w)
+
+                        if is_in_box and is_valid:
+                            # obj
+                            gt_objectness[level][batch_index, j, i, anchor_idx] = 1.0
+                            # cls
+                            cls_ont_hot = torch.zeros(self.num_classes)
+                            cls_ont_hot[int(gt_label)] = 1.0
+                            gt_classes[level][batch_index, j, i, anchor_idx] = cls_ont_hot
+                            # box
+                            gt_bboxes[level][batch_index, j, i, anchor_idx] = torch.as_tensor([x1, y1, x2, y2])
+
+        # [B, M, C]
+        gt_objectness = torch.cat([gt.view(bs, -1, 1) for gt in gt_objectness], dim=1).float()
+        gt_classes = torch.cat([gt.view(bs, -1, self.num_classes) for gt in gt_classes], dim=1).float()
+        gt_bboxes = torch.cat([gt.view(bs, -1, 4) for gt in gt_bboxes], dim=1).float()
+
+        return gt_objectness, gt_classes, gt_bboxes

+ 155 - 0
models/yolov5/yolov5.py

@@ -0,0 +1,155 @@
+# --------------- Torch components ---------------
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from .yolov5_backbone import Yolov5Backbone
+from .yolov5_neck     import SPPF
+from .yolov5_pafpn    import Yolov5PaFPN
+from .yolov5_head     import Yolov5DetHead
+from .yolov5_pred     import Yolov5DetPredLayer
+
+# --------------- External components ---------------
+from utils.misc import multiclass_nms
+
+
+# YOLOv5
+class Yolov5(nn.Module):
+    def __init__(self,
+                 cfg,
+                 is_val = False,
+                 ) -> None:
+        super(Yolov5, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
+        self.no_multi_labels  = False if is_val else True
+        
+        # ---------------------- Network Parameters ----------------------
+        ## Backbone
+        self.backbone = Yolov5Backbone(cfg)
+        self.pyramid_feat_dims = self.backbone.feat_dims[-3:]
+        ## Neck: SPP
+        self.neck     = SPPF(cfg, self.pyramid_feat_dims[-1], self.pyramid_feat_dims[-1])
+        ## Neck: FPN
+        self.fpn      = Yolov5PaFPN(cfg, self.pyramid_feat_dims)
+        ## Head
+        self.head     = Yolov5DetHead(cfg, self.fpn.out_dims)
+        ## Pred
+        self.pred     = Yolov5DetPredLayer(cfg)
+
+    def post_process(self, obj_preds, cls_preds, box_preds):
+        """
+        We process predictions at each scale hierarchically
+        Input:
+            obj_preds: List[torch.Tensor] -> [[B, M, 1], ...], B=1
+            cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
+            box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
+        Output:
+            bboxes: np.array -> [N, 4]
+            scores: np.array -> [N,]
+            labels: np.array -> [N,]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
+            obj_pred_i = obj_pred_i[0]
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
+            if self.no_multi_labels:
+                # [M,]
+                scores, labels = torch.max(
+                    torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()), dim=1)
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # topk candidates
+                predicted_prob, topk_idxs = scores.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                labels = labels[topk_idxs]
+                bboxes = box_pred_i[topk_idxs]
+            else:
+                # [M, C] -> [MC,]
+                scores_i = torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()).flatten()
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # torch.sort is actually faster than .topk (at least on GPUs)
+                predicted_prob, topk_idxs = scores_i.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+                labels = topk_idxs % self.num_classes
+
+                bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores, dim=0)
+        labels = torch.cat(all_labels, dim=0)
+        bboxes = torch.cat(all_bboxes, dim=0)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes)
+        
+        return bboxes, scores, labels
+    
+    def forward(self, x):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(x)
+        # ---------------- Neck: SPP ----------------
+        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+        # ---------------- Neck: PaFPN ----------------
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # ---------------- Heads ----------------
+        cls_feats, reg_feats = self.head(pyramid_feats)
+
+        # ---------------- Preds ----------------
+        outputs = self.pred(cls_feats, reg_feats)
+        outputs['image_size'] = [x.shape[2], x.shape[3]]
+
+        if not self.training:
+            all_obj_preds = outputs['pred_obj']
+            all_cls_preds = outputs['pred_cls']
+            all_box_preds = outputs['pred_box']
+
+            # post process
+            bboxes, scores, labels = self.post_process(all_obj_preds, all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
+        
+        return outputs 

+ 135 - 0
models/yolov5/yolov5_backbone.py

@@ -0,0 +1,135 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov5_basic import BasicConv, CSPBlock
+except:
+    from  yolov5_basic import BasicConv, CSPBlock
+
+
+# --------------------- Yolov3's Backbone -----------------------
+## Modified DarkNet
+class Yolov5Backbone(nn.Module):
+    def __init__(self, cfg):
+        super(Yolov5Backbone, self).__init__()
+        # ------------------ Basic setting ------------------
+        self.model_scale = cfg.scale
+        self.feat_dims = [round(64   * cfg.width),
+                          round(128  * cfg.width),
+                          round(256  * cfg.width),
+                          round(512  * cfg.width),
+                          round(1024 * cfg.width)]
+        
+        # ------------------ Network setting ------------------
+        ## P1/2
+        self.layer_1 = BasicConv(3, self.feat_dims[0],
+                                 kernel_size=6, padding=2, stride=2,
+                                 act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
+        # P2/4
+        self.layer_2 = nn.Sequential(
+            BasicConv(self.feat_dims[0], self.feat_dims[1],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            CSPBlock(in_dim     = self.feat_dims[1],
+                     out_dim    = self.feat_dims[1],
+                     num_blocks = round(3*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     act_type   = cfg.bk_act,
+                     norm_type  = cfg.bk_norm,
+                     depthwise  = cfg.bk_depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            BasicConv(self.feat_dims[1], self.feat_dims[2],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            CSPBlock(in_dim     = self.feat_dims[2],
+                     out_dim    = self.feat_dims[2],
+                     num_blocks = round(9*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     act_type   = cfg.bk_act,
+                     norm_type  = cfg.bk_norm,
+                     depthwise  = cfg.bk_depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            BasicConv(self.feat_dims[2], self.feat_dims[3],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            CSPBlock(in_dim     = self.feat_dims[3],
+                     out_dim    = self.feat_dims[3],
+                     num_blocks = round(9*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     act_type   = cfg.bk_act,
+                     norm_type  = cfg.bk_norm,
+                     depthwise  = cfg.bk_depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            BasicConv(self.feat_dims[3], self.feat_dims[4],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            CSPBlock(in_dim     = self.feat_dims[4],
+                     out_dim    = self.feat_dims[4],
+                     num_blocks = round(3*cfg.depth),
+                     expansion  = 0.5,
+                     shortcut   = True,
+                     act_type   = cfg.bk_act,
+                     norm_type  = cfg.bk_norm,
+                     depthwise  = cfg.bk_depthwise)
+        )
+
+        # Initialize all layers
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    class BaseConfig(object):
+        def __init__(self) -> None:
+            self.bk_act = 'silu'
+            self.bk_norm = 'BN'
+            self.bk_depthwise = False
+            self.width = 0.5
+            self.depth = 0.34
+            self.scale = "s"
+            self.use_pretrained = True
+
+    cfg = BaseConfig()
+    model = Yolov5Backbone(cfg)
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    x = torch.randn(1, 3, 640, 640)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 137 - 0
models/yolov5/yolov5_basic.py

@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.norm2(self.conv2(x))
+            return x
+
+
+# ---------------------------- Basic Modules ----------------------------
+class YoloBottleneck(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 kernel_size  :List  = [1, 3],
+                 expansion    :float = 0.5,
+                 shortcut     :bool  = False,
+                 act_type     :str   = 'silu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False,
+                 ) -> None:
+        super(YoloBottleneck, self).__init__()
+        inter_dim = int(out_dim * expansion)
+        # ----------------- Network setting -----------------
+        self.conv_layer1 = BasicConv(in_dim, inter_dim,
+                                     kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.conv_layer2 = BasicConv(inter_dim, out_dim,
+                                     kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.conv_layer2(self.conv_layer1(x))
+
+        return x + h if self.shortcut else h
+
+class CSPBlock(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 num_blocks   :int   = 1,
+                 expansion    :float = 0.5,
+                 shortcut     :bool  = False,
+                 act_type     :str   = 'silu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False,
+                 ):
+        super(CSPBlock, self).__init__()
+        # ---------- Basic parameters ----------
+        self.num_blocks = num_blocks
+        self.expansion = expansion
+        self.shortcut = shortcut
+        inter_dim = round(out_dim * expansion)
+        # ---------- Model parameters ----------
+        self.conv_layer_1 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.conv_layer_2 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.conv_layer_3 = BasicConv(inter_dim * 2, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.module       = nn.Sequential(*[YoloBottleneck(inter_dim,
+                                                           inter_dim,
+                                                           kernel_size  = [1, 3],
+                                                           expansion    = 1.0,
+                                                           shortcut     = shortcut,
+                                                           act_type     = act_type,
+                                                           norm_type    = norm_type,
+                                                           depthwise    = depthwise)
+                                                           for _ in range(num_blocks)
+                                                           ])
+
+    def forward(self, x):
+        x1 = self.conv_layer_1(x)
+        x2 = self.module(self.conv_layer_2(x))
+        out = self.conv_layer_3(torch.cat([x1, x2], dim=1))
+
+        return out
+    

+ 171 - 0
models/yolov5/yolov5_head.py

@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov5_basic import BasicConv
+except:
+    from  yolov5_basic import BasicConv
+
+
+## Single-level Detection Head
+class DetHead(nn.Module):
+    def __init__(self,
+                 in_dim       :int  = 256,
+                 cls_head_dim :int  = 256,
+                 reg_head_dim :int  = 256,
+                 num_cls_head :int  = 2,
+                 num_reg_head :int  = 2,
+                 act_type     :str  = "silu",
+                 norm_type    :str  = "BN",
+                 depthwise    :bool = False):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dim = in_dim
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        self.act_type = act_type
+        self.norm_type = norm_type
+        self.depthwise = depthwise
+        
+        # --------- Network Parameters ----------
+        ## cls head
+        cls_feats = []
+        self.cls_head_dim = cls_head_dim
+        for i in range(num_cls_head):
+            if i == 0:
+                cls_feats.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                cls_feats.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        ## reg head
+        reg_feats = []
+        self.reg_head_dim = reg_head_dim
+        for i in range(num_reg_head):
+            if i == 0:
+                reg_feats.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                reg_feats.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+## Multi-level Detection Head
+class Yolov5DetHead(nn.Module):
+    def __init__(self, cfg, in_dims):
+        super().__init__()
+        ## ----------- Network Parameters -----------
+        self.multi_level_heads = nn.ModuleList(
+            [DetHead(in_dim       = in_dims[level],
+                     cls_head_dim = round(cfg.head_dim * cfg.width),
+                     reg_head_dim = round(cfg.head_dim * cfg.width),
+                     num_cls_head = cfg.num_cls_head,
+                     num_reg_head = cfg.num_reg_head,
+                     act_type     = cfg.head_act,
+                     norm_type    = cfg.head_norm,
+                     depthwise    = cfg.head_depthwise)
+                     for level in range(cfg.num_levels)
+                     ])
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.cls_head_dim = cfg.head_dim
+        self.reg_head_dim = cfg.head_dim
+
+
+    def forward(self, feats):
+        """
+            feats: List[(Tensor)] [[B, C, H, W], ...]
+        """
+        cls_feats = []
+        reg_feats = []
+        for feat, head in zip(feats, self.multi_level_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+
+        return cls_feats, reg_feats
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv3-Base config
+    class Yolov5BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.out_stride = 32
+            self.max_stride = 32
+            self.num_levels = 3
+            ## Head
+            self.head_act  = 'lrelu'
+            self.head_norm = 'BN'
+            self.head_depthwise = False
+            self.head_dim  = 256
+            self.num_cls_head   = 2
+            self.num_reg_head   = 2
+
+    cfg = Yolov5BaseConfig()
+    # Build a head
+    pyramid_feats = [torch.randn(1, cfg.head_dim, 80, 80),
+                     torch.randn(1, cfg.head_dim, 40, 40),
+                     torch.randn(1, cfg.head_dim, 20, 20)]
+    head = Yolov5DetHead(cfg, [cfg.head_dim]*3)
+
+
+    # Inference
+    t0 = time.time()
+    cls_feats, reg_feats = head(pyramid_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for cls_f, reg_f in zip(cls_feats, reg_feats):
+        print(cls_f.shape, reg_f.shape)
+
+    print('==============================')
+    flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))    

+ 33 - 0
models/yolov5/yolov5_neck.py

@@ -0,0 +1,33 @@
+import torch
+import torch.nn as nn
+
+from .yolov5_basic import BasicConv
+
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    """
+        This code referenced to https://github.com/ultralytics/yolov5
+    """
+    def __init__(self, cfg, in_dim, out_dim):
+        super().__init__()
+        ## ----------- Basic Parameters -----------
+        inter_dim = round(in_dim * cfg.neck_expand_ratio)
+        self.out_dim = out_dim
+        ## ----------- Network Parameters -----------
+        self.cv1 = BasicConv(in_dim, inter_dim,
+                             kernel_size=1, padding=0, stride=1,
+                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.cv2 = BasicConv(inter_dim * 4, out_dim,
+                             kernel_size=1, padding=0, stride=1,
+                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.m = nn.MaxPool2d(kernel_size=cfg.spp_pooling_size,
+                              stride=1,
+                              padding=cfg.spp_pooling_size // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))

+ 102 - 0
models/yolov5/yolov5_pafpn.py

@@ -0,0 +1,102 @@
+from typing import List
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .yolov5_basic import BasicConv, CSPBlock
+
+
+# Yolov5FPN
+class Yolov5PaFPN(nn.Module):
+    def __init__(self, cfg, in_dims: List = [256, 512, 1024],
+                 ):
+        super(Yolov5PaFPN, self).__init__()
+        self.in_dims = in_dims
+        c3, c4, c5 = in_dims
+
+        # ---------------------- Yolov5's Top down FPN ----------------------
+        ## P5 -> P4
+        self.reduce_layer_1   = BasicConv(c5, round(512*cfg.width), kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
+        self.top_down_layer_1 = CSPBlock(in_dim     = c4 + round(512*cfg.width),
+                                         out_dim    = round(512*cfg.width),
+                                         num_blocks = round(3*cfg.depth),
+                                         expansion  = 0.5,
+                                         shortcut   = False,
+                                         act_type   = cfg.fpn_act,
+                                         norm_type  = cfg.fpn_norm,
+                                         depthwise  = cfg.fpn_depthwise)
+
+        ## P4 -> P3
+        self.reduce_layer_2   = BasicConv(round(512*cfg.width), round(256*cfg.width), kernel_size=1, act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
+        self.top_down_layer_2 = CSPBlock(in_dim     = c3 + round(256*cfg.width),
+                                         out_dim    = round(256*cfg.width),
+                                         num_blocks = round(3*cfg.depth),
+                                         expansion  = 0.5,
+                                         shortcut   = False,
+                                         act_type   = cfg.fpn_act,
+                                         norm_type  = cfg.fpn_norm,
+                                         depthwise  = cfg.fpn_depthwise)
+        
+        # ---------------------- Yolov5's Bottom up PAN ----------------------
+        ## P3 -> P4
+        self.downsample_layer_1 = BasicConv(round(256*cfg.width), round(256*cfg.width),
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
+        self.bottom_up_layer_1  = CSPBlock(in_dim     = round(256*cfg.width) + round(256*cfg.width),
+                                           out_dim    = round(512*cfg.width),
+                                           num_blocks = round(3*cfg.depth),
+                                           expansion  = 0.5,
+                                           shortcut   = False,
+                                           act_type   = cfg.fpn_act,
+                                           norm_type  = cfg.fpn_norm,
+                                           depthwise  = cfg.fpn_depthwise)
+        ## P4 -> P5
+        self.downsample_layer_2 = BasicConv(round(512*cfg.width), round(512*cfg.width),
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
+        self.bottom_up_layer_2  = CSPBlock(in_dim     = round(512*cfg.width) + round(512*cfg.width),
+                                           out_dim    = round(1024*cfg.width),
+                                           num_blocks = round(3*cfg.depth),
+                                           expansion  = 0.5,
+                                           shortcut   = False,
+                                           act_type   = cfg.fpn_act,
+                                           norm_type  = cfg.fpn_norm,
+                                           depthwise  = cfg.fpn_depthwise)
+
+        # ---------------------- Yolov5's output projection ----------------------
+        self.out_layers = nn.ModuleList([
+            BasicConv(in_dim, round(cfg.head_dim*cfg.width), kernel_size=1,
+                      act_type=cfg.fpn_act, norm_type=cfg.fpn_norm)
+                      for in_dim in [round(256*cfg.width), round(512*cfg.width), round(1024*cfg.width)]
+                      ])
+        self.out_dims = [round(cfg.head_dim*cfg.width)] * 3
+
+    def forward(self, features):
+        c3, c4, c5 = features
+        
+        # P5 -> P4
+        p5 = self.reduce_layer_1(c5)
+        p5_up = F.interpolate(p5, scale_factor=2.0)
+        p4 = self.top_down_layer_1(torch.cat([c4, p5_up], dim=1))
+
+        # P4 -> P3
+        p4 = self.reduce_layer_2(p4)
+        p4_up = F.interpolate(p4, scale_factor=2.0)
+        p3 = self.top_down_layer_2(torch.cat([c3, p4_up], dim=1))
+
+        # P3 -> P4
+        p3_ds = self.downsample_layer_1(p3)
+        p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
+
+        # P4 -> P5
+        p4_ds = self.downsample_layer_2(p4)
+        p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
+
+        out_feats = [p3, p4, p5]
+
+        # output proj layers
+        out_feats_proj = []
+        for feat, layer in zip(out_feats, self.out_layers):
+            out_feats_proj.append(layer(feat))
+            
+        return out_feats_proj

+ 158 - 0
models/yolov5/yolov5_pred.py

@@ -0,0 +1,158 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+# -------------------- Detection Pred Layer --------------------
+## Single-level pred layer
+class DetPredLayer(nn.Module):
+    def __init__(self,
+                 cls_dim      :int,
+                 reg_dim      :int,
+                 stride       :int,
+                 num_classes  :int,
+                 anchor_sizes :List,
+                 ):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.stride  = stride
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.num_classes = num_classes
+        # ------------------- Anchor box -------------------
+        self.anchor_size = torch.as_tensor(anchor_sizes).float().view(-1, 2) # [A, 2]
+        self.num_anchors = self.anchor_size.shape[0]
+
+        # --------- Network Parameters ----------
+        self.obj_pred = nn.Conv2d(self.cls_dim, 1 * self.num_anchors, kernel_size=1)
+        self.cls_pred = nn.Conv2d(self.cls_dim, num_classes * self.num_anchors, kernel_size=1)
+        self.reg_pred = nn.Conv2d(self.reg_dim, 4 * self.num_anchors, kernel_size=1)                
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # Init bias
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        # obj pred
+        b = self.obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        self.obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # cls pred
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+    def generate_anchors(self, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # 特征图的宽和高
+        fmp_h, fmp_w = fmp_size
+
+        # 生成网格的x坐标和y坐标
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+
+        # 将xy两部分的坐标拼起来:[H, W, 2] -> [HW, 2]
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        # [HW, 2] -> [HW, A, 2] -> [M, 2], M=HWA
+        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1)
+        anchor_xy = anchor_xy.view(-1, 2)
+
+        # [A, 2] -> [1, A, 2] -> [HW, A, 2] -> [M, 2], M=HWA
+        anchor_wh = self.anchor_size.unsqueeze(0).repeat(fmp_h*fmp_w, 1, 1)
+        anchor_wh = anchor_wh.view(-1, 2)
+
+        anchors = torch.cat([anchor_xy, anchor_wh], dim=-1)
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat):
+        # 预测层
+        obj_pred = self.obj_pred(reg_feat)
+        cls_pred = self.cls_pred(cls_feat)
+        reg_pred = self.reg_pred(reg_feat)
+
+        # 生成网格坐标
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+
+        # 对 pred 的size做一些view调整,便于后续的处理
+        # [B, C*A, H, W] -> [B, H, W, C*A] -> [B, H*W*A, C]
+        obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+        
+        # 解算边界框坐标
+        cxcy_pred = (torch.sigmoid(reg_pred[..., :2]) * 2.0 - 0.5 + anchors[..., :2]) * self.stride
+        bwbh_pred = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
+        pred_x1y1 = cxcy_pred - bwbh_pred * 0.5
+        pred_x2y2 = cxcy_pred + bwbh_pred * 0.5
+        box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        # output dict
+        outputs = {"pred_obj": obj_pred,       # (torch.Tensor) [B, M, 1]
+                   "pred_cls": cls_pred,       # (torch.Tensor) [B, M, C]
+                   "pred_reg": reg_pred,       # (torch.Tensor) [B, M, 4]
+                   "pred_box": box_pred,       # (torch.Tensor) [B, M, 4]
+                   "anchors" : anchors,        # (torch.Tensor) [M, 2]
+                   "fmp_size": fmp_size,
+                   "stride"  : self.stride,    # (Int)
+                   }
+
+        return outputs
+
+## Multi-level pred layer
+class Yolov5DetPredLayer(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+
+        # ----------- Network Parameters -----------
+        ## pred layers
+        self.multi_level_preds = nn.ModuleList(
+            [DetPredLayer(cls_dim      = round(cfg.head_dim * cfg.width),
+                          reg_dim      = round(cfg.head_dim * cfg.width),
+                          stride       = cfg.out_stride[level],
+                          anchor_sizes = cfg.anchor_size[level],
+                          num_classes  = cfg.num_classes,)
+                          for level in range(cfg.num_levels)
+                          ])
+
+    def forward(self, cls_feats, reg_feats):
+        all_anchors = []
+        all_strides = []
+        all_fmp_sizes = []
+        all_obj_preds = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        for level in range(self.cfg.num_levels):
+            # -------------- Single-level prediction --------------
+            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
+
+            # collect results
+            all_obj_preds.append(outputs["pred_obj"])
+            all_cls_preds.append(outputs["pred_cls"])
+            all_reg_preds.append(outputs["pred_reg"])
+            all_box_preds.append(outputs["pred_box"])
+            all_fmp_sizes.append(outputs["fmp_size"])
+            all_anchors.append(outputs["anchors"])
+        
+        # output dict
+        outputs = {"pred_obj":  all_obj_preds,         # List(Tensor) [B, M, 1]
+                   "pred_cls":  all_cls_preds,         # List(Tensor) [B, M, C]
+                   "pred_reg":  all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box":  all_box_preds,         # List(Tensor) [B, M, 4]
+                   "fmp_sizes": all_fmp_sizes,         # List(Tensor) [M, 1]
+                   "anchors":   all_anchors,           # List(Tensor) [M, 2]
+                   "strides":   self.cfg.out_stride,   # List(Int) = [8, 16, 32]
+                   }
+
+        return outputs