yjh0410 2 年 前
コミット
88b035039c

+ 3 - 3
README.md

@@ -94,9 +94,9 @@ python train.py --cuda -d voc --root path/to/VOCdevkit -v yolov1 -bs 16 --max_ep
 | Model  | Scale |  IP  | mAP  | FPS<sup>3090<br>FP32-bs1 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
 |--------|-------|------|------|--------------------------|-------------------|--------------------|--------|
 | YOLOv1 |  640  |  √   | 76.7 |                          |   37.8            |   21.3             |  |
-| YOLOv2 |  640  |  √   |     |                          |   53.9            |   30.9             |  |
-| YOLOv3 |  640  |  √   |     |                          |                   |                    |  |
-| YOLOv4 |  640  |  √   |     |                          |                   |                    |  |
+| YOLOv2 |  640  |  √   |      |                          |   53.9            |   30.9             |  |
+| YOLOv3 |  640  |  √   |      |                          |                   |                    |  |
+| YOLOv4 |  640  |  √   |      |                          |                   |                    |  |
 
 *All models are trained with ImageNet pretrained weight (IP). All FLOPs are measured with a 640x640 image size on VOC2007 test. The FPS is measured with batch size 1 on 3090 GPU from the model inference to the NMS operation.*
 

+ 4 - 3
config/__init__.py

@@ -1,6 +1,7 @@
 # ------------------ Model Config ----------------------
 from .yolov1_config import yolov1_cfg
 from .yolov2_config import yolov2_cfg
+from .yolov3_config import yolov3_cfg
 
 
 def build_model_config(args):
@@ -12,9 +13,9 @@ def build_model_config(args):
     # YOLOv2
     elif args.model == 'yolov2':
         cfg = yolov2_cfg
-    # # YOLOv3
-    # elif args.model == 'yolov3':
-    #     cfg = yolov3_cfg
+    # YOLOv3
+    elif args.model == 'yolov3':
+        cfg = yolov3_cfg
     # # YOLOv4
     # elif args.model == 'yolov4':
     #     cfg = yolov4_cfg

+ 56 - 0
config/yolov3_config.py

@@ -0,0 +1,56 @@
+# YOLOv3 Config
+
+yolov3_cfg = {
+    # input
+    'trans_type': 'yolov5',
+    # model
+    'backbone': 'darknet53',
+    'pretrained': True,
+    'stride': [8, 16, 32],  # P3, P4, P5
+    'width': 1.0,
+    'depth': 1.0,
+    # neck
+    'neck': 'sppf',
+    'expand_ratio': 0.5,
+    'pooling_size': 5,
+    'neck_act': 'silu',
+    'neck_norm': 'BN',
+    'neck_depthwise': False,
+     # fpn
+    'fpn': 'yolo_fpn',
+    'fpn_act': 'silu',
+    'fpn_norm': 'BN',
+    'fpn_depthwise': False,
+    # head
+    'head': 'decoupled_head',
+    'head_act': 'silu',
+    'head_norm': 'BN',
+    'num_cls_head': 2,
+    'num_reg_head': 2,
+    'head_depthwise': False,
+    'anchor_size': [[10, 13],   [16, 30],   [33, 23],     # P3
+                    [30, 61],   [62, 45],   [59, 119],    # P4
+                    [116, 90],  [156, 198], [373, 326]],  # P5
+    # matcher
+    'iou_thresh': 0.5,
+    # loss weight
+    'loss_obj_weight': 1.0,
+    'loss_cls_weight': 1.0,
+    'loss_box_weight': 5.0,
+    # training configuration
+    'no_aug_epoch': 20,
+    # optimizer
+    'optimizer': 'sgd',        # optional: sgd, adam, adamw
+    'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
+    'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
+    'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
+    # model EMA
+    'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
+    'ema_tau': 2000,
+    # lr schedule
+    'scheduler': 'linear',
+    'lr0': 0.01,               # SGD: 0.01;     AdamW: 0.004
+    'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.05
+    'warmup_momentum': 0.8,
+    'warmup_bias_lr': 0.1,
+}

+ 5 - 0
models/__init__.py

@@ -4,6 +4,7 @@
 import torch
 from .yolov1.build import build_yolov1
 from .yolov2.build import build_yolov2
+from .yolov3.build import build_yolov3
 
 
 # build object detector
@@ -20,6 +21,10 @@ def build_model(args,
     elif args.model == 'yolov2':
         model, criterion = build_yolov2(
             args, model_cfg, device, num_classes, trainable)
+    # YOLOv3   
+    elif args.model == 'yolov3':
+        model, criterion = build_yolov3(
+            args, model_cfg, device, num_classes, trainable)
 
     if trainable:
         # Load pretrained weight

+ 0 - 1
models/yolov1/yolov1_backbone.py

@@ -16,7 +16,6 @@ model_urls = {
 }
 
 # --------------------- Basic Module -----------------------
-
 def conv3x3(in_planes, out_planes, stride=1):
     """3x3 convolution with padding"""
     return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,

+ 1 - 1
models/yolov2/loss.py

@@ -43,7 +43,7 @@ class Criterion(object):
 
 
     def __call__(self, outputs, targets):
-        device = outputs['pred_cls'][0].device
+        device = outputs['pred_cls'].device
         stride = outputs['stride']
         fmp_size = outputs['fmp_size']
         (

+ 1 - 1
models/yolov2/yolov2.py

@@ -47,7 +47,7 @@ class YOLOv2(nn.Module):
         ## 检测头
         self.head = build_head(cfg, head_dim, head_dim, num_classes)
 
-        ## 预测
+        ## 预测
         self.obj_pred = nn.Conv2d(head_dim, 1*self.num_anchors, kernel_size=1)
         self.cls_pred = nn.Conv2d(head_dim, num_classes*self.num_anchors, kernel_size=1)
         self.reg_pred = nn.Conv2d(head_dim, 4*self.num_anchors, kernel_size=1)

+ 4 - 1
models/yolov2/yolov2_backbone.py

@@ -1,6 +1,6 @@
 import torch
 import torch.nn as nn
-import os
+
 
 model_urls = {
     "darknet19": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/darknet19.pth",
@@ -10,6 +10,7 @@ model_urls = {
 __all__ = ['darknet19']
 
 
+# --------------------- Basic Module -----------------------
 class Conv_BN_LeakyReLU(nn.Module):
     def __init__(self, in_channels, out_channels, ksize, padding=0, stride=1, dilation=1):
         super(Conv_BN_LeakyReLU, self).__init__()
@@ -23,6 +24,7 @@ class Conv_BN_LeakyReLU(nn.Module):
         return self.convs(x)
 
 
+# --------------------- DarkNet-19 -----------------------
 class DarkNet19(nn.Module):
     def __init__(self):
         
@@ -87,6 +89,7 @@ class DarkNet19(nn.Module):
         return c5
 
 
+# --------------------- Fsnctions -----------------------
 def build_backbone(model_name='darknet19', pretrained=False):
     if model_name == 'darknet19':
         # model

+ 32 - 0
models/yolov3/build.py

@@ -0,0 +1,32 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+from .loss import build_criterion
+from .yolov3 import YOLOv3
+
+
+# build object detector
+def build_yolov3(args, cfg, device, num_classes=80, trainable=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+    
+    print('==============================')
+    print('Model Configuration: \n', cfg)
+    
+    model = YOLOv3(
+        cfg = cfg,
+        device = device,
+        img_size = args.img_size,
+        num_classes = num_classes,
+        conf_thresh = args.conf_thresh,
+        nms_thresh = args.nms_thresh,
+        topk = args.topk,
+        trainable = trainable
+        )
+
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(cfg, device, num_classes)
+
+    return model, criterion

+ 114 - 0
models/yolov3/loss.py

@@ -0,0 +1,114 @@
+import torch
+import torch.nn.functional as F
+from .matcher import Yolov3Matcher
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+
+class Criterion(object):
+    def __init__(self, cfg, device, num_classes=80):
+        self.cfg = cfg
+        self.device = device
+        self.num_classes = num_classes
+        # loss weight
+        self.loss_obj_weight = cfg['loss_obj_weight']
+        self.loss_cls_weight = cfg['loss_cls_weight']
+        self.loss_box_weight = cfg['loss_box_weight']
+
+        # matcher
+        self.matcher = Yolov3Matcher(num_classes, 3, cfg['anchor_size'], cfg['iou_thresh'])
+
+
+    def loss_objectness(self, pred_obj, gt_obj):
+        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
+
+        return loss_obj
+    
+
+    def loss_classes(self, pred_cls, gt_label):
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+
+        return loss_cls
+
+
+    def loss_bboxes(self, pred_box, gt_box):
+        # regression loss
+        ious = get_ious(pred_box,
+                        gt_box,
+                        box_mode="xyxy",
+                        iou_type='giou')
+        loss_box = 1.0 - ious
+
+        return loss_box, ious
+
+
+    def __call__(self, outputs, targets):
+        device = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        fmp_sizes = outputs['fmp_sizes']
+        (
+            gt_objectness, 
+            gt_classes, 
+            gt_bboxes,
+            ) = self.matcher(fmp_sizes=fmp_sizes, 
+                             fpn_strides=fpn_strides, 
+                             targets=targets)
+        # List[B, M, C] -> [B, M, C] -> [BM, C]
+        pred_obj = torch.cat(outputs['pred_obj'], dim=1).view(-1)                      # [BM,]
+        pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes)    # [BM, C]
+        pred_box = torch.cat(outputs['pred_box'], dim=1).view(-1, 4)                   # [BM, 4]
+       
+        gt_objectness = gt_objectness.view(-1).to(device).float()               # [BM,]
+        gt_classes = gt_classes.view(-1, self.num_classes).to(device).float()   # [BM, C]
+        gt_bboxes = gt_bboxes.view(-1, 4).to(device).float()                    # [BM, 4]
+
+        pos_masks = (gt_objectness > 0)
+        num_fgs = pos_masks.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+
+        # box loss
+        pred_box_pos = pred_box[pos_masks]
+        gt_bboxes_pos = gt_bboxes[pos_masks]
+        loss_box, ious = self.loss_bboxes(pred_box_pos, gt_bboxes_pos)
+        loss_box = loss_box.sum() / num_fgs
+        
+        # cls loss
+        pred_cls_pos = pred_cls[pos_masks]
+        gt_classes_pos = gt_classes[pos_masks] * ious.unsqueeze(-1).clamp(0.)
+        loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # obj loss
+        loss_obj = self.loss_objectness(pred_obj, gt_objectness)
+        loss_obj = loss_obj.sum() / num_fgs
+
+        # total loss
+        losses = self.loss_obj_weight * loss_obj + \
+                 self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box
+
+        loss_dict = dict(
+                loss_obj = loss_obj,
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                losses = losses
+        )
+
+        return loss_dict
+    
+
+def build_criterion(cfg, device, num_classes):
+    criterion = Criterion(
+        cfg=cfg,
+        device=device,
+        num_classes=num_classes
+        )
+
+    return criterion
+
+    
+if __name__ == "__main__":
+    pass

+ 156 - 0
models/yolov3/matcher.py

@@ -0,0 +1,156 @@
+import numpy as np
+import torch
+
+
+class Yolov3Matcher(object):
+    def __init__(self, num_classes, num_anchors, anchor_size, iou_thresh):
+        self.num_classes = num_classes
+        self.num_anchors = num_anchors
+        self.iou_thresh = iou_thresh
+        self.anchor_boxes = np.array(
+            [[0., 0., anchor[0], anchor[1]]
+            for anchor in anchor_size]
+            )  # [KA, 4]
+
+
+    def compute_iou(self, anchor_boxes, gt_box):
+        """
+            anchor_boxes : ndarray -> [KA, 4] (cx, cy, bw, bh).
+            gt_box : ndarray -> [1, 4] (cx, cy, bw, bh).
+        """
+        # anchors: [KA, 4]
+        anchors = np.zeros_like(anchor_boxes)
+        anchors[..., :2] = anchor_boxes[..., :2] - anchor_boxes[..., 2:] * 0.5  # x1y1
+        anchors[..., 2:] = anchor_boxes[..., :2] + anchor_boxes[..., 2:] * 0.5  # x2y2
+        anchors_area = anchor_boxes[..., 2] * anchor_boxes[..., 3]
+        
+        # gt_box: [1, 4] -> [KA, 4]
+        gt_box = np.array(gt_box).reshape(-1, 4)
+        gt_box = np.repeat(gt_box, anchors.shape[0], axis=0)
+        gt_box_ = np.zeros_like(gt_box)
+        gt_box_[..., :2] = gt_box[..., :2] - gt_box[..., 2:] * 0.5  # x1y1
+        gt_box_[..., 2:] = gt_box[..., :2] + gt_box[..., 2:] * 0.5  # x2y2
+        gt_box_area = np.prod(gt_box[..., 2:] - gt_box[..., :2], axis=1)
+
+        # intersection
+        inter_w = np.minimum(anchors[:, 2], gt_box_[:, 2]) - \
+                  np.maximum(anchors[:, 0], gt_box_[:, 0])
+        inter_h = np.minimum(anchors[:, 3], gt_box_[:, 3]) - \
+                  np.maximum(anchors[:, 1], gt_box_[:, 1])
+        inter_area = inter_w * inter_h
+        
+        # union
+        union_area = anchors_area + gt_box_area - inter_area
+
+        # iou
+        iou = inter_area / union_area
+        iou = np.clip(iou, a_min=1e-10, a_max=1.0)
+        
+        return iou
+
+
+    @torch.no_grad()
+    def __call__(self, fmp_sizes, fpn_strides, targets):
+        """
+            fmp_size: (List) [fmp_h, fmp_w]
+            fpn_strides: (List) -> [8, 16, 32, ...] stride of network output.
+            targets: (Dict) dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}
+        """
+        assert len(fmp_sizes) == len(fpn_strides)
+        # prepare
+        bs = len(targets)
+        gt_objectness = [
+            torch.zeros([bs, fmp_h, fmp_w, self.num_anchors, 1]) 
+            for (fmp_h, fmp_w) in fmp_sizes
+            ]
+        gt_classes = [
+            torch.zeros([bs, fmp_h, fmp_w, self.num_anchors, self.num_classes]) 
+            for (fmp_h, fmp_w) in fmp_sizes
+            ]
+        gt_bboxes = [
+            torch.zeros([bs, fmp_h, fmp_w, self.num_anchors, 4]) 
+            for (fmp_h, fmp_w) in fmp_sizes
+            ]
+
+        for batch_index in range(bs):
+            targets_per_image = targets[batch_index]
+            # [N,]
+            tgt_cls = targets_per_image["labels"].numpy()
+            # [N, 4]
+            tgt_box = targets_per_image['boxes'].numpy()
+
+            for gt_box, gt_label in zip(tgt_box, tgt_cls):
+                # get a bbox coords
+                x1, y1, x2, y2 = gt_box.tolist()
+                # xyxy -> cxcywh
+                xc, yc = (x2 + x1) * 0.5, (y2 + y1) * 0.5
+                bw, bh = x2 - x1, y2 - y1
+                gt_box = [0, 0, bw, bh]
+
+                # check target
+                if bw < 1. or bh < 1.:
+                    # invalid target
+                    continue
+
+                # compute IoU
+                iou = self.compute_iou(self.anchor_boxes, gt_box)
+                iou_mask = (iou > self.iou_thresh)
+
+                label_assignment_results = []
+                if iou_mask.sum() == 0:
+                    # We assign the anchor box with highest IoU score.
+                    iou_ind = np.argmax(iou)
+
+                    level = iou_ind // self.num_anchors              # pyramid level
+                    anchor_idx = iou_ind - level * self.num_anchors  # anchor index
+
+                    # get the corresponding stride
+                    stride = fpn_strides[level]
+
+                    # compute the grid cell
+                    xc_s = xc / stride
+                    yc_s = yc / stride
+                    grid_x = int(xc_s)
+                    grid_y = int(yc_s)
+
+                    label_assignment_results.append([grid_x, grid_y, level, anchor_idx])
+                else:            
+                    for iou_ind, iou_m in enumerate(iou_mask):
+                        if iou_m:
+                            level = iou_ind // self.num_anchors              # pyramid level
+                            anchor_idx = iou_ind - level * self.num_anchors  # anchor index
+
+                            # get the corresponding stride
+                            stride = fpn_strides[level]
+
+                            # compute the gride cell
+                            xc_s = xc / stride
+                            yc_s = yc / stride
+                            grid_x = int(xc_s)
+                            grid_y = int(yc_s)
+
+                            label_assignment_results.append([grid_x, grid_y, level, anchor_idx])
+
+                # label assignment
+                for result in label_assignment_results:
+                    grid_x, grid_y, level, anchor_idx = result
+                    fmp_h, fmp_w = fmp_sizes[level]
+
+                    if grid_x < fmp_w and grid_y < fmp_h:
+                        # obj
+                        gt_objectness[level][batch_index, grid_y, grid_x, anchor_idx] = 1.0
+                        # cls
+                        cls_ont_hot = torch.zeros(self.num_classes)
+                        cls_ont_hot[int(gt_label)] = 1.0
+                        gt_classes[level][batch_index, grid_y, grid_x, anchor_idx] = cls_ont_hot
+                        # box
+                        gt_bboxes[level][batch_index, grid_y, grid_x, anchor_idx] = torch.as_tensor([x1, y1, x2, y2])
+
+        # [B, M, C]
+        gt_objectness = torch.cat([gt.view(bs, -1, 1) for gt in gt_objectness], dim=1).float()
+        gt_classes = torch.cat([gt.view(bs, -1, self.num_classes) for gt in gt_classes], dim=1).float()
+        gt_bboxes = torch.cat([gt.view(bs, -1, 4) for gt in gt_bboxes], dim=1).float()
+
+        return gt_objectness, gt_classes, gt_bboxes

+ 305 - 0
models/yolov3/yolov3.py

@@ -0,0 +1,305 @@
+import torch
+import torch.nn as nn
+
+from utils.nms import multiclass_nms
+
+from .yolov3_backbone import build_backbone
+from .yolov3_neck import build_neck
+from .yolov3_fpn import build_fpn
+from .yolov3_head import build_head
+
+
+# YOLOv3
+class YOLOv3(nn.Module):
+    def __init__(self,
+                 cfg,
+                 device,
+                 img_size=None,
+                 num_classes=20,
+                 conf_thresh=0.01,
+                 topk=100,
+                 nms_thresh=0.5,
+                 trainable=False):
+        super(YOLOv3, self).__init__()
+        # ------------------- Basic parameters -------------------
+        self.cfg = cfg                                 # 模型配置文件
+        self.img_size = img_size                       # 输入图像大小
+        self.device = device                           # cuda或者是cpu
+        self.num_classes = num_classes                 # 类别的数量
+        self.trainable = trainable                     # 训练的标记
+        self.conf_thresh = conf_thresh                 # 得分阈值
+        self.nms_thresh = nms_thresh                   # NMS阈值
+        self.topk = topk                               # topk
+        self.stride = [8, 16, 32]                      # 网络的输出步长
+        # ------------------- Anchor box -------------------
+        self.num_levels = 3
+        self.num_anchors = len(cfg['anchor_size']) // self.num_levels
+        self.anchor_size = torch.as_tensor(
+            cfg['anchor_size']
+            ).view(self.num_levels, self.num_anchors, 2) # [S, A, 2]
+        
+        # ------------------- Network Structure -------------------
+        ## 主干网络
+        self.backbone, feats_dim = build_backbone(
+            cfg['backbone'], trainable&cfg['pretrained'])
+
+        ## 颈部网络
+        self.neck = build_neck(cfg, in_dim=feats_dim[-1], out_dim=feats_dim[-1])
+        feats_dim[-1] = self.neck.out_dim
+
+        ## 特征金字塔
+        self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=int(256*cfg['width']))
+        self.head_dim = self.fpn.out_dim
+
+        ## 检测头
+        self.non_shared_heads = nn.ModuleList(
+            [build_head(cfg, head_dim, head_dim, num_classes) 
+            for head_dim in self.head_dim
+            ])
+
+        ## 预测层
+        self.obj_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_out_dim, 1 * self.num_anchors, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ]) 
+        self.cls_preds = nn.ModuleList(
+                            [nn.Conv2d(head.cls_out_dim, self.num_classes * self.num_anchors, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ]) 
+        self.reg_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_out_dim, 4 * self.num_anchors, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ])                 
+    
+
+        # --------- Network Initialization ----------
+        self.init_yolo()
+
+
+    def init_yolo(self): 
+        # Init yolo
+        for m in self.modules():
+            if isinstance(m, nn.BatchNorm2d):
+                m.eps = 1e-3
+                m.momentum = 0.03
+                
+        # Init bias
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        # obj pred
+        for obj_pred in self.obj_preds:
+            b = obj_pred.bias.view(self.num_anchors, -1)
+            b.data.fill_(bias_value.item())
+            obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # cls pred
+        for cls_pred in self.cls_preds:
+            b = cls_pred.bias.view(self.num_anchors, -1)
+            b.data.fill_(bias_value.item())
+            cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+
+    def generate_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        fmp_h, fmp_w = fmp_size
+        # [KA, 2]
+        anchor_size = self.anchor_size[level]
+
+        # generate grid cells
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        # [HW, 2] -> [HW, KA, 2] -> [M, 2]
+        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1)
+        anchor_xy = anchor_xy.view(-1, 2).to(self.device)
+
+        # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2] -> [M, 2]
+        anchor_wh = anchor_size.unsqueeze(0).repeat(fmp_h*fmp_w, 1, 1)
+        anchor_wh = anchor_wh.view(-1, 2).to(self.device)
+
+        anchors = torch.cat([anchor_xy, anchor_wh], dim=-1)
+
+        return anchors
+        
+
+    def decode_boxes(self, level, anchors, reg_pred):
+        """
+            将txtytwth转换为常用的x1y1x2y2形式。
+        """
+
+        # 计算预测边界框的中心点坐标和宽高
+        pred_ctr = (torch.sigmoid(reg_pred[..., :2]) + anchors[..., :2]) * self.stride[level]
+        pred_wh = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
+
+        # 将所有bbox的中心带你坐标和宽高换算成x1y1x2y2形式
+        pred_x1y1 = pred_ctr - pred_wh * 0.5
+        pred_x2y2 = pred_ctr + pred_wh * 0.5
+        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        return pred_box
+
+
+    def post_process(self, obj_preds, cls_preds, reg_preds, anchors):
+        """
+        Input:
+            obj_preds: List(Tensor) [[H x W, 1], ...]
+            cls_preds: List(Tensor) [[H x W, C], ...]
+            reg_preds: List(Tensor) [[H x W, 4], ...]
+            anchors:  List(Tensor) [[H x W, 2], ...]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for level, (obj_pred_i, cls_pred_i, reg_pred_i, anchor_i) \
+                in enumerate(zip(obj_preds, cls_preds, reg_preds, anchors)):
+            # (H x W x KA x C,)
+            scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk, reg_pred_i.size(0))
+
+            # torch.sort is actually faster than .topk (at least on GPUs)
+            predicted_prob, topk_idxs = scores_i.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            scores = topk_scores[keep_idxs]
+            topk_idxs = topk_idxs[keep_idxs]
+
+            anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+            labels = topk_idxs % self.num_classes
+
+            reg_pred_i = reg_pred_i[anchor_idxs]
+            anchor_i = anchor_i[anchor_idxs]
+
+            # decode box: [M, 4]
+            bboxes = self.decode_boxes(level, anchor_i, reg_pred_i)
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # threshold
+        keep_idxs = scores.gt(self.conf_thresh)
+        scores = scores[keep_idxs]
+        labels = labels[keep_idxs]
+        bboxes = bboxes[keep_idxs]
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
+
+        return bboxes, scores, labels
+
+
+    @torch.no_grad()
+    def inference(self, x):
+        # 主干网络
+        pyramid_feats = self.backbone(x)
+
+        # 颈部网络
+        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+        # 特征金字塔
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # 检测头
+        all_anchors = []
+        all_obj_preds = []
+        all_cls_preds = []
+        all_reg_preds = []
+        for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+            cls_feat, reg_feat = head(feat)
+
+            # [1, C, H, W]
+            obj_pred = self.obj_preds[level](reg_feat)
+            cls_pred = self.cls_preds[level](cls_feat)
+            reg_pred = self.reg_preds[level](reg_feat)
+
+            # anchors: [M, 2]
+            fmp_size = cls_pred.shape[-2:]
+            anchors = self.generate_anchors(level, fmp_size)
+
+            # [1, AC, H, W] -> [H, W, AC] -> [M, C]
+            obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
+            cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
+            reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
+
+            all_obj_preds.append(obj_pred)
+            all_cls_preds.append(cls_pred)
+            all_reg_preds.append(reg_pred)
+            all_anchors.append(anchors)
+
+        # post process
+        bboxes, scores, labels = self.post_process(
+            all_obj_preds, all_cls_preds, all_reg_preds, all_anchors)
+
+        return bboxes, scores, labels
+
+
+    def forward(self, x):
+        if not self.trainable:
+            return self.inference(x)
+        else:
+            bs = x.shape[0]
+            # 主干网络
+            pyramid_feats = self.backbone(x)
+
+            # 颈部网络
+            pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+            # 特征金字塔
+            pyramid_feats = self.fpn(pyramid_feats)
+
+            # 检测头
+            all_fmp_sizes = []
+            all_obj_preds = []
+            all_cls_preds = []
+            all_box_preds = []
+            for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+                cls_feat, reg_feat = head(feat)
+
+                # [B, C, H, W]
+                obj_pred = self.obj_preds[level](reg_feat)
+                cls_pred = self.cls_preds[level](cls_feat)
+                reg_pred = self.reg_preds[level](reg_feat)
+
+                fmp_size = cls_pred.shape[-2:]
+
+                # generate anchor boxes: [M, 4]
+                anchors = self.generate_anchors(level, fmp_size)
+                
+                # [B, AC, H, W] -> [B, H, W, AC] -> [B, M, C]
+                obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, 1)
+                cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, self.num_classes)
+                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, 4)
+
+                # decode bbox
+                box_pred = self.decode_boxes(level, anchors, reg_pred)
+
+                all_obj_preds.append(obj_pred)
+                all_cls_preds.append(cls_pred)
+                all_box_preds.append(box_pred)
+                all_fmp_sizes.append(fmp_size)
+
+            # output dict
+            outputs = {"pred_obj": all_obj_preds,        # List [B, M, 1]
+                       "pred_cls": all_cls_preds,        # List [B, M, C]
+                       "pred_box": all_box_preds,        # List [B, M, 4]
+                       'fmp_sizes': all_fmp_sizes,       # List
+                       'strides': self.stride,           # List
+                       }
+
+            return outputs 

+ 116 - 0
models/yolov3/yolov3_backbone.py

@@ -0,0 +1,116 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov3_basic import Conv, ResBlock
+except:
+    from yolov3_basic import Conv, ResBlock
+    
+
+model_urls = {
+    "darknet53": "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/darknet53_silu.pth",
+}
+
+
+# --------------------- DarkNet-19 -----------------------
+class DarkNet53(nn.Module):
+    def __init__(self, act_type='silu', norm_type='BN'):
+        super(DarkNet53, self).__init__()
+        self.feat_dims = [256, 512, 1024]
+
+        # P1
+        self.layer_1 = nn.Sequential(
+            Conv(3, 32, k=3, p=1, act_type=act_type, norm_type=norm_type),
+            Conv(32, 64, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            ResBlock(64, 64, nblocks=1, act_type=act_type, norm_type=norm_type)
+        )
+        # P2
+        self.layer_2 = nn.Sequential(
+            Conv(64, 128, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            ResBlock(128, 128, nblocks=2, act_type=act_type, norm_type=norm_type)
+        )
+        # P3
+        self.layer_3 = nn.Sequential(
+            Conv(128, 256, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            ResBlock(256, 256, nblocks=8, act_type=act_type, norm_type=norm_type)
+        )
+        # P4
+        self.layer_4 = nn.Sequential(
+            Conv(256, 512, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            ResBlock(512, 512, nblocks=8, act_type=act_type, norm_type=norm_type)
+        )
+        # P5
+        self.layer_5 = nn.Sequential(
+            Conv(512, 1024, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            ResBlock(1024, 1024, nblocks=4, act_type=act_type, norm_type=norm_type)
+        )
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# --------------------- Functions -----------------------
+def build_backbone(model_name='darknet53', pretrained=False): 
+    """Constructs a darknet-53 model.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    if model_name == 'darknet53':
+        backbone = DarkNet53(act_type='silu', norm_type='BN')
+        feat_dims = backbone.feat_dims
+
+    if pretrained:
+        url = model_urls['darknet53']
+        if url is not None:
+            print('Loading pretrained weight ...')
+            checkpoint = torch.hub.load_state_dict_from_url(
+                url=url, map_location="cpu", check_hash=True)
+            # checkpoint state dict
+            checkpoint_state_dict = checkpoint.pop("model")
+            # model state dict
+            model_state_dict = backbone.state_dict()
+            # check
+            for k in list(checkpoint_state_dict.keys()):
+                if k in model_state_dict:
+                    shape_model = tuple(model_state_dict[k].shape)
+                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                    if shape_model != shape_checkpoint:
+                        checkpoint_state_dict.pop(k)
+                else:
+                    checkpoint_state_dict.pop(k)
+                    print(k)
+
+            backbone.load_state_dict(checkpoint_state_dict)
+        else:
+            print('No backbone pretrained: DarkNet53')        
+
+    return backbone, feat_dims
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    model, feats = build_backbone(pretrained=False)
+    x = torch.randn(1, 3, 224, 224)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    x = torch.randn(1, 3, 224, 224)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 137 - 0
models/yolov3/yolov3_basic.py

@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+
+
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+
+
+# Basic conv layer
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type='lrelu',     # activation
+                 norm_type='BN',       # normalization
+                 depthwise=False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+
+# BottleNeck
+class Bottleneck(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio=0.5,
+                 shortcut=False,
+                 depthwise=False,
+                 act_type='silu',
+                 norm_type='BN'):
+        super(Bottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)  # hidden channels            
+        self.cv1 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
+        self.cv2 = Conv(inter_dim, out_dim, k=3, p=1, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.cv2(self.cv1(x))
+
+        return x + h if self.shortcut else h
+
+
+# ResBlock
+class ResBlock(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 nblocks=1,
+                 act_type='silu',
+                 norm_type='BN'):
+        super(ResBlock, self).__init__()
+        assert in_dim == out_dim
+        self.m = nn.Sequential(*[
+            Bottleneck(in_dim, out_dim, expand_ratio=0.5, shortcut=True,
+                       norm_type=norm_type, act_type=act_type)
+                       for _ in range(nblocks)
+                       ])
+
+    def forward(self, x):
+        return self.m(x)
+
+
+# ConvBlocks
+class ConvBlocks(nn.Module):
+    def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
+        super().__init__()
+        inter_dim = out_dim // 2
+        self.convs = nn.Sequential(
+            Conv(in_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type),
+            Conv(out_dim, inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            Conv(inter_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type),
+            Conv(out_dim, inter_dim, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            Conv(inter_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        )
+
+    def forward(self, x):
+        return self.convs(x)
+    

+ 87 - 0
models/yolov3/yolov3_fpn.py

@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .yolov3_basic import Conv, ConvBlocks
+
+
+# YoloFPN
+class YoloFPN(nn.Module):
+    def __init__(self,
+                 in_dims=[256, 512, 1024],
+                 width=1.0,
+                 depth=1.0,
+                 out_dim=None,
+                 act_type='silu',
+                 norm_type='BN'):
+        super(YoloFPN, self).__init__()
+        self.in_dims = in_dims
+        self.out_dim = out_dim
+        c3, c4, c5 = in_dims
+
+        # P5 -> P4
+        self.top_down_layer_1 = ConvBlocks(c5, int(512*width), act_type=act_type, norm_type=norm_type)
+        self.reduce_layer_1 = Conv(int(512*width), int(256*width), k=1, act_type=act_type, norm_type=norm_type)
+
+        # P4 -> P3
+        self.top_down_layer_2 = ConvBlocks(c4 + int(256*width), int(256*width), act_type=act_type, norm_type=norm_type)
+        self.reduce_layer_2 = Conv(int(256*width), int(128*width), k=1, act_type=act_type, norm_type=norm_type)
+
+        # P3
+        self.top_down_layer_3 = ConvBlocks(c3 + int(128*width), int(128*width), act_type=act_type, norm_type=norm_type)
+
+        # output proj layers
+        if out_dim is not None:
+            # output proj layers
+            self.out_layers = nn.ModuleList([
+                Conv(in_dim, out_dim, k=1,
+                        norm_type=norm_type, act_type=act_type)
+                        for in_dim in [int(128 * width), int(256 * width), int(512 * width)]
+                        ])
+            self.out_dim = [out_dim] * 3
+
+        else:
+            self.out_layers = None
+            self.out_dim = [int(128 * width), int(256 * width), int(512 * width)]
+
+
+    def forward(self, features):
+        c3, c4, c5 = features
+        
+        # p5/32
+        p5 = self.top_down_layer_1(c5)
+
+        # p4/16
+        p5_up = F.interpolate(self.reduce_layer_1(p5), scale_factor=2.0)
+        p4 = self.top_down_layer_2(torch.cat([c4, p5_up], dim=1))
+
+        # P3/8
+        p4_up = F.interpolate(self.reduce_layer_2(p4), scale_factor=2.0)
+        p3 = self.top_down_layer_3(torch.cat([c3, p4_up], dim=1))
+
+        out_feats = [p3, p4, p5]
+
+        # output proj layers
+        if self.out_layers is not None:
+            # output proj layers
+            out_feats_proj = []
+            for feat, layer in zip(out_feats, self.out_layers):
+                out_feats_proj.append(layer(feat))
+            return out_feats_proj
+
+        return out_feats
+
+
+def build_fpn(cfg, in_dims, out_dim=None):
+    model = cfg['fpn']
+    # build neck
+    if model == 'yolo_fpn':
+        fpn_net = YoloFPN(in_dims=in_dims,
+                            out_dim=out_dim,
+                            width=cfg['width'],
+                            depth=cfg['depth'],
+                            act_type=cfg['fpn_act'],
+                            norm_type=cfg['fpn_norm']
+                            )
+
+    return fpn_net

+ 137 - 0
models/yolov3/yolov3_head.py

@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+try:
+    from .yolov3_basic import Conv
+except:
+    from yolov3_basic import Conv
+
+
+class DecoupledHead(nn.Module):
+    def __init__(self, cfg, in_dim, out_dim, num_classes=80):
+        super().__init__()
+        print('==============================')
+        print('Head: Decoupled Head')
+        self.in_dim = in_dim
+        self.num_cls_head=cfg['num_cls_head']
+        self.num_reg_head=cfg['num_reg_head']
+        self.act_type=cfg['head_act']
+        self.norm_type=cfg['head_norm']
+
+        # cls head
+        cls_feats = []
+        self.cls_out_dim = max(out_dim, num_classes)
+        for i in range(cfg['num_cls_head']):
+            if i == 0:
+                cls_feats.append(
+                    Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+            else:
+                cls_feats.append(
+                    Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+                
+        # reg head
+        reg_feats = []
+        self.reg_out_dim = max(out_dim, 64)
+        for i in range(cfg['num_reg_head']):
+            if i == 0:
+                reg_feats.append(
+                    Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+            else:
+                reg_feats.append(
+                    Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+
+# build detection head
+def build_head(cfg, in_dim, out_dim, num_classes=80):
+    head = DecoupledHead(cfg, in_dim, out_dim, num_classes) 
+
+    return head
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'head_depthwise': False,
+        'reg_max': 16,
+    }
+    fpn_dims = [256, 512, 512]
+    # Head-1
+    model = build_head(cfg, 256, fpn_dims, num_classes=80)
+    x = torch.randn(1, 256, 80, 80)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-1: Params : {:.2f} M'.format(params / 1e6))
+
+    # Head-2
+    model = build_head(cfg, 512, fpn_dims, num_classes=80)
+    x = torch.randn(1, 512, 40, 40)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-2: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-2: Params : {:.2f} M'.format(params / 1e6))
+
+    # Head-3
+    model = build_head(cfg, 512, fpn_dims, num_classes=80)
+    x = torch.randn(1, 512, 20, 20)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-3: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-3: Params : {:.2f} M'.format(params / 1e6))

+ 40 - 0
models/yolov3/yolov3_neck.py

@@ -0,0 +1,40 @@
+import torch
+import torch.nn as nn
+from .yolov3_basic import Conv
+
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    def __init__(self, in_dim, out_dim, expand_ratio=0.5, pooling_size=5, act_type='lrelu', norm_type='BN'):
+        super().__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(inter_dim * 4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.m = nn.MaxPool2d(kernel_size=pooling_size, stride=1, padding=pooling_size // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+
+
+def build_neck(cfg, in_dim, out_dim):
+    model = cfg['neck']
+    print('==============================')
+    print('Neck: {}'.format(model))
+    # build neck
+    if model == 'sppf':
+        neck = SPPF(
+            in_dim=in_dim,
+            out_dim=out_dim,
+            expand_ratio=cfg['expand_ratio'], 
+            pooling_size=cfg['pooling_size'],
+            act_type=cfg['neck_act'],
+            norm_type=cfg['neck_norm']
+            )
+
+    return neck
+        

+ 2 - 2
train.sh

@@ -3,11 +3,11 @@ python train.py \
         --cuda \
         -d voc \
         --root /mnt/share/ssd2/dataset/ \
-        -m yolov2 \
+        -m yolov3 \
         -bs 16 \
         -size 640 \
         --wp_epoch 1 \
-        --max_epoch 150 \
+        --max_epoch 300 \
         --eval_epoch 10 \
         --ema \
         --fp16 \