瀏覽代碼

add YOLOv5 & YOLOv7

yjh0410 2 年之前
父節點
當前提交
ee2cca5217

+ 18 - 14
README.md

@@ -64,13 +64,15 @@ For example:
 python train.py --cuda -d voc --root path/to/VOCdevkit -v yolov1 -bs 16 --max_epoch 150 --wp_epoch 1 --eval_epoch 10 --fp16 --ema --multi_scale
 ```
 
-| Model  | Scale |  IP  | Epoch | AP50 | FPS<sup>3090<br>FP32-bs1 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
-|--------|-------|------|-------|------|--------------------------|-------------------|--------------------|--------|
-| YOLOv1 |  640  |  √   |  150  | 76.7 |                          |   37.8            |   21.3             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov1_voc.pth) |
-| YOLOv2 |  640  |  √   |  150  | 79.8 |                          |   53.9            |   30.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov2_voc.pth) |
-| YOLOv3 |  640  |  √   |  150  | 82.0 |                          |   167.4           |   54.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov3_voc.pth) |
-| YOLOv4 |  640  |  √   |  150  | 83.6 |                          |   162.7           |   61.5             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov4_voc.pth) |
-| YOLOX  |  640  |  √   |  150  |      |                          |                   |                    |  |
+| Model  |   Backbone    | Scale |  IP  | Epoch | AP50 | FPS<sup>3090<br>FP32-bs1 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
+|--------|---------------|-------|------|-------|------|--------------------------|-------------------|--------------------|--------|
+| YOLOv1 | ResNet-18     |  640  |  √   |  150  | 76.7 |                          |   37.8            |   21.3             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov1_voc.pth) |
+| YOLOv2 | DarkNet-19    |  640  |  √   |  150  | 79.8 |                          |   53.9            |   30.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov2_voc.pth) |
+| YOLOv3 | DarkNet-53    |  640  |  √   |  150  | 82.0 |                          |   167.4           |   54.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov3_voc.pth) |
+| YOLOv4 | CSPDarkNet-53 |  640  |  √   |  150  | 83.6 |                          |   162.7           |   61.5             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov4_voc.pth) |
+| YOLOv5 | CSPDarkNet-L  |  640  |  √   |  150  |      |                          |                   |                    |  |
+| YOLOX  | CSPDarkNet-L  |  640  |  √   |  150  |      |                          |                   |                    |  |
+| YOLOv7 | ELANNet       |  640  |  √   |  150  |      |                          |                   |                    |  |
 
 *All models are trained with ImageNet pretrained weight (IP). All FLOPs are measured with a 640x640 image size on VOC2007 test. The FPS is measured with batch size 1 on 3090 GPU from the model inference to the NMS operation.*
 
@@ -96,13 +98,15 @@ For example:
 python train.py --cuda -d coco --root path/to/COCO -v yolov1 -bs 16 --max_epoch 150 --wp_epoch 1 --eval_epoch 10 --fp16 --ema --multi_scale
 ```
 
-| Model  | Scale |  IP  | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>test<br>50 | Weight |
-|--------|-------|------|-------|------------------------|-------------------|--------|
-| YOLOv1 |  640  |  √   |  150  |        27.9            |       47.5        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov1_coco.pth) |
-| YOLOv2 |  640  |  √   |  150  |        32.7            |       50.9        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov2_coco.pth) |
-| YOLOv3 |  640  |  √   |  250  |                        |                         |  |
-| YOLOv4 |  640  |  √   |  250  |                        |                         |  |
-| YOLOX  |  640  |  √   |  250  |                        |                         |  |
+| Model  |   Backbone    | Scale |  IP  | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>test<br>50 | Weight |
+|--------|---------------|-------|------|-------|------------------------|-------------------|--------|
+| YOLOv1 | ResNet-18     |  640  |  √   |  150  |        27.9            |       47.5        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov1_coco.pth) |
+| YOLOv2 | DarkNet-19    |  640  |  √   |  150  |        32.7            |       50.9        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov2_coco.pth) |
+| YOLOv3 | DarkNet-53    |  640  |  √   |  250  |                        |                   |  |
+| YOLOv4 | CSPDarkNet-53 |  640  |  √   |  250  |                        |                   |  |
+| YOLOv5 | CSPDarkNet-L  |  640  |  √   |  250  |                        |                   |  |
+| YOLOX  | CSPDarkNet-L  |  640  |  √   |  250  |                        |                   |  |
+| YOLOv7 | ELANNet       |  640  |  √   |  250  |                        |                   |  |
 
 *All models are trained with ImageNet pretrained weight (IP). All FLOPs are measured with a 640x640 image size on COCO val2017. The FPS is measured with batch size 1 on 3090 GPU from the model inference to the NMS operation.*
 

+ 18 - 14
README_CN.md

@@ -64,13 +64,15 @@ python train.py --cuda -d voc --root path/to/VOC -v yolov1 -bs 16 --max_epoch 15
 
 **P5-Model on COCO:**
 
-| Model  | Scale |  IP  | Epoch | AP50 | FPS<sup>3090<br>FP32-bs1 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
-|--------|-------|------|-------|------|--------------------------|-------------------|--------------------|--------|
-| YOLOv1 |  640  |  √   |  150  | 76.7 |                          |   37.8            |   21.3             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov1_voc.pth) |
-| YOLOv2 |  640  |  √   |  150  | 79.8 |                          |   53.9            |   30.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov2_voc.pth) |
-| YOLOv3 |  640  |  √   |  150  | 82.0 |                          |   167.4           |   54.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov3_voc.pth) |
-| YOLOv4 |  640  |  √   |  150  | 83.6 |                          |   162.7           |   61.5             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov4_voc.pth) |
-| YOLOX  |  640  |  √   |  150  |      |                          |                   |                    |  |
+| Model  |   Backbone    | Scale |  IP  | Epoch | AP50 | FPS<sup>3090<br>FP32-bs1 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
+|--------|---------------|-------|------|-------|------|--------------------------|-------------------|--------------------|--------|
+| YOLOv1 | ResNet-18     |  640  |  √   |  150  | 76.7 |                          |   37.8            |   21.3             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov1_voc.pth) |
+| YOLOv2 | DarkNet-19    |  640  |  √   |  150  | 79.8 |                          |   53.9            |   30.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov2_voc.pth) |
+| YOLOv3 | DarkNet-53    |  640  |  √   |  150  | 82.0 |                          |   167.4           |   54.9             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov3_voc.pth) |
+| YOLOv4 | CSPDarkNet-53 |  640  |  √   |  150  | 83.6 |                          |   162.7           |   61.5             | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpy/yolov4_voc.pth) |
+| YOLOv5 | CSPDarkNet-L  |  640  |  √   |  150  |      |                          |                   |                    |  |
+| YOLOX  | CSPDarkNet-L  |  640  |  √   |  150  |      |                          |                   |                    |  |
+| YOLOv7 | ELANNet       |  640  |  √   |  150  |      |                          |                   |                    |  |
 
 *所有的模型都使用了ImageNet预训练权重(IP),所有的FLOPs都是在VOC2007 test数据集上以640x640或1280x1280的输入尺寸来测试的。FPS指标是在一张3090型号的GPU上以batch size=1的输入来测试的,请注意,测速的内容包括模型前向推理、后处理以及NMS操作。*
 
@@ -99,13 +101,15 @@ python train.py --cuda -d coco --root path/to/COCO -v yolov1 -bs 16 --max_epoch
 
 **P5-Model on COCO:**
 
-| Model  | Scale |  IP  | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>test<br>0.5:0.95 | Weight |
-|--------|-------|------|-------|------------------------|-------------------------|--------|
-| YOLOv1 |  640  |  √   |  150  |                        |                         |  |
-| YOLOv2 |  640  |  √   |  150  |                        |                         |  |
-| YOLOv3 |  640  |  √   |  250  |                        |                         |  |
-| YOLOv4 |  640  |  √   |  250  |                        |                         |  |
-| YOLOX  |  640  |  √   |  250  |                        |                         |  |
+| Model  |   Backbone    | Scale |  IP  | Epoch | AP<sup>val<br>0.5:0.95 | AP<sup>test<br>50 | Weight |
+|--------|---------------|-------|------|-------|------------------------|-------------------|--------|
+| YOLOv1 | ResNet-18     |  640  |  √   |  150  |        27.9            |       47.5        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov1_coco.pth) |
+| YOLOv2 | DarkNet-19    |  640  |  √   |  150  |        32.7            |       50.9        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov2_coco.pth) |
+| YOLOv3 | DarkNet-53    |  640  |  √   |  250  |                        |                   |  |
+| YOLOv4 | CSPDarkNet-53 |  640  |  √   |  250  |                        |                   |  |
+| YOLOv5 | CSPDarkNet-L  |  640  |  √   |  250  |                        |                   |  |
+| YOLOX  | CSPDarkNet-L  |  640  |  √   |  250  |                        |                   |  |
+| YOLOv7 | ELANNet       |  640  |  √   |  250  |                        |                   |  |
 
 *所有的模型都使用了ImageNet预训练权重(IP),所有的FLOPs都是在COCO-val数据集上以640x640或1280x1280的输入尺寸来测试的。FPS指标是在一张3090型号的GPU上以batch size=1的输入来测试的,请注意,测速的内容包括模型前向推理、后处理以及NMS操作。*
 

+ 8 - 0
config/__init__.py

@@ -3,6 +3,8 @@ from .yolov1_config import yolov1_cfg
 from .yolov2_config import yolov2_cfg
 from .yolov3_config import yolov3_cfg
 from .yolov4_config import yolov4_cfg
+from .yolov5_config import yolov5_cfg
+from .yolov7_config import yolov7_cfg
 from .yolox_config import yolox_cfg
 
 
@@ -21,6 +23,12 @@ def build_model_config(args):
     # YOLOv4
     elif args.model == 'yolov4':
         cfg = yolov4_cfg
+    # YOLOv5
+    elif args.model == 'yolov5':
+        cfg = yolov5_cfg
+    # YOLOv7
+    elif args.model == 'yolov7':
+        cfg = yolov7_cfg
     # YOLOX
     elif args.model == 'yolox':
         cfg = yolox_cfg

+ 53 - 0
config/yolov5_config.py

@@ -0,0 +1,53 @@
+# YOLOv5 Config
+
+yolov5_cfg = {
+    # input
+    'trans_type': 'yolov5',
+    'multi_scale': [0.5, 1.0],
+    # model
+    'backbone': 'cspdarknet',
+    'pretrained': False,
+    'bk_act': 'silu',
+    'bk_norm': 'BN',
+    'bk_dpw': False,
+    'stride': [8, 16, 32],  # P3, P4, P5
+    'width': 1.0,
+    'depth': 1.0,
+     # fpn
+    'fpn': 'yolo_pafpn',
+    'fpn_act': 'silu',
+    'fpn_norm': 'BN',
+    'fpn_depthwise': False,
+    # head
+    'head': 'decoupled_head',
+    'head_act': 'silu',
+    'head_norm': 'BN',
+    'num_cls_head': 2,
+    'num_reg_head': 2,
+    'head_depthwise': False,
+    'anchor_size': [[10, 13],   [16, 30],   [33, 23],     # P3
+                    [30, 61],   [62, 45],   [59, 119],    # P4
+                    [116, 90],  [156, 198], [373, 326]],  # P5
+    # matcher
+    'anchor_thresh': 4.0,
+    # loss weight
+    'loss_obj_weight': 1.0,
+    'loss_cls_weight': 1.0,
+    'loss_box_weight': 5.0,
+    # training configuration
+    'no_aug_epoch': 10,
+    # optimizer
+    'optimizer': 'sgd',        # optional: sgd, adam, adamw
+    'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
+    'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
+    'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
+    # model EMA
+    'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
+    'ema_tau': 2000,
+    # lr schedule
+    'scheduler': 'linear',
+    'lr0': 0.01,               # SGD: 0.01;     AdamW: 0.004
+    'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.05
+    'warmup_momentum': 0.8,
+    'warmup_bias_lr': 0.1,
+}

+ 56 - 0
config/yolov7_config.py

@@ -0,0 +1,56 @@
+# YOLOv7 Config
+
+yolov7_cfg = {
+    # input
+    'trans_type': 'yolov5',
+    'multi_scale': [0.5, 1.0],
+    # model
+    'backbone': 'elannet',
+    'pretrained': True,
+    'bk_act': 'silu',
+    'bk_norm': 'BN',
+    'bk_dpw': False,
+    'stride': [8, 16, 32],  # P3, P4, P5
+    # neck
+    'neck': 'csp_sppf',
+    'expand_ratio': 0.5,
+    'pooling_size': 5,
+    'neck_act': 'silu',
+    'neck_norm': 'BN',
+    'neck_depthwise': False,
+     # fpn
+    'fpn': 'yolov7_pafpn',
+    'fpn_act': 'silu',
+    'fpn_norm': 'BN',
+    'fpn_depthwise': False,
+    # head
+    'head': 'decoupled_head',
+    'head_act': 'silu',
+    'head_norm': 'BN',
+    'num_cls_head': 2,
+    'num_reg_head': 2,
+    'head_depthwise': False,
+    # matcher
+    'matcher': {'center_sampling_radius': 2.5,
+                'topk_candicate': 10},
+    # loss weight
+    'loss_obj_weight': 1.0,
+    'loss_cls_weight': 1.0,
+    'loss_box_weight': 5.0,
+    # training configuration
+    'no_aug_epoch': 20,
+    # optimizer
+    'optimizer': 'sgd',        # optional: sgd, adam, adamw
+    'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
+    'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
+    'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
+    # model EMA
+    'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
+    'ema_tau': 2000,
+    # lr schedule
+    'scheduler': 'linear',
+    'lr0': 0.01,               # SGD: 0.01;     AdamW: 0.004
+    'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.05
+    'warmup_momentum': 0.8,
+    'warmup_bias_lr': 0.1,
+}

+ 1 - 1
eval.py

@@ -28,7 +28,7 @@ def parse_args():
 
     # model
     parser.add_argument('-m', '--model', default='yolov1', type=str,
-                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolox'], help='build yolo')
+                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolov5', 'yolov7', 'yolox'], help='build yolo')
     parser.add_argument('--weight', default=None,
                         type=str, help='Trained state_dict file path to open')
     parser.add_argument('--conf_thresh', default=0.001, type=float,

+ 10 - 0
models/__init__.py

@@ -6,6 +6,8 @@ from .yolov1.build import build_yolov1
 from .yolov2.build import build_yolov2
 from .yolov3.build import build_yolov3
 from .yolov4.build import build_yolov4
+from .yolov5.build import build_yolov5
+from .yolov7.build import build_yolov7
 from .yolox.build import build_yolox
 
 
@@ -31,6 +33,14 @@ def build_model(args,
     elif args.model == 'yolov4':
         model, criterion = build_yolov4(
             args, model_cfg, device, num_classes, trainable)
+    # YOLOv5   
+    elif args.model == 'yolov5':
+        model, criterion = build_yolov5(
+            args, model_cfg, device, num_classes, trainable)
+    # YOLOv5   
+    elif args.model == 'yolov7':
+        model, criterion = build_yolov7(
+            args, model_cfg, device, num_classes, trainable)
     # YOLOX   
     elif args.model == 'yolox':
         model, criterion = build_yolox(

+ 31 - 0
models/yolov5/build.py

@@ -0,0 +1,31 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+from .loss import build_criterion
+from .yolov5 import YOLOv5
+
+
+# build object detector
+def build_yolov5(args, cfg, device, num_classes=80, trainable=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+    
+    print('==============================')
+    print('Model Configuration: \n', cfg)
+    
+    model = YOLOv5(
+        cfg = cfg,
+        device = device,
+        num_classes = num_classes,
+        conf_thresh = args.conf_thresh,
+        nms_thresh = args.nms_thresh,
+        topk = args.topk,
+        trainable = trainable
+        )
+
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(cfg, device, num_classes)
+
+    return model, criterion

+ 114 - 0
models/yolov5/loss.py

@@ -0,0 +1,114 @@
+import torch
+import torch.nn.functional as F
+from .matcher import Yolov5Matcher
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+
+class Criterion(object):
+    def __init__(self, cfg, device, num_classes=80):
+        self.cfg = cfg
+        self.device = device
+        self.num_classes = num_classes
+        # loss weight
+        self.loss_obj_weight = cfg['loss_obj_weight']
+        self.loss_cls_weight = cfg['loss_cls_weight']
+        self.loss_box_weight = cfg['loss_box_weight']
+
+        # matcher
+        self.matcher = Yolov5Matcher(num_classes, 3, cfg['anchor_size'], cfg['anchor_thresh'])
+
+
+    def loss_objectness(self, pred_obj, gt_obj):
+        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
+
+        return loss_obj
+    
+
+    def loss_classes(self, pred_cls, gt_label):
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+
+        return loss_cls
+
+
+    def loss_bboxes(self, pred_box, gt_box):
+        # regression loss
+        ious = get_ious(pred_box,
+                        gt_box,
+                        box_mode="xyxy",
+                        iou_type='giou')
+        loss_box = 1.0 - ious
+
+        return loss_box, ious
+
+
+    def __call__(self, outputs, targets):
+        device = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        fmp_sizes = outputs['fmp_sizes']
+        (
+            gt_objectness, 
+            gt_classes, 
+            gt_bboxes,
+            ) = self.matcher(fmp_sizes=fmp_sizes, 
+                             fpn_strides=fpn_strides, 
+                             targets=targets)
+        # List[B, M, C] -> [B, M, C] -> [BM, C]
+        pred_obj = torch.cat(outputs['pred_obj'], dim=1).view(-1)                      # [BM,]
+        pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes)    # [BM, C]
+        pred_box = torch.cat(outputs['pred_box'], dim=1).view(-1, 4)                   # [BM, 4]
+       
+        gt_objectness = gt_objectness.view(-1).to(device).float()               # [BM,]
+        gt_classes = gt_classes.view(-1, self.num_classes).to(device).float()   # [BM, C]
+        gt_bboxes = gt_bboxes.view(-1, 4).to(device).float()                    # [BM, 4]
+
+        pos_masks = (gt_objectness > 0)
+        num_fgs = pos_masks.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+
+        # box loss
+        pred_box_pos = pred_box[pos_masks]
+        gt_bboxes_pos = gt_bboxes[pos_masks]
+        loss_box, ious = self.loss_bboxes(pred_box_pos, gt_bboxes_pos)
+        loss_box = loss_box.sum() / num_fgs
+        
+        # cls loss
+        pred_cls_pos = pred_cls[pos_masks]
+        gt_classes_pos = gt_classes[pos_masks] * ious.unsqueeze(-1).clamp(0.)
+        loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # obj loss
+        loss_obj = self.loss_objectness(pred_obj, gt_objectness)
+        loss_obj = loss_obj.sum() / num_fgs
+
+        # total loss
+        losses = self.loss_obj_weight * loss_obj + \
+                 self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box
+
+        loss_dict = dict(
+                loss_obj = loss_obj,
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                losses = losses
+        )
+
+        return loss_dict
+    
+
+def build_criterion(cfg, device, num_classes):
+    criterion = Criterion(
+        cfg=cfg,
+        device=device,
+        num_classes=num_classes
+        )
+
+    return criterion
+
+    
+if __name__ == "__main__":
+    pass

+ 214 - 0
models/yolov5/matcher.py

@@ -0,0 +1,214 @@
+import numpy as np
+import torch
+
+
+class Yolov5Matcher(object):
+    def __init__(self, num_classes, num_anchors, anchor_size, anchor_theshold):
+        self.num_classes = num_classes
+        self.num_anchors = num_anchors
+        self.anchor_theshold = anchor_theshold
+        # [KA, 2]
+        self.anchor_sizes = np.array([[anchor[0], anchor[1]]
+                                      for anchor in anchor_size])
+        # [KA, 4]
+        self.anchor_boxes = np.array([[0., 0., anchor[0], anchor[1]]
+                                      for anchor in anchor_size])
+
+    def compute_iou(self, anchor_boxes, gt_box):
+        """
+            anchor_boxes : ndarray -> [KA, 4] (cx, cy, bw, bh).
+            gt_box : ndarray -> [1, 4] (cx, cy, bw, bh).
+        """
+        # anchors: [KA, 4]
+        anchors = np.zeros_like(anchor_boxes)
+        anchors[..., :2] = anchor_boxes[..., :2] - anchor_boxes[..., 2:] * 0.5  # x1y1
+        anchors[..., 2:] = anchor_boxes[..., :2] + anchor_boxes[..., 2:] * 0.5  # x2y2
+        anchors_area = anchor_boxes[..., 2] * anchor_boxes[..., 3]
+        
+        # gt_box: [1, 4] -> [KA, 4]
+        gt_box = np.array(gt_box).reshape(-1, 4)
+        gt_box = np.repeat(gt_box, anchors.shape[0], axis=0)
+        gt_box_ = np.zeros_like(gt_box)
+        gt_box_[..., :2] = gt_box[..., :2] - gt_box[..., 2:] * 0.5  # x1y1
+        gt_box_[..., 2:] = gt_box[..., :2] + gt_box[..., 2:] * 0.5  # x2y2
+        gt_box_area = np.prod(gt_box[..., 2:] - gt_box[..., :2], axis=1)
+
+        # intersection
+        inter_w = np.minimum(anchors[:, 2], gt_box_[:, 2]) - \
+                  np.maximum(anchors[:, 0], gt_box_[:, 0])
+        inter_h = np.minimum(anchors[:, 3], gt_box_[:, 3]) - \
+                  np.maximum(anchors[:, 1], gt_box_[:, 1])
+        inter_area = inter_w * inter_h
+        
+        # union
+        union_area = anchors_area + gt_box_area - inter_area
+
+        # iou
+        iou = inter_area / union_area
+        iou = np.clip(iou, a_min=1e-10, a_max=1.0)
+        
+        return iou
+
+
+    def iou_assignment(self, ctr_points, gt_box, fpn_strides):
+        # compute IoU
+        iou = self.compute_iou(self.anchor_boxes, gt_box)
+        iou_mask = (iou > 0.5)
+
+        label_assignment_results = []
+        if iou_mask.sum() == 0:
+            # We assign the anchor box with highest IoU score.
+            iou_ind = np.argmax(iou)
+
+            level = iou_ind // self.num_anchors              # pyramid level
+            anchor_idx = iou_ind - level * self.num_anchors  # anchor index
+
+            # get the corresponding stride
+            stride = fpn_strides[level]
+
+            # compute the grid cell
+            xc, yc = ctr_points
+            xc_s = xc / stride
+            yc_s = yc / stride
+            grid_x = int(xc_s)
+            grid_y = int(yc_s)
+
+            label_assignment_results.append([grid_x, grid_y, xc_s, yc_s, level, anchor_idx])
+        else:            
+            for iou_ind, iou_m in enumerate(iou_mask):
+                if iou_m:
+                    level = iou_ind // self.num_anchors              # pyramid level
+                    anchor_idx = iou_ind - level * self.num_anchors  # anchor index
+
+                    # get the corresponding stride
+                    stride = fpn_strides[level]
+
+                    # compute the gride cell
+                    xc_s = xc / stride
+                    yc_s = yc / stride
+                    grid_x = int(xc_s)
+                    grid_y = int(yc_s)
+
+                    label_assignment_results.append([grid_x, grid_y, xc_s, yc_s, level, anchor_idx])
+
+        return label_assignment_results
+
+
+    def aspect_ratio_assignment(self, ctr_points, keeps, fpn_strides):
+        label_assignment_results = []
+        for keep_idx, keep in enumerate(keeps):
+            if keep:
+                level = keep_idx // self.num_anchors              # pyramid level
+                anchor_idx = keep_idx - level * self.num_anchors  # anchor index
+
+                # get the corresponding stride
+                stride = fpn_strides[level]
+
+                # compute the gride cell
+                xc, yc = ctr_points
+                xc_s = xc / stride
+                yc_s = yc / stride
+                grid_x = int(xc_s)
+                grid_y = int(yc_s)
+
+                label_assignment_results.append([grid_x, grid_y, xc_s, yc_s, level, anchor_idx])
+        
+        return label_assignment_results
+    
+
+    @torch.no_grad()
+    def __call__(self, fmp_sizes, fpn_strides, targets):
+        """
+            fmp_size: (List) [fmp_h, fmp_w]
+            fpn_strides: (List) -> [8, 16, 32, ...] stride of network output.
+            targets: (Dict) dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}
+        """
+        assert len(fmp_sizes) == len(fpn_strides)
+        # prepare
+        bs = len(targets)
+        gt_objectness = [
+            torch.zeros([bs, fmp_h, fmp_w, self.num_anchors, 1]) 
+            for (fmp_h, fmp_w) in fmp_sizes
+            ]
+        gt_classes = [
+            torch.zeros([bs, fmp_h, fmp_w, self.num_anchors, self.num_classes]) 
+            for (fmp_h, fmp_w) in fmp_sizes
+            ]
+        gt_bboxes = [
+            torch.zeros([bs, fmp_h, fmp_w, self.num_anchors, 4]) 
+            for (fmp_h, fmp_w) in fmp_sizes
+            ]
+
+        for batch_index in range(bs):
+            targets_per_image = targets[batch_index]
+            # [N,]
+            tgt_cls = targets_per_image["labels"].numpy()
+            # [N, 4]
+            tgt_box = targets_per_image['boxes'].numpy()
+
+            for gt_box, gt_label in zip(tgt_box, tgt_cls):
+                # get a bbox coords
+                x1, y1, x2, y2 = gt_box.tolist()
+                # xyxy -> cxcywh
+                xc, yc = (x2 + x1) * 0.5, (y2 + y1) * 0.5
+                bw, bh = x2 - x1, y2 - y1
+                gt_box = np.array([[0., 0., bw, bh]])
+
+                # check target
+                if bw < 1. or bh < 1.:
+                    # invalid target
+                    continue
+
+                # compute aspect ratio
+                ratios = gt_box[..., 2:] / self.anchor_sizes
+                keeps = np.maximum(ratios, 1 / ratios).max(-1) < self.anchor_theshold
+
+                if keeps.sum() == 0:
+                    label_assignment_results = self.iou_assignment([xc, yc], gt_box, fpn_strides)
+                else:
+                    label_assignment_results = self.aspect_ratio_assignment([xc, yc], keeps, fpn_strides)
+
+                # label assignment
+                for result in label_assignment_results:
+                    stride = fpn_strides[level]
+                    fmp_h, fmp_w = fmp_sizes[level]
+                    # assignment
+                    grid_x, grid_y, xc_s, yc_s, level, anchor_idx = result
+                    # coord on the feature
+                    x1s, y1s = x1 / stride, y1 / stride
+                    x2s, y2s = x2 / stride, y2 / stride
+                    # offset
+                    off_x = xc_s - grid_x
+                    off_y = yc_s - grid_y
+
+                    if off_x <= 0.5 and off_y <= 0.5:  # top left
+                        grids = [(grid_x-1, grid_y), (grid_x, grid_y-1), (grid_x, grid_y)]
+                    elif off_x > 0.5 and off_y <= 0.5: # top right
+                        grids = [(grid_x+1, grid_y), (grid_x, grid_y-1), (grid_x, grid_y)]
+                    elif off_x < 0.5 and off_y > 0.5:  # bottom left
+                        grids = [(grid_x-1, grid_y), (grid_x, grid_y+1), (grid_x, grid_y)]
+                    elif off_x > 0.5 and off_y > 0.5:  # bottom right
+                        grids = [(grid_x+1, grid_y), (grid_x, grid_y+1), (grid_x, grid_y)]
+
+                    for (i, j) in grids:
+                        is_in_box = (j >= y1s and j < y2s) and (i >= x1s and i < x2s)
+                        is_valid = (j >= 0 and j < fmp_h) and (i >= 0 and i < fmp_w)
+
+                        if is_in_box and is_valid:
+                            # obj
+                            gt_objectness[level][batch_index, j, i, anchor_idx] = 1.0
+                            # cls
+                            cls_ont_hot = torch.zeros(self.num_classes)
+                            cls_ont_hot[int(gt_label)] = 1.0
+                            gt_classes[level][batch_index, j, i, anchor_idx] = cls_ont_hot
+                            # box
+                            gt_bboxes[level][batch_index, j, i, anchor_idx] = torch.as_tensor([x1, y1, x2, y2])
+
+        # [B, M, C]
+        gt_objectness = torch.cat([gt.view(bs, -1, 1) for gt in gt_objectness], dim=1).float()
+        gt_classes = torch.cat([gt.view(bs, -1, self.num_classes) for gt in gt_classes], dim=1).float()
+        gt_bboxes = torch.cat([gt.view(bs, -1, 4) for gt in gt_bboxes], dim=1).float()
+
+        return gt_objectness, gt_classes, gt_bboxes

+ 285 - 0
models/yolov5/yolov5.py

@@ -0,0 +1,285 @@
+import torch
+import torch.nn as nn
+
+from utils.nms import multiclass_nms
+
+from .yolov5_backbone import build_backbone
+from .yolov5_fpn import build_fpn
+from .yolov5_head import build_head
+
+
+# YOLOv5
+class YOLOv5(nn.Module):
+    def __init__(self,
+                 cfg,
+                 device,
+                 num_classes=20,
+                 conf_thresh=0.01,
+                 topk=100,
+                 nms_thresh=0.5,
+                 trainable=False):
+        super(YOLOv5, self).__init__()
+        # ------------------- Basic parameters -------------------
+        self.cfg = cfg                                 # 模型配置文件
+        self.device = device                           # cuda或者是cpu
+        self.num_classes = num_classes                 # 类别的数量
+        self.trainable = trainable                     # 训练的标记
+        self.conf_thresh = conf_thresh                 # 得分阈值
+        self.nms_thresh = nms_thresh                   # NMS阈值
+        self.topk = topk                               # topk
+        self.stride = [8, 16, 32]                      # 网络的输出步长
+        # ------------------- Anchor box -------------------
+        self.num_levels = 3
+        self.num_anchors = len(cfg['anchor_size']) // self.num_levels
+        self.anchor_size = torch.as_tensor(
+            cfg['anchor_size']
+            ).view(self.num_levels, self.num_anchors, 2) # [S, A, 2]
+        
+        # ------------------- Network Structure -------------------
+        ## 主干网络
+        self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
+
+        ## 颈部网络: 特征金字塔
+        self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=int(256*cfg['width']))
+        self.head_dim = self.fpn.out_dim
+
+        ## 检测头
+        self.non_shared_heads = nn.ModuleList(
+            [build_head(cfg, head_dim, head_dim, num_classes) 
+            for head_dim in self.head_dim
+            ])
+
+        ## 预测层
+        self.obj_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_out_dim, 1 * self.num_anchors, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ]) 
+        self.cls_preds = nn.ModuleList(
+                            [nn.Conv2d(head.cls_out_dim, self.num_classes * self.num_anchors, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ]) 
+        self.reg_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_out_dim, 4 * self.num_anchors, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ])                 
+    
+
+        # --------- Network Initialization ----------
+        self.init_yolo()
+
+
+    def init_yolo(self): 
+        # Init yolo
+        for m in self.modules():
+            if isinstance(m, nn.BatchNorm2d):
+                m.eps = 1e-3
+                m.momentum = 0.03
+                
+        # Init bias
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        # obj pred
+        for obj_pred in self.obj_preds:
+            b = obj_pred.bias.view(self.num_anchors, -1)
+            b.data.fill_(bias_value.item())
+            obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # cls pred
+        for cls_pred in self.cls_preds:
+            b = cls_pred.bias.view(self.num_anchors, -1)
+            b.data.fill_(bias_value.item())
+            cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+
+    def generate_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        fmp_h, fmp_w = fmp_size
+        # [KA, 2]
+        anchor_size = self.anchor_size[level]
+
+        # generate grid cells
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        # [HW, 2] -> [HW, KA, 2] -> [M, 2]
+        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1)
+        anchor_xy = anchor_xy.view(-1, 2).to(self.device)
+
+        # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2] -> [M, 2]
+        anchor_wh = anchor_size.unsqueeze(0).repeat(fmp_h*fmp_w, 1, 1)
+        anchor_wh = anchor_wh.view(-1, 2).to(self.device)
+
+        anchors = torch.cat([anchor_xy, anchor_wh], dim=-1)
+
+        return anchors
+        
+
+    def decode_boxes(self, level, anchors, reg_pred):
+        """
+            将txtytwth转换为常用的x1y1x2y2形式。
+        """
+
+        # 计算预测边界框的中心点坐标和宽高
+        pred_ctr = (torch.sigmoid(reg_pred[..., :2]) * 2.0 - 0.5 + anchors[..., :2]) * self.stride[level]
+        pred_wh = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
+
+        # 将所有bbox的中心带你坐标和宽高换算成x1y1x2y2形式
+        pred_x1y1 = pred_ctr - pred_wh * 0.5
+        pred_x2y2 = pred_ctr + pred_wh * 0.5
+        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        return pred_box
+
+
+    def post_process(self, obj_preds, cls_preds, reg_preds, anchors):
+        """
+        Input:
+            obj_preds: List(Tensor) [[H x W, 1], ...]
+            cls_preds: List(Tensor) [[H x W, C], ...]
+            reg_preds: List(Tensor) [[H x W, 4], ...]
+            anchors:  List(Tensor) [[H x W, 2], ...]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for level, (obj_pred_i, cls_pred_i, reg_pred_i, anchor_i) \
+                in enumerate(zip(obj_preds, cls_preds, reg_preds, anchors)):
+            # (H x W x KA x C,)
+            scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk, reg_pred_i.size(0))
+
+            # torch.sort is actually faster than .topk (at least on GPUs)
+            predicted_prob, topk_idxs = scores_i.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            scores = topk_scores[keep_idxs]
+            topk_idxs = topk_idxs[keep_idxs]
+
+            anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+            labels = topk_idxs % self.num_classes
+
+            reg_pred_i = reg_pred_i[anchor_idxs]
+            anchor_i = anchor_i[anchor_idxs]
+
+            # decode box: [M, 4]
+            bboxes = self.decode_boxes(level, anchor_i, reg_pred_i)
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
+
+        return bboxes, scores, labels
+
+
+    @torch.no_grad()
+    def inference(self, x):
+        # 主干网络
+        pyramid_feats = self.backbone(x)
+
+        # 特征金字塔
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # 检测头
+        all_anchors = []
+        all_obj_preds = []
+        all_cls_preds = []
+        all_reg_preds = []
+        for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+            cls_feat, reg_feat = head(feat)
+
+            # [1, C, H, W]
+            obj_pred = self.obj_preds[level](reg_feat)
+            cls_pred = self.cls_preds[level](cls_feat)
+            reg_pred = self.reg_preds[level](reg_feat)
+
+            # anchors: [M, 2]
+            fmp_size = cls_pred.shape[-2:]
+            anchors = self.generate_anchors(level, fmp_size)
+
+            # [1, AC, H, W] -> [H, W, AC] -> [M, C]
+            obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
+            cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
+            reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
+
+            all_obj_preds.append(obj_pred)
+            all_cls_preds.append(cls_pred)
+            all_reg_preds.append(reg_pred)
+            all_anchors.append(anchors)
+
+        # post process
+        bboxes, scores, labels = self.post_process(
+            all_obj_preds, all_cls_preds, all_reg_preds, all_anchors)
+
+        return bboxes, scores, labels
+
+
+    def forward(self, x):
+        if not self.trainable:
+            return self.inference(x)
+        else:
+            bs = x.shape[0]
+            # 主干网络
+            pyramid_feats = self.backbone(x)
+
+            # 特征金字塔
+            pyramid_feats = self.fpn(pyramid_feats)
+
+            # 检测头
+            all_fmp_sizes = []
+            all_obj_preds = []
+            all_cls_preds = []
+            all_box_preds = []
+            for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+                cls_feat, reg_feat = head(feat)
+
+                # [B, C, H, W]
+                obj_pred = self.obj_preds[level](reg_feat)
+                cls_pred = self.cls_preds[level](cls_feat)
+                reg_pred = self.reg_preds[level](reg_feat)
+
+                fmp_size = cls_pred.shape[-2:]
+
+                # generate anchor boxes: [M, 4]
+                anchors = self.generate_anchors(level, fmp_size)
+                
+                # [B, AC, H, W] -> [B, H, W, AC] -> [B, M, C]
+                obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, 1)
+                cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, self.num_classes)
+                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, 4)
+
+                # decode bbox
+                box_pred = self.decode_boxes(level, anchors, reg_pred)
+
+                all_obj_preds.append(obj_pred)
+                all_cls_preds.append(cls_pred)
+                all_box_preds.append(box_pred)
+                all_fmp_sizes.append(fmp_size)
+
+            # output dict
+            outputs = {"pred_obj": all_obj_preds,        # List [B, M, 1]
+                       "pred_cls": all_cls_preds,        # List [B, M, C]
+                       "pred_box": all_box_preds,        # List [B, M, 4]
+                       'fmp_sizes': all_fmp_sizes,       # List
+                       'strides': self.stride,           # List
+                       }
+
+            return outputs 

+ 128 - 0
models/yolov5/yolov5_backbone.py

@@ -0,0 +1,128 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov5_basic import Conv, CSPBlock
+    from .yolov5_neck import SPPF
+except:
+    from yolov5_basic import Conv, CSPBlock
+    from yolov5_neck import SPPF
+
+model_urls = {
+    "cspdarknet_large": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/cspdarknet_large.pth",
+}
+
+# CSPDarkNet
+class CSPDarkNet(nn.Module):
+    def __init__(self, depth=1.0, width=1.0, act_type='silu', norm_type='BN', depthwise=False):
+        super(CSPDarkNet, self).__init__()
+        self.feat_dims = [int(256*width), int(512*width), int(1024*width)]
+
+        # P1
+        self.layer_1 = Conv(3, int(64*width), k=6, p=2, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        
+        # P2
+        self.layer_2 = nn.Sequential(
+            Conv(int(64*width), int(128*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            CSPBlock(int(128*width), int(128*width), expand_ratio=0.5, nblocks=int(3*depth),
+                     shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P3
+        self.layer_3 = nn.Sequential(
+            Conv(int(128*width), int(256*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            CSPBlock(int(256*width), int(256*width), expand_ratio=0.5, nblocks=int(9*depth),
+                     shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P4
+        self.layer_4 = nn.Sequential(
+            Conv(int(256*width), int(512*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            CSPBlock(int(512*width), int(512*width), expand_ratio=0.5, nblocks=int(9*depth),
+                     shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P5
+        self.layer_5 = nn.Sequential(
+            Conv(int(512*width), int(1024*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            SPPF(int(1024*width), int(1024*width), expand_ratio=0.5, act_type=act_type, norm_type=norm_type),
+            CSPBlock(int(1024*width), int(1024*width), expand_ratio=0.5, nblocks=int(3*depth),
+                     shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ---------------------------- Functions ----------------------------
+def build_backbone(cfg, pretrained=False): 
+    """Constructs a darknet-53 model.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    backbone = CSPDarkNet(cfg['depth'], cfg['width'], cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
+    feat_dims = backbone.feat_dims
+
+    if pretrained:
+        if cfg['width'] == 1.0 and cfg['depth'] == 1.0:
+            url = model_urls['cspdarknet_large']
+        if url is not None:
+            print('Loading pretrained weight ...')
+            checkpoint = torch.hub.load_state_dict_from_url(
+                url=url, map_location="cpu", check_hash=True)
+            # checkpoint state dict
+            checkpoint_state_dict = checkpoint.pop("model")
+            # model state dict
+            model_state_dict = backbone.state_dict()
+            # check
+            for k in list(checkpoint_state_dict.keys()):
+                if k in model_state_dict:
+                    shape_model = tuple(model_state_dict[k].shape)
+                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                    if shape_model != shape_checkpoint:
+                        checkpoint_state_dict.pop(k)
+                else:
+                    checkpoint_state_dict.pop(k)
+                    print(k)
+
+            backbone.load_state_dict(checkpoint_state_dict)
+        else:
+            print('No backbone pretrained: CSPDarkNet53')        
+
+    return backbone, feat_dims
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'pretrained': False,
+        'bk_act': 'lrelu',
+        'bk_norm': 'BN',
+        'bk_dpw': False,
+        'p6_feat': False,
+        'p7_feat': False,
+        'width': 1.0,
+        'depth': 1.0,
+    }
+    model, feats = build_backbone(cfg)
+    x = torch.randn(1, 3, 224, 224)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    x = torch.randn(1, 3, 224, 224)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 131 - 0
models/yolov5/yolov5_basic.py

@@ -0,0 +1,131 @@
+import torch
+import torch.nn as nn
+
+
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+
+
+# Basic conv layer
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type='lrelu',     # activation
+                 norm_type='BN',       # normalization
+                 depthwise=False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+
+# BottleNeck
+class Bottleneck(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio=0.5,
+                 shortcut=False,
+                 depthwise=False,
+                 act_type='silu',
+                 norm_type='BN'):
+        super(Bottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)  # hidden channels            
+        self.cv1 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
+        self.cv2 = Conv(inter_dim, out_dim, k=3, p=1, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.cv2(self.cv1(x))
+
+        return x + h if self.shortcut else h
+
+
+# CSP-stage block
+class CSPBlock(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio=0.5,
+                 nblocks=1,
+                 shortcut=False,
+                 depthwise=False,
+                 act_type='silu',
+                 norm_type='BN'):
+        super(CSPBlock, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)
+        self.cv1 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
+        self.cv2 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
+        self.cv3 = Conv(2 * inter_dim, out_dim, k=1, norm_type=norm_type, act_type=act_type)
+        self.m = nn.Sequential(*[
+            Bottleneck(inter_dim, inter_dim, expand_ratio=1.0, shortcut=shortcut,
+                       norm_type=norm_type, act_type=act_type, depthwise=depthwise)
+                       for _ in range(nblocks)
+                       ])
+
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.m(x1)
+        out = self.cv3(torch.cat([x3, x2], dim=1))
+
+        return out
+    

+ 137 - 0
models/yolov5/yolov5_fpn.py

@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .yolov5_basic import Conv, CSPBlock
+
+
+# PaFPN-CSP
+class YoloPaFPN(nn.Module):
+    def __init__(self, 
+                 in_dims=[256, 512, 1024],
+                 out_dim=256,
+                 width=1.0,
+                 depth=1.0,
+                 act_type='silu',
+                 norm_type='BN',
+                 depthwise=False):
+        super(YoloPaFPN, self).__init__()
+        self.in_dims = in_dims
+        self.out_dim = out_dim
+        c3, c4, c5 = in_dims
+
+        # top down
+        ## P5 -> P4
+        self.reduce_layer_1 = Conv(c5, int(512*width), k=1, norm_type=norm_type, act_type=act_type)
+        self.top_down_layer_1 = CSPBlock(in_dim = c4 + int(512*width),
+                                         out_dim = int(512*width),
+                                         expand_ratio = 0.5,
+                                         nblocks = int(3*depth),
+                                         shortcut = False,
+                                         depthwise = depthwise,
+                                         norm_type = norm_type,
+                                         act_type = act_type
+                                         )
+
+        ## P4 -> P3
+        self.reduce_layer_2 = Conv(c4, int(256*width), k=1, norm_type=norm_type, act_type=act_type)
+        self.top_down_layer_2 = CSPBlock(in_dim = c3 + int(256*width), 
+                                         out_dim = int(256*width),
+                                         expand_ratio = 0.5,
+                                         nblocks = int(3*depth),
+                                         shortcut = False,
+                                         depthwise = depthwise,
+                                         norm_type = norm_type,
+                                         act_type=act_type
+                                         )
+
+        # bottom up
+        ## P3 -> P4
+        self.reduce_layer_3 = Conv(int(256*width), int(256*width), k=3, p=1, s=2,
+                                   depthwise=depthwise, norm_type=norm_type, act_type=act_type)
+        self.bottom_up_layer_1 = CSPBlock(in_dim = int(256*width) + int(256*width),
+                                          out_dim = int(512*width),
+                                          expand_ratio = 0.5,
+                                          nblocks = int(3*depth),
+                                          shortcut = False,
+                                          depthwise = depthwise,
+                                          norm_type = norm_type,
+                                          act_type=act_type
+                                          )
+
+        ## P4 -> P5
+        self.reduce_layer_4 = Conv(int(512*width), int(512*width), k=3, p=1, s=2,
+                                   depthwise=depthwise, norm_type=norm_type, act_type=act_type)
+        self.bottom_up_layer_2 = CSPBlock(in_dim = int(512*width) + int(512*width),
+                                          out_dim = int(1024*width),
+                                          expand_ratio = 0.5,
+                                          nblocks = int(3*depth),
+                                          shortcut = False,
+                                          depthwise = depthwise,
+                                          norm_type = norm_type,
+                                          act_type=act_type
+                                          )
+
+        # output proj layers
+        if out_dim is not None:
+            # output proj layers
+            self.out_layers = nn.ModuleList([
+                Conv(in_dim, out_dim, k=1,
+                        norm_type=norm_type, act_type=act_type)
+                        for in_dim in [int(256 * width), int(512 * width), int(1024 * width)]
+                        ])
+            self.out_dim = [out_dim] * 3
+
+        else:
+            self.out_layers = None
+            self.out_dim = [int(256 * width), int(512 * width), int(1024 * width)]
+
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        c6 = self.reduce_layer_1(c5)
+        c7 = F.interpolate(c6, scale_factor=2.0)   # s32->s16
+        c8 = torch.cat([c7, c4], dim=1)
+        c9 = self.top_down_layer_1(c8)
+        # P3/8
+        c10 = self.reduce_layer_2(c9)
+        c11 = F.interpolate(c10, scale_factor=2.0)   # s16->s8
+        c12 = torch.cat([c11, c3], dim=1)
+        c13 = self.top_down_layer_2(c12)  # to det
+        # p4/16
+        c14 = self.reduce_layer_3(c13)
+        c15 = torch.cat([c14, c10], dim=1)
+        c16 = self.bottom_up_layer_1(c15)  # to det
+        # p5/32
+        c17 = self.reduce_layer_4(c16)
+        c18 = torch.cat([c17, c6], dim=1)
+        c19 = self.bottom_up_layer_2(c18)  # to det
+
+        out_feats = [c13, c16, c19] # [P3, P4, P5]
+
+        # output proj layers
+        if self.out_layers is not None:
+            # output proj layers
+            out_feats_proj = []
+            for feat, layer in zip(out_feats, self.out_layers):
+                out_feats_proj.append(layer(feat))
+            return out_feats_proj
+
+        return out_feats
+
+
+def build_fpn(cfg, in_dims, out_dim=None):
+    model = cfg['fpn']
+    # build neck
+    if model == 'yolo_pafpn':
+        fpn_net = YoloPaFPN(in_dims=in_dims,
+                             out_dim=out_dim,
+                             width=cfg['width'],
+                             depth=cfg['depth'],
+                             act_type=cfg['fpn_act'],
+                             norm_type=cfg['fpn_norm'],
+                             depthwise=cfg['fpn_depthwise']
+                             )
+
+
+    return fpn_net

+ 137 - 0
models/yolov5/yolov5_head.py

@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+try:
+    from .yolov5_basic import Conv
+except:
+    from yolov5_basic import Conv
+
+
+class DecoupledHead(nn.Module):
+    def __init__(self, cfg, in_dim, out_dim, num_classes=80):
+        super().__init__()
+        print('==============================')
+        print('Head: Decoupled Head')
+        self.in_dim = in_dim
+        self.num_cls_head=cfg['num_cls_head']
+        self.num_reg_head=cfg['num_reg_head']
+        self.act_type=cfg['head_act']
+        self.norm_type=cfg['head_norm']
+
+        # cls head
+        cls_feats = []
+        self.cls_out_dim = max(out_dim, num_classes)
+        for i in range(cfg['num_cls_head']):
+            if i == 0:
+                cls_feats.append(
+                    Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+            else:
+                cls_feats.append(
+                    Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+                
+        # reg head
+        reg_feats = []
+        self.reg_out_dim = max(out_dim, 64)
+        for i in range(cfg['num_reg_head']):
+            if i == 0:
+                reg_feats.append(
+                    Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+            else:
+                reg_feats.append(
+                    Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+
+# build detection head
+def build_head(cfg, in_dim, out_dim, num_classes=80):
+    head = DecoupledHead(cfg, in_dim, out_dim, num_classes) 
+
+    return head
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'head_depthwise': False,
+        'reg_max': 16,
+    }
+    fpn_dims = [256, 512, 512]
+    # Head-1
+    model = build_head(cfg, 256, fpn_dims, num_classes=80)
+    x = torch.randn(1, 256, 80, 80)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-1: Params : {:.2f} M'.format(params / 1e6))
+
+    # Head-2
+    model = build_head(cfg, 512, fpn_dims, num_classes=80)
+    x = torch.randn(1, 512, 40, 40)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-2: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-2: Params : {:.2f} M'.format(params / 1e6))
+
+    # Head-3
+    model = build_head(cfg, 512, fpn_dims, num_classes=80)
+    x = torch.randn(1, 512, 20, 20)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-3: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-3: Params : {:.2f} M'.format(params / 1e6))

+ 95 - 0
models/yolov5/yolov5_neck.py

@@ -0,0 +1,95 @@
+import torch
+import torch.nn as nn
+from .yolov5_basic import Conv
+
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    def __init__(self, in_dim, out_dim, expand_ratio=0.5, pooling_size=5, act_type='lrelu', norm_type='BN'):
+        super().__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(inter_dim * 4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.m = nn.MaxPool2d(kernel_size=pooling_size, stride=1, padding=pooling_size // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+
+
+# SPPF block with CSP module
+class SPPFBlockCSP(nn.Module):
+    """
+        CSP Spatial Pyramid Pooling Block
+    """
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio=0.5,
+                 pooling_size=5,
+                 act_type='lrelu',
+                 norm_type='BN',
+                 depthwise=False
+                 ):
+        super(SPPFBlockCSP, self).__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.m = nn.Sequential(
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=act_type, norm_type=norm_type, 
+                 depthwise=depthwise),
+            SPPF(inter_dim, 
+                 inter_dim, 
+                 expand_ratio=1.0, 
+                 pooling_size=pooling_size, 
+                 act_type=act_type, 
+                 norm_type=norm_type),
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=act_type, norm_type=norm_type, 
+                 depthwise=depthwise)
+        )
+        self.cv3 = Conv(inter_dim * 2, self.out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+        
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.m(x2)
+        y = self.cv3(torch.cat([x1, x3], dim=1))
+
+        return y
+
+
+def build_neck(cfg, in_dim, out_dim):
+    model = cfg['neck']
+    print('==============================')
+    print('Neck: {}'.format(model))
+    # build neck
+    if model == 'sppf':
+        neck = SPPF(
+            in_dim=in_dim,
+            out_dim=out_dim,
+            expand_ratio=cfg['expand_ratio'], 
+            pooling_size=cfg['pooling_size'],
+            act_type=cfg['neck_act'],
+            norm_type=cfg['neck_norm']
+            )
+    elif model == 'csp_sppf':
+        neck = SPPFBlockCSP(
+            in_dim=in_dim,
+            out_dim=out_dim,
+            expand_ratio=cfg['expand_ratio'], 
+            pooling_size=cfg['pooling_size'],
+            act_type=cfg['neck_act'],
+            norm_type=cfg['neck_norm'],
+            depthwise=cfg['neck_depthwise']
+            )
+
+    return neck
+        

+ 31 - 0
models/yolov7/build.py

@@ -0,0 +1,31 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+from .loss import build_criterion
+from .yolov7 import YOLOv7
+
+
+# build object detector
+def build_yolov7(args, cfg, device, num_classes=80, trainable=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+    
+    print('==============================')
+    print('Model Configuration: \n', cfg)
+    
+    model = YOLOv7(
+        cfg = cfg,
+        device = device,
+        num_classes = num_classes,
+        conf_thresh = args.conf_thresh,
+        nms_thresh = args.nms_thresh,
+        topk = args.topk,
+        trainable = trainable
+        )
+
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(cfg, device, num_classes)
+
+    return model, criterion

+ 168 - 0
models/yolov7/loss.py

@@ -0,0 +1,168 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .matcher import SimOTA
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+
+
+class Criterion(object):
+    def __init__(self, 
+                 cfg, 
+                 device, 
+                 num_classes=80):
+        self.cfg = cfg
+        self.device = device
+        self.num_classes = num_classes
+        # loss weight
+        self.loss_obj_weight = cfg['loss_obj_weight']
+        self.loss_cls_weight = cfg['loss_cls_weight']
+        self.loss_box_weight = cfg['loss_box_weight']
+        # matcher
+        matcher_config = cfg['matcher']
+        self.matcher = SimOTA(
+            num_classes=num_classes,
+            center_sampling_radius=matcher_config['center_sampling_radius'],
+            topk_candidate=matcher_config['topk_candicate']
+            )
+
+
+    def loss_objectness(self, pred_obj, gt_obj):
+        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
+
+        return loss_obj
+    
+
+    def loss_classes(self, pred_cls, gt_label):
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+
+        return loss_cls
+
+
+    def loss_bboxes(self, pred_box, gt_box):
+        # regression loss
+        ious = get_ious(pred_box,
+                        gt_box,
+                        box_mode="xyxy",
+                        iou_type='giou')
+        loss_box = 1.0 - ious
+
+        return loss_box
+
+
+    def __call__(self, outputs, targets):        
+        """
+            outputs['pred_obj']: List(Tensor) [B, M, 1]
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        bs = outputs['pred_cls'][0].shape[0]
+        device = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        anchors = outputs['anchors']
+        # preds: [B, M, C]
+        obj_preds = torch.cat(outputs['pred_obj'], dim=1)
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+
+        # label assignment
+        cls_targets = []
+        box_targets = []
+        obj_targets = []
+        fg_masks = []
+
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
+                num_anchors = sum([ab.shape[0] for ab in anchors])
+                # There is no valid gt
+                cls_target = obj_preds.new_zeros((0, self.num_classes))
+                box_target = obj_preds.new_zeros((0, 4))
+                obj_target = obj_preds.new_zeros((num_anchors, 1))
+                fg_mask = obj_preds.new_zeros(num_anchors).bool()
+            else:
+                (
+                    gt_matched_classes,
+                    fg_mask,
+                    pred_ious_this_matching,
+                    matched_gt_inds,
+                    num_fg_img,
+                ) = self.matcher(
+                    fpn_strides = fpn_strides,
+                    anchors = anchors,
+                    pred_obj = obj_preds[batch_idx],
+                    pred_cls = cls_preds[batch_idx], 
+                    pred_box = box_preds[batch_idx],
+                    tgt_labels = tgt_labels,
+                    tgt_bboxes = tgt_bboxes
+                    )
+
+                obj_target = fg_mask.unsqueeze(-1)
+                cls_target = F.one_hot(gt_matched_classes.long(), self.num_classes)
+                cls_target = cls_target * pred_ious_this_matching.unsqueeze(-1)
+                box_target = tgt_bboxes[matched_gt_inds]
+
+            cls_targets.append(cls_target)
+            box_targets.append(box_target)
+            obj_targets.append(obj_target)
+            fg_masks.append(fg_mask)
+
+        cls_targets = torch.cat(cls_targets, 0)
+        box_targets = torch.cat(box_targets, 0)
+        obj_targets = torch.cat(obj_targets, 0)
+        fg_masks = torch.cat(fg_masks, 0)
+        num_fgs = fg_masks.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+
+        # obj loss
+        loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
+        loss_obj = loss_obj.sum() / num_fgs
+        
+        # cls loss
+        cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
+        loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # regression loss
+        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets)
+        loss_box = loss_box.sum() / num_fgs
+
+        # total loss
+        losses = self.loss_obj_weight * loss_obj + \
+                 self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box
+
+        loss_dict = dict(
+                loss_obj = loss_obj,
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                losses = losses
+        )
+
+        return loss_dict
+    
+
+def build_criterion(cfg, device, num_classes):
+    criterion = Criterion(
+        cfg=cfg,
+        device=device,
+        num_classes=num_classes
+        )
+
+    return criterion
+
+
+if __name__ == "__main__":
+    pass

+ 204 - 0
models/yolov7/matcher.py

@@ -0,0 +1,204 @@
+import torch
+import torch.nn.functional as F
+from utils.box_ops import *
+
+
+
+# YOLOX SimOTA
+class SimOTA(object):
+    def __init__(self, 
+                 num_classes,
+                 center_sampling_radius,
+                 topk_candidate
+                 ) -> None:
+        self.num_classes = num_classes
+        self.center_sampling_radius = center_sampling_radius
+        self.topk_candidate = topk_candidate
+
+
+    @torch.no_grad()
+    def __call__(self, 
+                 fpn_strides, 
+                 anchors, 
+                 pred_obj, 
+                 pred_cls, 
+                 pred_box, 
+                 tgt_labels,
+                 tgt_bboxes):
+        # [M,]
+        strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
+                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
+        # List[F, M, 2] -> [M, 2]
+        anchors = torch.cat(anchors, dim=0)
+        num_anchor = anchors.shape[0]        
+        num_gt = len(tgt_labels)
+
+        fg_mask, is_in_boxes_and_center = \
+            self.get_in_boxes_info(
+                tgt_bboxes,
+                anchors,
+                strides,
+                num_anchor,
+                num_gt
+                )
+
+        obj_preds_ = pred_obj[fg_mask]   # [Mp, 1]
+        cls_preds_ = pred_cls[fg_mask]   # [Mp, C]
+        box_preds_ = pred_box[fg_mask]   # [Mp, 4]
+        num_in_boxes_anchor = box_preds_.shape[0]
+
+        # [N, Mp]
+        pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds_)
+        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
+
+        # [N, C] -> [N, Mp, C]
+        gt_cls = (
+            F.one_hot(tgt_labels.long(), self.num_classes)
+            .float()
+            .unsqueeze(1)
+            .repeat(1, num_in_boxes_anchor, 1)
+        )
+
+        with torch.cuda.amp.autocast(enabled=False):
+            score_preds_ = torch.sqrt(
+                cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
+                * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
+            ) # [N, Mp, C]
+            pair_wise_cls_loss = F.binary_cross_entropy(
+                score_preds_, gt_cls, reduction="none"
+            ).sum(-1) # [N, Mp]
+        del score_preds_
+
+        cost = (
+            pair_wise_cls_loss
+            + 3.0 * pair_wise_ious_loss
+            + 100000.0 * (~is_in_boxes_and_center)
+        ) # [N, Mp]
+
+        (
+            num_fg,
+            gt_matched_classes,         # [num_fg,]
+            pred_ious_this_matching,    # [num_fg,]
+            matched_gt_inds,            # [num_fg,]
+        ) = self.dynamic_k_matching(
+            cost,
+            pair_wise_ious,
+            tgt_labels,
+            num_gt,
+            fg_mask
+            )
+        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
+
+        return (
+                gt_matched_classes,
+                fg_mask,
+                pred_ious_this_matching,
+                matched_gt_inds,
+                num_fg,
+        )
+
+
+    def get_in_boxes_info(
+        self,
+        gt_bboxes,   # [N, 4]
+        anchors,     # [M, 2]
+        strides,     # [M,]
+        num_anchors, # M
+        num_gt,      # N
+        ):
+        # anchor center
+        x_centers = anchors[:, 0]
+        y_centers = anchors[:, 1]
+
+        # [M,] -> [1, M] -> [N, M]
+        x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
+        y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
+
+        # [N,] -> [N, 1] -> [N, M]
+        gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
+        gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
+        gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
+        gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
+
+        b_l = x_centers - gt_bboxes_l
+        b_r = gt_bboxes_r - x_centers
+        b_t = y_centers - gt_bboxes_t
+        b_b = gt_bboxes_b - y_centers
+        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
+
+        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
+        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
+        # in fixed center
+        center_radius = self.center_sampling_radius
+
+        # [N, 2]
+        gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
+        
+        # [1, M]
+        center_radius_ = center_radius * strides.unsqueeze(0)
+
+        gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
+        gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
+        gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
+        gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
+
+        c_l = x_centers - gt_bboxes_l
+        c_r = gt_bboxes_r - x_centers
+        c_t = y_centers - gt_bboxes_t
+        c_b = gt_bboxes_b - y_centers
+        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
+        is_in_centers = center_deltas.min(dim=-1).values > 0.0
+        is_in_centers_all = is_in_centers.sum(dim=0) > 0
+
+        # in boxes and in centers
+        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
+
+        is_in_boxes_and_center = (
+            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
+        )
+        return is_in_boxes_anchor, is_in_boxes_and_center
+    
+    
+    def dynamic_k_matching(
+        self, 
+        cost, 
+        pair_wise_ious, 
+        gt_classes, 
+        num_gt, 
+        fg_mask
+        ):
+        # Dynamic K
+        # ---------------------------------------------------------------
+        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
+
+        ious_in_boxes_matrix = pair_wise_ious
+        n_candidate_k = min(self.topk_candidate, ious_in_boxes_matrix.size(1))
+        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+        dynamic_ks = dynamic_ks.tolist()
+        for gt_idx in range(num_gt):
+            _, pos_idx = torch.topk(
+                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
+            )
+            matching_matrix[gt_idx][pos_idx] = 1
+
+        del topk_ious, dynamic_ks, pos_idx
+
+        anchor_matching_gt = matching_matrix.sum(0)
+        if (anchor_matching_gt > 1).sum() > 0:
+            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
+            matching_matrix[:, anchor_matching_gt > 1] *= 0
+            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+        num_fg = fg_mask_inboxes.sum().item()
+
+        fg_mask[fg_mask.clone()] = fg_mask_inboxes
+
+        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
+        gt_matched_classes = gt_classes[matched_gt_inds]
+
+        pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
+            fg_mask_inboxes
+        ]
+        return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
+    

+ 292 - 0
models/yolov7/yolov7.py

@@ -0,0 +1,292 @@
+import torch
+import torch.nn as nn
+
+from utils.nms import multiclass_nms
+
+from .yolov7_backbone import build_backbone
+from .yolov7_neck import build_neck
+from .yolov7_fpn import build_fpn
+from .yolov7_head import build_head
+
+
+# YOLOv7
+class YOLOv7(nn.Module):
+    def __init__(self,
+                 cfg,
+                 device,
+                 num_classes=20,
+                 conf_thresh=0.01,
+                 topk=100,
+                 nms_thresh=0.5,
+                 trainable=False):
+        super(YOLOv7, self).__init__()
+        # ------------------- Basic parameters -------------------
+        self.cfg = cfg                                 # 模型配置文件
+        self.device = device                           # cuda或者是cpu
+        self.num_classes = num_classes                 # 类别的数量
+        self.trainable = trainable                     # 训练的标记
+        self.conf_thresh = conf_thresh                 # 得分阈值
+        self.nms_thresh = nms_thresh                   # NMS阈值
+        self.topk = topk                               # topk
+        self.stride = [8, 16, 32]                      # 网络的输出步长
+        # ------------------- Anchor box -------------------
+        self.num_levels = 3
+        self.num_anchors = len(cfg['anchor_size']) // self.num_levels
+        self.anchor_size = torch.as_tensor(
+            cfg['anchor_size']
+            ).view(self.num_levels, self.num_anchors, 2) # [S, A, 2]
+        
+        # ------------------- Network Structure -------------------
+        ## 主干网络
+        self.backbone, feats_dim = build_backbone(cfg, trainable&cfg['pretrained'])
+
+        ## 颈部网络: SPP模块
+        self.neck = build_neck(cfg, in_dim=feats_dim[-1], out_dim=feats_dim[-1]//2)
+        feats_dim[-1] = self.neck.out_dim
+
+        ## 颈部网络: 特征金字塔
+        self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=256)
+        self.head_dim = self.fpn.out_dim
+
+        ## 检测头
+        self.non_shared_heads = nn.ModuleList(
+            [build_head(cfg, head_dim, head_dim, num_classes) 
+            for head_dim in self.head_dim
+            ])
+
+        ## 预测层
+        self.obj_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_out_dim, 1, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ]) 
+        self.cls_preds = nn.ModuleList(
+                            [nn.Conv2d(head.cls_out_dim, self.num_classes, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ]) 
+        self.reg_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_out_dim, 4, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ])                 
+
+        # --------- Network Initialization ----------
+        # init bias
+        self.init_yolo()
+
+
+    def init_yolo(self): 
+        # Init yolo
+        for m in self.modules():
+            if isinstance(m, nn.BatchNorm2d):
+                m.eps = 1e-3
+                m.momentum = 0.03    
+        # Init bias
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        # obj pred
+        for obj_pred in self.obj_preds:
+            b = obj_pred.bias.view(1, -1)
+            b.data.fill_(bias_value.item())
+            obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # cls pred
+        for cls_pred in self.cls_preds:
+            b = cls_pred.bias.view(1, -1)
+            b.data.fill_(bias_value.item())
+            cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred
+        for reg_pred in self.reg_preds:
+            b = reg_pred.bias.view(-1, )
+            b.data.fill_(1.0)
+            reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+            w = reg_pred.weight
+            w.data.fill_(0.)
+            reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    def generate_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchor_xy += 0.5  # add center offset
+        anchor_xy *= self.stride[level]
+        anchors = anchor_xy.to(self.device)
+
+        return anchors
+        
+
+    def decode_boxes(self, anchors, reg_pred, stride):
+        """
+            anchors:  (List[Tensor]) [1, M, 2] or [M, 2]
+            reg_pred: (List[Tensor]) [B, M, 4] or [M, 4]
+        """
+        # center of bbox
+        pred_ctr_xy = anchors + reg_pred[..., :2] * stride
+        # size of bbox
+        pred_box_wh = reg_pred[..., 2:].exp() * stride
+
+        pred_x1y1 = pred_ctr_xy - 0.5 * pred_box_wh
+        pred_x2y2 = pred_ctr_xy + 0.5 * pred_box_wh
+        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        return pred_box
+
+
+    def post_process(self, obj_preds, cls_preds, reg_preds, anchors):
+        """
+        Input:
+            obj_preds: List(Tensor) [[H x W, 1], ...]
+            cls_preds: List(Tensor) [[H x W, C], ...]
+            reg_preds: List(Tensor) [[H x W, 4], ...]
+            anchors:  List(Tensor) [[H x W, 2], ...]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for level, (obj_pred_i, cls_pred_i, reg_pred_i, anchors_i) in enumerate(zip(obj_preds, cls_preds, reg_preds, anchors)):
+            # (H x W x C,)
+            scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk, reg_pred_i.size(0))
+
+            # torch.sort is actually faster than .topk (at least on GPUs)
+            predicted_prob, topk_idxs = scores_i.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            scores = topk_scores[keep_idxs]
+            topk_idxs = topk_idxs[keep_idxs]
+
+            anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+            labels = topk_idxs % self.num_classes
+
+            reg_pred_i = reg_pred_i[anchor_idxs]
+            anchors_i = anchors_i[anchor_idxs]
+
+            # decode box: [M, 4]
+            bboxes = self.decode_boxes(anchors_i, reg_pred_i, self.stride[level])
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
+
+        return bboxes, scores, labels
+
+
+    @torch.no_grad()
+    def inference_single_image(self, x):
+        # 主干网络
+        pyramid_feats = self.backbone(x)
+
+        # 颈部网络
+        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+        # 特征金字塔
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # 检测头
+        all_obj_preds = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_anchors = []
+        for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+            cls_feat, reg_feat = head(feat)
+
+            # [1, C, H, W]
+            obj_pred = self.obj_preds[level](reg_feat)
+            cls_pred = self.cls_preds[level](cls_feat)
+            reg_pred = self.reg_preds[level](reg_feat)
+
+            # anchors: [M, 2]
+            fmp_size = cls_pred.shape[-2:]
+            anchors = self.generate_anchors(level, fmp_size)
+
+            # [1, C, H, W] -> [H, W, C] -> [M, C]
+            obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
+            cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
+            reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
+
+            all_obj_preds.append(obj_pred)
+            all_cls_preds.append(cls_pred)
+            all_reg_preds.append(reg_pred)
+            all_anchors.append(anchors)
+
+        # post process
+        bboxes, scores, labels = self.post_process(
+            all_obj_preds, all_cls_preds, all_reg_preds, all_anchors)
+        
+        return bboxes, scores, labels
+
+
+    def forward(self, x):
+        if not self.trainable:
+            return self.inference_single_image(x)
+        else:
+            # 主干网络
+            pyramid_feats = self.backbone(x)
+
+            # 颈部网络
+            pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+            # 特征金字塔
+            pyramid_feats = self.fpn(pyramid_feats)
+
+            # 检测头
+            all_anchors = []
+            all_obj_preds = []
+            all_cls_preds = []
+            all_box_preds = []
+            for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+                cls_feat, reg_feat = head(feat)
+
+                # [B, C, H, W]
+                obj_pred = self.obj_preds[level](reg_feat)
+                cls_pred = self.cls_preds[level](cls_feat)
+                reg_pred = self.reg_preds[level](reg_feat)
+
+                B, _, H, W = cls_pred.size()
+                fmp_size = [H, W]
+                # generate anchor boxes: [M, 4]
+                anchors = self.generate_anchors(level, fmp_size)
+                
+                # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+                obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
+                cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+
+                # decode box: [M, 4]
+                box_pred = self.decode_boxes(anchors, reg_pred, self.stride[level])
+
+                all_obj_preds.append(obj_pred)
+                all_cls_preds.append(cls_pred)
+                all_box_preds.append(box_pred)
+                all_anchors.append(anchors)
+            
+            # output dict
+            outputs = {"pred_obj": all_obj_preds,        # List(Tensor) [B, M, 1]
+                       "pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
+                       "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
+                       "anchors": all_anchors,           # List(Tensor) [B, M, 2]
+                       'strides': self.stride}           # List(Int) [8, 16, 32]
+
+            return outputs 

+ 131 - 0
models/yolov7/yolov7_backbone.py

@@ -0,0 +1,131 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov7_basic import Conv, ELANBlock, DownSample
+except:
+    from yolov7_basic import Conv, ELANBlock, DownSample
+    
+
+model_urls = {
+    "elannet": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/yolov7_elannet.pth",
+}
+
+# --------------------- CSPDarkNet-53 -----------------------
+# ELANNet
+class ELANNet(nn.Module):
+    """
+    ELAN-Net of YOLOv7-L.
+    """
+    def __init__(self, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANNet, self).__init__()
+        self.feat_dims = [512, 1024, 1024]
+        
+        # P1/2
+        self.layer_1 = nn.Sequential(
+            Conv(3, 32, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise),      
+            Conv(32, 64, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            Conv(64, 64, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P2/4
+        self.layer_2 = nn.Sequential(   
+            Conv(64, 128, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
+            ELANBlock(in_dim=128, out_dim=256, expand_ratio=0.5,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            DownSample(in_dim=256, act_type=act_type),             
+            ELANBlock(in_dim=256, out_dim=512, expand_ratio=0.5,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            DownSample(in_dim=512, act_type=act_type),             
+            ELANBlock(in_dim=512, out_dim=1024, expand_ratio=0.5,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            DownSample(in_dim=1024, act_type=act_type),             
+            ELANBlock(in_dim=1024, out_dim=1024, expand_ratio=0.25,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# --------------------- Functions -----------------------
+def build_backbone(cfg, pretrained=False): 
+    """Constructs a ELANNet model.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    backbone = ELANNet(cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
+    feat_dims = backbone.feat_dims
+
+    if pretrained:
+        url = model_urls['elannet']
+        if url is not None:
+            print('Loading pretrained weight ...')
+            checkpoint = torch.hub.load_state_dict_from_url(
+                url=url, map_location="cpu", check_hash=True)
+            # checkpoint state dict
+            checkpoint_state_dict = checkpoint.pop("model")
+            # model state dict
+            model_state_dict = backbone.state_dict()
+            # check
+            for k in list(checkpoint_state_dict.keys()):
+                if k in model_state_dict:
+                    shape_model = tuple(model_state_dict[k].shape)
+                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                    if shape_model != shape_checkpoint:
+                        checkpoint_state_dict.pop(k)
+                else:
+                    checkpoint_state_dict.pop(k)
+                    print(k)
+
+            backbone.load_state_dict(checkpoint_state_dict)
+        else:
+            print('No backbone pretrained: ELANNet')        
+
+    return backbone, feat_dims
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'pretrained': False,
+
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': False,
+        'p6_feat': False,
+        'p7_feat': False,
+    }
+    model, feats = build_backbone(cfg)
+    x = torch.randn(1, 3, 224, 224)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    x = torch.randn(1, 3, 224, 224)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 230 - 0
models/yolov7/yolov7_basic.py

@@ -0,0 +1,230 @@
+import torch
+import torch.nn as nn
+
+
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+
+
+# Basic conv layer
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type='lrelu',     # activation
+                 norm_type='BN',       # normalization
+                 depthwise=False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+
+# ELAN Block
+class ELANBlock(nn.Module):
+    """
+    ELAN BLock of YOLOv7's backbone
+    """
+    def __init__(self, in_dim, out_dim, expand_ratio=0.5, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANBlock, self).__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv3 = nn.Sequential(*[
+            Conv(inter_dim, inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            for _ in range(2)
+        ])
+        self.cv4 = nn.Sequential(*[
+            Conv(inter_dim, inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            for _ in range(2)
+        ])
+
+        self.out = Conv(inter_dim*4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+
+
+    def forward(self, x):
+        """
+        Input:
+            x: [B, C, H, W]
+        Output:
+            out: [B, 2C, H, W]
+        """
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.cv3(x2)
+        x4 = self.cv4(x3)
+
+        # [B, C, H, W] -> [B, 2C, H, W]
+        out = self.out(torch.cat([x1, x2, x3, x4], dim=1))
+
+        return out
+
+
+# DownSample Block
+class DownSample(nn.Module):
+    def __init__(self, in_dim, act_type='silu', norm_type='BN'):
+        super().__init__()
+        inter_dim = in_dim // 2
+        self.mp = nn.MaxPool2d((2, 2), 2)
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = nn.Sequential(
+            Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type),
+            Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type)
+        )
+
+    def forward(self, x):
+        """
+        Input:
+            x: [B, C, H, W]
+        Output:
+            out: [B, C, H//2, W//2]
+        """
+        # [B, C, H, W] -> [B, C//2, H//2, W//2]
+        x1 = self.cv1(self.mp(x))
+        x2 = self.cv2(x)
+
+        # [B, C, H//2, W//2]
+        out = torch.cat([x1, x2], dim=1)
+
+        return out
+
+
+# ELAN Block for PaFPN
+class ELANBlockFPN(nn.Module):
+    """
+    ELAN BLock of YOLOv7's head
+    """
+    def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANBlockFPN, self).__init__()
+        # Basic parameters
+        e1, e2 = 0.5, 0.5
+        width = 4
+        depth = 1
+        inter_dim = int(in_dim * e1)
+        inter_dim2 = int(inter_dim * e2) 
+        # Network structure
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv3 = nn.ModuleList()
+        for idx in range(width):
+            if idx == 0:
+                cvs = [Conv(inter_dim, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
+            else:
+                cvs = [Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
+            # deeper
+            if depth > 1:
+                for _ in range(1, depth):
+                    cvs.append(Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise))
+                self.cv3.append(nn.Sequential(*cvs))
+            else:
+                self.cv3.append(cvs[0])
+
+        self.out = Conv(inter_dim*2+inter_dim2*len(self.cv3), out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+
+    def forward(self, x):
+        """
+        Input:
+            x: [B, C_in, H, W]
+        Output:
+            out: [B, C_out, H, W]
+        """
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        inter_outs = [x1, x2]
+        for m in self.cv3:
+            y1 = inter_outs[-1]
+            y2 = m(y1)
+            inter_outs.append(y2)
+
+        # [B, C_in, H, W] -> [B, C_out, H, W]
+        out = self.out(torch.cat(inter_outs, dim=1))
+
+        return out
+
+
+# DownSample Block for PaFPN
+class DownSampleFPN(nn.Module):
+    def __init__(self, in_dim, act_type='silu', norm_type='BN', depthwise=False):
+        super().__init__()
+        inter_dim = in_dim
+        self.mp = nn.MaxPool2d((2, 2), 2)
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = nn.Sequential(
+            Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type),
+            Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+    def forward(self, x):
+        """
+        Input:
+            x: [B, C, H, W]
+        Output:
+            out: [B, 2C, H//2, W//2]
+        """
+        # [B, C, H, W] -> [B, C//2, H//2, W//2]
+        x1 = self.cv1(self.mp(x))
+        x2 = self.cv2(x)
+
+        # [B, C, H//2, W//2]
+        out = torch.cat([x1, x2], dim=1)
+
+        return out

+ 123 - 0
models/yolov7/yolov7_fpn.py

@@ -0,0 +1,123 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .yolov7_basic import Conv, ELANBlockFPN, DownSampleFPN
+
+
+# PaFPN-ELAN (YOLOv7's)
+class Yolov7PaFPN(nn.Module):
+    def __init__(self, 
+                 in_dims=[512, 1024, 512],
+                 out_dim=None,
+                 act_type='silu',
+                 norm_type='BN',
+                 depthwise=False):
+        super(Yolov7PaFPN, self).__init__()
+        self.in_dims = in_dims
+        c3, c4, c5 = in_dims
+
+        # top dwon
+        ## P5 -> P4
+        self.reduce_layer_1 = Conv(c5, 256, k=1, norm_type=norm_type, act_type=act_type)
+        self.reduce_layer_2 = Conv(c4, 256, k=1, norm_type=norm_type, act_type=act_type)
+        self.top_down_layer_1 = ELANBlockFPN(in_dim=256 + 256,
+                                             out_dim=256,
+                                             act_type=act_type,
+                                             norm_type=norm_type,
+                                             depthwise=depthwise
+                                             )
+
+        # P4 -> P3
+        self.reduce_layer_3 = Conv(256, 128, k=1, norm_type=norm_type, act_type=act_type)
+        self.reduce_layer_4 = Conv(c3, 128, k=1, norm_type=norm_type, act_type=act_type)
+        self.top_down_layer_2 = ELANBlockFPN(in_dim=128 + 128,
+                                             out_dim=128,
+                                             act_type=act_type,
+                                             norm_type=norm_type,
+                                             depthwise=depthwise
+                                             )
+
+        # bottom up
+        # P3 -> P4
+        self.downsample_layer_1 = DownSampleFPN(128, act_type=act_type,
+                                    norm_type=norm_type, depthwise=depthwise)
+        self.bottom_up_layer_1 = ELANBlockFPN(in_dim=256 + 256,
+                                              out_dim=256,
+                                              act_type=act_type,
+                                              norm_type=norm_type,
+                                              depthwise=depthwise
+                                              )
+
+        # P4 -> P5
+        self.downsample_layer_2 = DownSampleFPN(256, act_type=act_type,
+                                    norm_type=norm_type, depthwise=depthwise)
+        self.bottom_up_layer_2 = ELANBlockFPN(in_dim=512 + c5,
+                                              out_dim=512,
+                                              act_type=act_type,
+                                              norm_type=norm_type,
+                                              depthwise=depthwise
+                                              )
+
+        # output proj layers
+        if out_dim is not None:
+            self.out_layers = nn.ModuleList([
+                Conv(in_dim, out_dim, k=1,
+                     norm_type=norm_type, act_type=act_type)
+                     for in_dim in [128, 256, 512]
+                     ])
+            self.out_dim = [out_dim] * 3
+        else:
+            self.out_layers = None
+            self.out_dim = [128, 256, 512]
+
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        # Top down
+        ## P5 -> P4
+        c6 = self.reduce_layer_1(c5)
+        c7 = F.interpolate(c6, scale_factor=2.0)
+        c8 = torch.cat([c7, self.reduce_layer_2(c4)], dim=1)
+        c9 = self.top_down_layer_1(c8)
+        ## P4 -> P3
+        c10 = self.reduce_layer_3(c9)
+        c11 = F.interpolate(c10, scale_factor=2.0)
+        c12 = torch.cat([c11, self.reduce_layer_4(c3)], dim=1)
+        c13 = self.top_down_layer_2(c12)
+
+        # Bottom up
+        # p3 -> P4
+        c14 = self.downsample_layer_1(c13)
+        c15 = torch.cat([c14, c9], dim=1)
+        c16 = self.bottom_up_layer_1(c15)
+        # P4 -> P5
+        c17 = self.downsample_layer_2(c16)
+        c18 = torch.cat([c17, c5], dim=1)
+        c19 = self.bottom_up_layer_2(c18)
+
+        out_feats = [c13, c16, c19] # [P3, P4, P5]
+        
+        # output proj layers
+        if self.out_layers is not None:
+            out_feats_proj = []
+            for feat, layer in zip(out_feats, self.out_layers):
+                out_feats_proj.append(layer(feat))
+            return out_feats_proj
+
+        return out_feats
+
+
+def build_fpn(cfg, in_dims, out_dim=None):
+    model = cfg['fpn']
+    # build neck
+    if model == 'yolov7_pafpn':
+        fpn_net = Yolov7PaFPN(in_dims=in_dims,
+                             out_dim=out_dim,
+                             act_type=cfg['fpn_act'],
+                             norm_type=cfg['fpn_norm'],
+                             depthwise=cfg['fpn_depthwise']
+                             )
+
+
+    return fpn_net

+ 137 - 0
models/yolov7/yolov7_head.py

@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+try:
+    from .yolov7_basic import Conv
+except:
+    from yolov7_basic import Conv
+
+
+class DecoupledHead(nn.Module):
+    def __init__(self, cfg, in_dim, out_dim, num_classes=80):
+        super().__init__()
+        print('==============================')
+        print('Head: Decoupled Head')
+        self.in_dim = in_dim
+        self.num_cls_head=cfg['num_cls_head']
+        self.num_reg_head=cfg['num_reg_head']
+        self.act_type=cfg['head_act']
+        self.norm_type=cfg['head_norm']
+
+        # cls head
+        cls_feats = []
+        self.cls_out_dim = max(out_dim, num_classes)
+        for i in range(cfg['num_cls_head']):
+            if i == 0:
+                cls_feats.append(
+                    Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+            else:
+                cls_feats.append(
+                    Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+                
+        # reg head
+        reg_feats = []
+        self.reg_out_dim = max(out_dim, 64)
+        for i in range(cfg['num_reg_head']):
+            if i == 0:
+                reg_feats.append(
+                    Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+            else:
+                reg_feats.append(
+                    Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+
+# build detection head
+def build_head(cfg, in_dim, out_dim, num_classes=80):
+    head = DecoupledHead(cfg, in_dim, out_dim, num_classes) 
+
+    return head
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'head_depthwise': False,
+        'reg_max': 16,
+    }
+    fpn_dims = [256, 512, 512]
+    # Head-1
+    model = build_head(cfg, 256, fpn_dims, num_classes=80)
+    x = torch.randn(1, 256, 80, 80)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-1: Params : {:.2f} M'.format(params / 1e6))
+
+    # Head-2
+    model = build_head(cfg, 512, fpn_dims, num_classes=80)
+    x = torch.randn(1, 512, 40, 40)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-2: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-2: Params : {:.2f} M'.format(params / 1e6))
+
+    # Head-3
+    model = build_head(cfg, 512, fpn_dims, num_classes=80)
+    x = torch.randn(1, 512, 20, 20)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-3: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-3: Params : {:.2f} M'.format(params / 1e6))

+ 95 - 0
models/yolov7/yolov7_neck.py

@@ -0,0 +1,95 @@
+import torch
+import torch.nn as nn
+from .yolov7_basic import Conv
+
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    def __init__(self, in_dim, out_dim, expand_ratio=0.5, pooling_size=5, act_type='lrelu', norm_type='BN'):
+        super().__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(inter_dim * 4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.m = nn.MaxPool2d(kernel_size=pooling_size, stride=1, padding=pooling_size // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+
+
+# SPPF block with CSP module
+class SPPFBlockCSP(nn.Module):
+    """
+        CSP Spatial Pyramid Pooling Block
+    """
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio=0.5,
+                 pooling_size=5,
+                 act_type='lrelu',
+                 norm_type='BN',
+                 depthwise=False
+                 ):
+        super(SPPFBlockCSP, self).__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.m = nn.Sequential(
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=act_type, norm_type=norm_type, 
+                 depthwise=depthwise),
+            SPPF(inter_dim, 
+                 inter_dim, 
+                 expand_ratio=1.0, 
+                 pooling_size=pooling_size, 
+                 act_type=act_type, 
+                 norm_type=norm_type),
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=act_type, norm_type=norm_type, 
+                 depthwise=depthwise)
+        )
+        self.cv3 = Conv(inter_dim * 2, self.out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+        
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.m(x2)
+        y = self.cv3(torch.cat([x1, x3], dim=1))
+
+        return y
+
+
+def build_neck(cfg, in_dim, out_dim):
+    model = cfg['neck']
+    print('==============================')
+    print('Neck: {}'.format(model))
+    # build neck
+    if model == 'sppf':
+        neck = SPPF(
+            in_dim=in_dim,
+            out_dim=out_dim,
+            expand_ratio=cfg['expand_ratio'], 
+            pooling_size=cfg['pooling_size'],
+            act_type=cfg['neck_act'],
+            norm_type=cfg['neck_norm']
+            )
+    elif model == 'csp_sppf':
+        neck = SPPFBlockCSP(
+            in_dim=in_dim,
+            out_dim=out_dim,
+            expand_ratio=cfg['expand_ratio'], 
+            pooling_size=cfg['pooling_size'],
+            act_type=cfg['neck_act'],
+            norm_type=cfg['neck_norm'],
+            depthwise=cfg['neck_depthwise']
+            )
+
+    return neck
+        

+ 1 - 1
test.py

@@ -40,7 +40,7 @@ def parse_args():
 
     # model
     parser.add_argument('-m', '--model', default='yolov1', type=str,
-                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolox'], help='build yolo')
+                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolov5', 'yolov7', 'yolox'], help='build yolo')
     parser.add_argument('--weight', default=None,
                         type=str, help='Trained state_dict file path to open')
     parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,

+ 1 - 1
train.py

@@ -56,7 +56,7 @@ def parse_args():
 
     # model
     parser.add_argument('-m', '--model', default='yolov1', type=str,
-                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolox'], help='build yolo')
+                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolov5', 'yolov7', 'yolox'], help='build yolo')
     parser.add_argument('-ct', '--conf_thresh', default=0.005, type=float,
                         help='confidence threshold')
     parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,