yjh0410 2 yıl önce
ebeveyn
işleme
6bf9fafba9

+ 1 - 0
README.md

@@ -97,6 +97,7 @@ python train.py --cuda -d voc --root path/to/VOCdevkit -v yolov1 -bs 16 --max_ep
 | YOLOv2 |  640  |  √   |      |                          |   53.9            |   30.9             |  |
 | YOLOv3 |  640  |  √   |      |                          |   167.4           |   54.9             |  |
 | YOLOv4 |  640  |  √   |      |                          |                   |                    |  |
+| YOLOX  |  640  |  ×   |      |                          |                   |                    |  |
 
 *All models are trained with ImageNet pretrained weight (IP). All FLOPs are measured with a 640x640 image size on VOC2007 test. The FPS is measured with batch size 1 on 3090 GPU from the model inference to the NMS operation.*
 

+ 8 - 1
config/__init__.py

@@ -3,6 +3,7 @@ from .yolov1_config import yolov1_cfg
 from .yolov2_config import yolov2_cfg
 from .yolov3_config import yolov3_cfg
 from .yolov4_config import yolov4_cfg
+from .yolox_config import yolox_cfg
 
 
 def build_model_config(args):
@@ -20,12 +21,15 @@ def build_model_config(args):
     # YOLOv4
     elif args.model == 'yolov4':
         cfg = yolov4_cfg
+    # YOLOX
+    elif args.model == 'yolox':
+        cfg = yolox_cfg
 
     return cfg
 
 
 # ------------------ Transform Config ----------------------
-from .transform_config import yolov5_trans_config, ssd_trans_config
+from .transform_config import yolov5_trans_config, yolox_trans_config, ssd_trans_config
 
 def build_trans_config(trans_config='ssd'):
     print('==============================')
@@ -36,5 +40,8 @@ def build_trans_config(trans_config='ssd'):
     # YOLOv5-style transform 
     elif trans_config == 'yolov5':
         cfg = yolov5_trans_config
+    # YOLOX-style transform 
+    elif trans_config == 'yolox':
+        cfg = yolox_trans_config
 
     return cfg

+ 20 - 2
config/transform_config.py

@@ -17,13 +17,31 @@ yolov5_trans_config = {
     'mixup_prob': 0.15,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolov5_mixup',
-    'mixup_scale': [0.5, 1.5]
+    'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
+yolox_trans_config = {
+    'aug_type': 'yolov5',
+    # Basic Augment
+    'degrees': 0.0,
+    'translate': 0.2,
+    'scale': 0.9,
+    'shear': 0.0,
+    'perspective': 0.0,
+    'hsv_h': 0.015,
+    'hsv_s': 0.7,
+    'hsv_v': 0.4,
+    # Mosaic & Mixup
+    'mosaic_prob': 1.0,
+    'mixup_prob': 1.0,
+    'mosaic_type': 'yolov5_mosaic',
+    'mixup_type': 'yolox_mixup',
+    'mixup_scale': [0.5, 1.5]
+}
 
 ssd_trans_config = {
     'aug_type': 'ssd',
-    # Mosaic & Mixup are nor used for SSD-style augmentation
+    # Mosaic & Mixup are not used for SSD-style augmentation
     'mosaic_prob': 0.,
     'mixup_prob': 0.,
     'mosaic_type': 'yolov5_mosaic',

+ 50 - 0
config/yolox_config.py

@@ -0,0 +1,50 @@
+# YOLOx Config
+
+yolox_cfg = {
+    # input
+    'trans_type': 'yolox',
+    # model
+    'backbone': 'cspdarknet',
+    'pretrained': True,
+    'bk_act': 'silu',
+    'bk_norm': 'BN',
+    'bk_dpw': False,
+    'stride': [8, 16, 32],  # P3, P4, P5
+    'width': 1.0,
+    'depth': 1.0,
+     # fpn
+    'fpn': 'yolo_pafpn',
+    'fpn_act': 'silu',
+    'fpn_norm': 'BN',
+    'fpn_depthwise': False,
+    # head
+    'head': 'decoupled_head',
+    'head_act': 'silu',
+    'head_norm': 'BN',
+    'num_cls_head': 2,
+    'num_reg_head': 2,
+    'head_depthwise': False,
+    # matcher
+    'matcher': {'center_sampling_radius': 2.5,
+                'topk_candicate': 10},
+    # loss weight
+    'loss_obj_weight': 1.0,
+    'loss_cls_weight': 1.0,
+    'loss_box_weight': 5.0,
+    # training configuration
+    'no_aug_epoch': 20,
+    # optimizer
+    'optimizer': 'sgd',        # optional: sgd, adam, adamw
+    'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
+    'weight_decay': 5e-4,      # SGD: 5e-4;     AdamW: 5e-2
+    'clip_grad': 10,           # SGD: 10.0;     AdamW: -1
+    # model EMA
+    'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
+    'ema_tau': 2000,
+    # lr schedule
+    'scheduler': 'linear',
+    'lr0': 0.01,               # SGD: 0.01;     AdamW: 0.004
+    'lrf': 0.01,               # SGD: 0.01;     AdamW: 0.05
+    'warmup_momentum': 0.8,
+    'warmup_bias_lr': 0.1,
+}

+ 3 - 3
eval.py

@@ -21,14 +21,14 @@ from config import build_model_config, build_trans_config
 def parse_args():
     parser = argparse.ArgumentParser(description='YOLO-Tutorial')
     # basic
-    parser.add_argument('-size', '--img_size', default=416, type=int,
+    parser.add_argument('-size', '--img_size', default=640, type=int,
                         help='the max size of input image')
     parser.add_argument('--cuda', action='store_true', default=False,
                         help='Use cuda')
 
     # model
-    parser.add_argument('-m', '--model', default='yolo_anchor', type=str,
-                        help='build YOLO')
+    parser.add_argument('-m', '--model', default='yolov1', type=str,
+                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolox'], help='build yolo')
     parser.add_argument('--weight', default=None,
                         type=str, help='Trained state_dict file path to open')
     parser.add_argument('--conf_thresh', default=0.001, type=float,

+ 5 - 0
models/__init__.py

@@ -6,6 +6,7 @@ from .yolov1.build import build_yolov1
 from .yolov2.build import build_yolov2
 from .yolov3.build import build_yolov3
 from .yolov4.build import build_yolov4
+from .yolox.build import build_yolox
 
 
 # build object detector
@@ -30,6 +31,10 @@ def build_model(args,
     elif args.model == 'yolov4':
         model, criterion = build_yolov4(
             args, model_cfg, device, num_classes, trainable)
+    # YOLOX   
+    elif args.model == 'yolox':
+        model, criterion = build_yolox(
+            args, model_cfg, device, num_classes, trainable)
 
     if trainable:
         # Load pretrained weight

+ 0 - 1
models/yolov3/build.py

@@ -16,7 +16,6 @@ def build_yolov3(args, cfg, device, num_classes=80, trainable=False):
     model = YOLOv3(
         cfg = cfg,
         device = device,
-        img_size = args.img_size,
         num_classes = num_classes,
         conf_thresh = args.conf_thresh,
         nms_thresh = args.nms_thresh,

+ 2 - 4
models/yolov3/yolov3.py

@@ -14,7 +14,6 @@ class YOLOv3(nn.Module):
     def __init__(self,
                  cfg,
                  device,
-                 img_size=None,
                  num_classes=20,
                  conf_thresh=0.01,
                  topk=100,
@@ -23,7 +22,6 @@ class YOLOv3(nn.Module):
         super(YOLOv3, self).__init__()
         # ------------------- Basic parameters -------------------
         self.cfg = cfg                                 # 模型配置文件
-        self.img_size = img_size                       # 输入图像大小
         self.device = device                           # cuda或者是cpu
         self.num_classes = num_classes                 # 类别的数量
         self.trainable = trainable                     # 训练的标记
@@ -43,11 +41,11 @@ class YOLOv3(nn.Module):
         self.backbone, feats_dim = build_backbone(
             cfg['backbone'], trainable&cfg['pretrained'])
 
-        ## 颈部网络
+        ## 颈部网络: SPP模块
         self.neck = build_neck(cfg, in_dim=feats_dim[-1], out_dim=feats_dim[-1])
         feats_dim[-1] = self.neck.out_dim
 
-        ## 特征金字塔
+        ## 颈部网络: 特征金字塔
         self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=int(256*cfg['width']))
         self.head_dim = self.fpn.out_dim
 

+ 0 - 1
models/yolov4/build.py

@@ -16,7 +16,6 @@ def build_yolov4(args, cfg, device, num_classes=80, trainable=False):
     model = YOLOv4(
         cfg = cfg,
         device = device,
-        img_size = args.img_size,
         num_classes = num_classes,
         conf_thresh = args.conf_thresh,
         nms_thresh = args.nms_thresh,

+ 2 - 4
models/yolov4/yolov4.py

@@ -14,7 +14,6 @@ class YOLOv4(nn.Module):
     def __init__(self,
                  cfg,
                  device,
-                 img_size=None,
                  num_classes=20,
                  conf_thresh=0.01,
                  topk=100,
@@ -23,7 +22,6 @@ class YOLOv4(nn.Module):
         super(YOLOv4, self).__init__()
         # ------------------- Basic parameters -------------------
         self.cfg = cfg                                 # 模型配置文件
-        self.img_size = img_size                       # 输入图像大小
         self.device = device                           # cuda或者是cpu
         self.num_classes = num_classes                 # 类别的数量
         self.trainable = trainable                     # 训练的标记
@@ -43,11 +41,11 @@ class YOLOv4(nn.Module):
         self.backbone, feats_dim = build_backbone(
             cfg['backbone'], trainable&cfg['pretrained'])
 
-        ## 颈部网络
+        ## 颈部网络: SPP模块
         self.neck = build_neck(cfg, in_dim=feats_dim[-1], out_dim=feats_dim[-1])
         feats_dim[-1] = self.neck.out_dim
 
-        ## 特征金字塔
+        ## 颈部网络: 特征金字塔
         self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=int(256*cfg['width']))
         self.head_dim = self.fpn.out_dim
 

+ 30 - 0
models/yolox/build.py

@@ -0,0 +1,30 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+from .loss import build_criterion
+from .yolox import YOLOX
+
+
+# build object detector
+def build_yolox(args, cfg, device, num_classes=80, trainable=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+    
+    print('==============================')
+    print('Model Configuration: \n', cfg)
+    
+    model = YOLOX(
+        cfg = cfg,
+        device = device,
+        num_classes = num_classes,
+        conf_thresh = args.conf_thresh,
+        nms_thresh = args.nms_thresh,
+        topk = args.topk,
+        trainable = trainable
+        )
+
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(cfg, device, num_classes)
+    return model, criterion

+ 168 - 0
models/yolox/loss.py

@@ -0,0 +1,168 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from .matcher import SimOTA
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+
+
+class Criterion(object):
+    def __init__(self, 
+                 cfg, 
+                 device, 
+                 num_classes=80):
+        self.cfg = cfg
+        self.device = device
+        self.num_classes = num_classes
+        # loss weight
+        self.loss_obj_weight = cfg['loss_obj_weight']
+        self.loss_cls_weight = cfg['loss_cls_weight']
+        self.loss_box_weight = cfg['loss_box_weight']
+        # matcher
+        matcher_config = cfg['matcher']
+        self.matcher = SimOTA(
+            num_classes=num_classes,
+            center_sampling_radius=matcher_config['center_sampling_radius'],
+            topk_candidate=matcher_config['topk_candicate']
+            )
+
+
+    def loss_objectness(self, pred_obj, gt_obj):
+        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
+
+        return loss_obj
+    
+
+    def loss_classes(self, pred_cls, gt_label):
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+
+        return loss_cls
+
+
+    def loss_bboxes(self, pred_box, gt_box):
+        # regression loss
+        ious = get_ious(pred_box,
+                        gt_box,
+                        box_mode="xyxy",
+                        iou_type='giou')
+        loss_box = 1.0 - ious
+
+        return loss_box
+
+
+    def __call__(self, outputs, targets):        
+        """
+            outputs['pred_obj']: List(Tensor) [B, M, 1]
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        bs = outputs['pred_cls'][0].shape[0]
+        device = outputs['pred_cls'][0].device
+        fpn_strides = outputs['strides']
+        anchors = outputs['anchors']
+        # preds: [B, M, C]
+        obj_preds = torch.cat(outputs['pred_obj'], dim=1)
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+
+        # label assignment
+        cls_targets = []
+        box_targets = []
+        obj_targets = []
+        fg_masks = []
+
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
+                num_anchors = sum([ab.shape[0] for ab in anchors])
+                # There is no valid gt
+                cls_target = obj_preds.new_zeros((0, self.num_classes))
+                box_target = obj_preds.new_zeros((0, 4))
+                obj_target = obj_preds.new_zeros((num_anchors, 1))
+                fg_mask = obj_preds.new_zeros(num_anchors).bool()
+            else:
+                (
+                    gt_matched_classes,
+                    fg_mask,
+                    pred_ious_this_matching,
+                    matched_gt_inds,
+                    num_fg_img,
+                ) = self.matcher(
+                    fpn_strides = fpn_strides,
+                    anchors = anchors,
+                    pred_obj = obj_preds[batch_idx],
+                    pred_cls = cls_preds[batch_idx], 
+                    pred_box = box_preds[batch_idx],
+                    tgt_labels = tgt_labels,
+                    tgt_bboxes = tgt_bboxes
+                    )
+
+                obj_target = fg_mask.unsqueeze(-1)
+                cls_target = F.one_hot(gt_matched_classes.long(), self.num_classes)
+                cls_target = cls_target * pred_ious_this_matching.unsqueeze(-1)
+                box_target = tgt_bboxes[matched_gt_inds]
+
+            cls_targets.append(cls_target)
+            box_targets.append(box_target)
+            obj_targets.append(obj_target)
+            fg_masks.append(fg_mask)
+
+        cls_targets = torch.cat(cls_targets, 0)
+        box_targets = torch.cat(box_targets, 0)
+        obj_targets = torch.cat(obj_targets, 0)
+        fg_masks = torch.cat(fg_masks, 0)
+        num_fgs = fg_masks.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+
+        # obj loss
+        loss_obj = self.loss_objectness(obj_preds.view(-1, 1), obj_targets.float())
+        loss_obj = loss_obj.sum() / num_fgs
+        
+        # cls loss
+        cls_preds_pos = cls_preds.view(-1, self.num_classes)[fg_masks]
+        loss_cls = self.loss_classes(cls_preds_pos, cls_targets)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # regression loss
+        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets)
+        loss_box = loss_box.sum() / num_fgs
+
+        # total loss
+        losses = self.loss_obj_weight * loss_obj + \
+                 self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box
+
+        loss_dict = dict(
+                loss_obj = loss_obj,
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                losses = losses
+        )
+
+        return loss_dict
+    
+
+def build_criterion(cfg, device, num_classes):
+    criterion = Criterion(
+        cfg=cfg,
+        device=device,
+        num_classes=num_classes
+        )
+
+    return criterion
+
+
+if __name__ == "__main__":
+    pass

+ 204 - 0
models/yolox/matcher.py

@@ -0,0 +1,204 @@
+import torch
+import torch.nn.functional as F
+from utils.box_ops import *
+
+
+
+# YOLOX SimOTA
+class SimOTA(object):
+    def __init__(self, 
+                 num_classes,
+                 center_sampling_radius,
+                 topk_candidate
+                 ) -> None:
+        self.num_classes = num_classes
+        self.center_sampling_radius = center_sampling_radius
+        self.topk_candidate = topk_candidate
+
+
+    @torch.no_grad()
+    def __call__(self, 
+                 fpn_strides, 
+                 anchors, 
+                 pred_obj, 
+                 pred_cls, 
+                 pred_box, 
+                 tgt_labels,
+                 tgt_bboxes):
+        # [M,]
+        strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
+                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
+        # List[F, M, 2] -> [M, 2]
+        anchors = torch.cat(anchors, dim=0)
+        num_anchor = anchors.shape[0]        
+        num_gt = len(tgt_labels)
+
+        fg_mask, is_in_boxes_and_center = \
+            self.get_in_boxes_info(
+                tgt_bboxes,
+                anchors,
+                strides,
+                num_anchor,
+                num_gt
+                )
+
+        obj_preds_ = pred_obj[fg_mask]   # [Mp, 1]
+        cls_preds_ = pred_cls[fg_mask]   # [Mp, C]
+        box_preds_ = pred_box[fg_mask]   # [Mp, 4]
+        num_in_boxes_anchor = box_preds_.shape[0]
+
+        # [N, Mp]
+        pair_wise_ious, _ = box_iou(tgt_bboxes, box_preds_)
+        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8)
+
+        # [N, C] -> [N, Mp, C]
+        gt_cls = (
+            F.one_hot(tgt_labels.long(), self.num_classes)
+            .float()
+            .unsqueeze(1)
+            .repeat(1, num_in_boxes_anchor, 1)
+        )
+
+        with torch.cuda.amp.autocast(enabled=False):
+            score_preds_ = torch.sqrt(
+                cls_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
+                * obj_preds_.float().unsqueeze(0).repeat(num_gt, 1, 1).sigmoid_()
+            ) # [N, Mp, C]
+            pair_wise_cls_loss = F.binary_cross_entropy(
+                score_preds_, gt_cls, reduction="none"
+            ).sum(-1) # [N, Mp]
+        del score_preds_
+
+        cost = (
+            pair_wise_cls_loss
+            + 3.0 * pair_wise_ious_loss
+            + 100000.0 * (~is_in_boxes_and_center)
+        ) # [N, Mp]
+
+        (
+            num_fg,
+            gt_matched_classes,         # [num_fg,]
+            pred_ious_this_matching,    # [num_fg,]
+            matched_gt_inds,            # [num_fg,]
+        ) = self.dynamic_k_matching(
+            cost,
+            pair_wise_ious,
+            tgt_labels,
+            num_gt,
+            fg_mask
+            )
+        del pair_wise_cls_loss, cost, pair_wise_ious, pair_wise_ious_loss
+
+        return (
+                gt_matched_classes,
+                fg_mask,
+                pred_ious_this_matching,
+                matched_gt_inds,
+                num_fg,
+        )
+
+
+    def get_in_boxes_info(
+        self,
+        gt_bboxes,   # [N, 4]
+        anchors,     # [M, 2]
+        strides,     # [M,]
+        num_anchors, # M
+        num_gt,      # N
+        ):
+        # anchor center
+        x_centers = anchors[:, 0]
+        y_centers = anchors[:, 1]
+
+        # [M,] -> [1, M] -> [N, M]
+        x_centers = x_centers.unsqueeze(0).repeat(num_gt, 1)
+        y_centers = y_centers.unsqueeze(0).repeat(num_gt, 1)
+
+        # [N,] -> [N, 1] -> [N, M]
+        gt_bboxes_l = gt_bboxes[:, 0].unsqueeze(1).repeat(1, num_anchors) # x1
+        gt_bboxes_t = gt_bboxes[:, 1].unsqueeze(1).repeat(1, num_anchors) # y1
+        gt_bboxes_r = gt_bboxes[:, 2].unsqueeze(1).repeat(1, num_anchors) # x2
+        gt_bboxes_b = gt_bboxes[:, 3].unsqueeze(1).repeat(1, num_anchors) # y2
+
+        b_l = x_centers - gt_bboxes_l
+        b_r = gt_bboxes_r - x_centers
+        b_t = y_centers - gt_bboxes_t
+        b_b = gt_bboxes_b - y_centers
+        bbox_deltas = torch.stack([b_l, b_t, b_r, b_b], 2)
+
+        is_in_boxes = bbox_deltas.min(dim=-1).values > 0.0
+        is_in_boxes_all = is_in_boxes.sum(dim=0) > 0
+        # in fixed center
+        center_radius = self.center_sampling_radius
+
+        # [N, 2]
+        gt_centers = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) * 0.5
+        
+        # [1, M]
+        center_radius_ = center_radius * strides.unsqueeze(0)
+
+        gt_bboxes_l = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # x1
+        gt_bboxes_t = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) - center_radius_ # y1
+        gt_bboxes_r = gt_centers[:, 0].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # x2
+        gt_bboxes_b = gt_centers[:, 1].unsqueeze(1).repeat(1, num_anchors) + center_radius_ # y2
+
+        c_l = x_centers - gt_bboxes_l
+        c_r = gt_bboxes_r - x_centers
+        c_t = y_centers - gt_bboxes_t
+        c_b = gt_bboxes_b - y_centers
+        center_deltas = torch.stack([c_l, c_t, c_r, c_b], 2)
+        is_in_centers = center_deltas.min(dim=-1).values > 0.0
+        is_in_centers_all = is_in_centers.sum(dim=0) > 0
+
+        # in boxes and in centers
+        is_in_boxes_anchor = is_in_boxes_all | is_in_centers_all
+
+        is_in_boxes_and_center = (
+            is_in_boxes[:, is_in_boxes_anchor] & is_in_centers[:, is_in_boxes_anchor]
+        )
+        return is_in_boxes_anchor, is_in_boxes_and_center
+    
+    
+    def dynamic_k_matching(
+        self, 
+        cost, 
+        pair_wise_ious, 
+        gt_classes, 
+        num_gt, 
+        fg_mask
+        ):
+        # Dynamic K
+        # ---------------------------------------------------------------
+        matching_matrix = torch.zeros_like(cost, dtype=torch.uint8)
+
+        ious_in_boxes_matrix = pair_wise_ious
+        n_candidate_k = min(self.topk_candidate, ious_in_boxes_matrix.size(1))
+        topk_ious, _ = torch.topk(ious_in_boxes_matrix, n_candidate_k, dim=1)
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+        dynamic_ks = dynamic_ks.tolist()
+        for gt_idx in range(num_gt):
+            _, pos_idx = torch.topk(
+                cost[gt_idx], k=dynamic_ks[gt_idx], largest=False
+            )
+            matching_matrix[gt_idx][pos_idx] = 1
+
+        del topk_ious, dynamic_ks, pos_idx
+
+        anchor_matching_gt = matching_matrix.sum(0)
+        if (anchor_matching_gt > 1).sum() > 0:
+            _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
+            matching_matrix[:, anchor_matching_gt > 1] *= 0
+            matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+        num_fg = fg_mask_inboxes.sum().item()
+
+        fg_mask[fg_mask.clone()] = fg_mask_inboxes
+
+        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
+        gt_matched_classes = gt_classes[matched_gt_inds]
+
+        pred_ious_this_matching = (matching_matrix * pair_wise_ious).sum(0)[
+            fg_mask_inboxes
+        ]
+        return num_fg, gt_matched_classes, pred_ious_this_matching, matched_gt_inds
+    

+ 274 - 0
models/yolox/yolox.py

@@ -0,0 +1,274 @@
+import torch
+import torch.nn as nn
+
+from .yolox_backbone import build_backbone
+from .yolox_pafpn import build_fpn
+from .yolox_head import build_head
+
+from utils.nms import multiclass_nms
+
+
+# YOLOX
+class YOLOX(nn.Module):
+    def __init__(self,
+                 cfg,
+                 device,
+                 num_classes=20,
+                 conf_thresh=0.01,
+                 topk=100,
+                 nms_thresh=0.5,
+                 trainable=False):
+        super(YOLOX, self).__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.device = device
+        self.stride = [8, 16, 32]
+        self.num_classes = num_classes
+        self.trainable = trainable
+        self.conf_thresh = conf_thresh
+        self.nms_thresh = nms_thresh
+        self.topk = topk
+        
+        # ------------------- Network Structure -------------------
+        ## 主干网络
+        self.backbone, feats_dim = build_backbone(cfg=cfg)
+        
+        ## 颈部网络: 特征金字塔
+        self.fpn = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=int(256*cfg['width']))
+        self.head_dim = self.fpn.out_dim
+
+        ## 检测头
+        self.non_shared_heads = nn.ModuleList(
+            [build_head(cfg, head_dim, head_dim, num_classes) 
+            for head_dim in self.head_dim
+            ])
+
+        ## 预测层
+        self.obj_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_out_dim, 1, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ]) 
+        self.cls_preds = nn.ModuleList(
+                            [nn.Conv2d(head.cls_out_dim, self.num_classes, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ]) 
+        self.reg_preds = nn.ModuleList(
+                            [nn.Conv2d(head.reg_out_dim, 4, kernel_size=1) 
+                                for head in self.non_shared_heads
+                              ])                 
+
+        # --------- Network Initialization ----------
+        # init bias
+        self.init_yolo()
+
+
+    def init_yolo(self): 
+        # Init yolo
+        for m in self.modules():
+            if isinstance(m, nn.BatchNorm2d):
+                m.eps = 1e-3
+                m.momentum = 0.03    
+        # Init bias
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        # obj pred
+        for obj_pred in self.obj_preds:
+            b = obj_pred.bias.view(1, -1)
+            b.data.fill_(bias_value.item())
+            obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # cls pred
+        for cls_pred in self.cls_preds:
+            b = cls_pred.bias.view(1, -1)
+            b.data.fill_(bias_value.item())
+            cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred
+        for reg_pred in self.reg_preds:
+            b = reg_pred.bias.view(-1, )
+            b.data.fill_(1.0)
+            reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+            w = reg_pred.weight
+            w.data.fill_(0.)
+            reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    def generate_anchors(self, level, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchor_xy *= self.stride[level]
+        anchors = anchor_xy.to(self.device)
+
+        return anchors
+        
+
+    def decode_boxes(self, anchors, reg_pred, stride):
+        """
+            anchors:  (List[Tensor]) [1, M, 2] or [M, 2]
+            reg_pred: (List[Tensor]) [B, M, 4] or [M, 4]
+        """
+        # center of bbox
+        pred_ctr_xy = anchors + reg_pred[..., :2] * stride
+        # size of bbox
+        pred_box_wh = reg_pred[..., 2:].exp() * stride
+
+        pred_x1y1 = pred_ctr_xy - 0.5 * pred_box_wh
+        pred_x2y2 = pred_ctr_xy + 0.5 * pred_box_wh
+        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        return pred_box
+
+
+    def post_process(self, obj_preds, cls_preds, reg_preds, anchors):
+        """
+        Input:
+            obj_preds: List(Tensor) [[H x W, 1], ...]
+            cls_preds: List(Tensor) [[H x W, C], ...]
+            reg_preds: List(Tensor) [[H x W, 4], ...]
+            anchors:  List(Tensor) [[H x W, 2], ...]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for level, (obj_pred_i, cls_pred_i, reg_pred_i, anchors_i) in enumerate(zip(obj_preds, cls_preds, reg_preds, anchors)):
+            # (H x W x C,)
+            scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk, reg_pred_i.size(0))
+
+            # torch.sort is actually faster than .topk (at least on GPUs)
+            predicted_prob, topk_idxs = scores_i.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            scores = topk_scores[keep_idxs]
+            topk_idxs = topk_idxs[keep_idxs]
+
+            anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+            labels = topk_idxs % self.num_classes
+
+            reg_pred_i = reg_pred_i[anchor_idxs]
+            anchors_i = anchors_i[anchor_idxs]
+
+            # decode box: [M, 4]
+            bboxes = self.decode_boxes(anchors_i, reg_pred_i, self.stride[level])
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
+
+        return bboxes, scores, labels
+
+
+    @torch.no_grad()
+    def inference_single_image(self, x):
+        # backbone
+        pyramid_feats = self.backbone(x)
+
+        # fpn
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # non-shared heads
+        all_obj_preds = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_anchors = []
+        for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+            cls_feat, reg_feat = head(feat)
+
+            # [1, C, H, W]
+            obj_pred = self.obj_preds[level](reg_feat)
+            cls_pred = self.cls_preds[level](cls_feat)
+            reg_pred = self.reg_preds[level](reg_feat)
+
+            # anchors: [M, 2]
+            fmp_size = cls_pred.shape[-2:]
+            anchors = self.generate_anchors(level, fmp_size)
+
+            # [1, C, H, W] -> [H, W, C] -> [M, C]
+            obj_pred = obj_pred[0].permute(1, 2, 0).contiguous().view(-1, 1)
+            cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
+            reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
+
+            all_obj_preds.append(obj_pred)
+            all_cls_preds.append(cls_pred)
+            all_reg_preds.append(reg_pred)
+            all_anchors.append(anchors)
+
+        # post process
+        bboxes, scores, labels = self.post_process(
+            all_obj_preds, all_cls_preds, all_reg_preds, all_anchors)
+        
+        return bboxes, scores, labels
+
+
+    def forward(self, x):
+        if not self.trainable:
+            return self.inference_single_image(x)
+        else:
+            # backbone
+            pyramid_feats = self.backbone(x)
+
+            # fpn
+            pyramid_feats = self.fpn(pyramid_feats)
+
+            # non-shared heads
+            all_anchors = []
+            all_obj_preds = []
+            all_cls_preds = []
+            all_box_preds = []
+            for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
+                cls_feat, reg_feat = head(feat)
+
+                # [B, C, H, W]
+                obj_pred = self.obj_preds[level](reg_feat)
+                cls_pred = self.cls_preds[level](cls_feat)
+                reg_pred = self.reg_preds[level](reg_feat)
+
+                B, _, H, W = cls_pred.size()
+                fmp_size = [H, W]
+                # generate anchor boxes: [M, 4]
+                anchors = self.generate_anchors(level, fmp_size)
+                
+                # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+                obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
+                cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+                reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+
+                # decode box: [M, 4]
+                box_pred = self.decode_boxes(anchors, reg_pred, self.stride[level])
+
+                all_obj_preds.append(obj_pred)
+                all_cls_preds.append(cls_pred)
+                all_box_preds.append(box_pred)
+                all_anchors.append(anchors)
+            
+            # output dict
+            outputs = {"pred_obj": all_obj_preds,        # List(Tensor) [B, M, 1]
+                       "pred_cls": all_cls_preds,        # List(Tensor) [B, M, C]
+                       "pred_box": all_box_preds,        # List(Tensor) [B, M, 4]
+                       "anchors": all_anchors,           # List(Tensor) [B, M, 2]
+                       'strides': self.stride}           # List(Int) [8, 16, 32]
+
+            return outputs 

+ 99 - 0
models/yolox/yolox_backbone.py

@@ -0,0 +1,99 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolox_basic import Conv, CSPBlock
+    from .yolox_neck import SPPF
+except:
+    from yolox_basic import Conv, CSPBlock
+    from yolox_neck import SPPF
+
+
+# CSPDarkNet
+class CSPDarkNet(nn.Module):
+    def __init__(self, depth=1.0, width=1.0, act_type='silu', norm_type='BN', depthwise=False):
+        super(CSPDarkNet, self).__init__()
+        self.feat_dims = [int(256*width), int(512*width), int(1024*width)]
+
+        # P1
+        self.layer_1 = Conv(3, int(64*width), k=6, p=2, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        
+        # P2
+        self.layer_2 = nn.Sequential(
+            Conv(int(64*width), int(128*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            CSPBlock(int(128*width), int(128*width), expand_ratio=0.5, nblocks=int(3*depth),
+                     shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P3
+        self.layer_3 = nn.Sequential(
+            Conv(int(128*width), int(256*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            CSPBlock(int(256*width), int(256*width), expand_ratio=0.5, nblocks=int(9*depth),
+                     shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P4
+        self.layer_4 = nn.Sequential(
+            Conv(int(256*width), int(512*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            CSPBlock(int(512*width), int(512*width), expand_ratio=0.5, nblocks=int(9*depth),
+                     shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P5
+        self.layer_5 = nn.Sequential(
+            Conv(int(512*width), int(1024*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            SPPF(int(1024*width), int(1024*width), expand_ratio=0.5, act_type=act_type, norm_type=norm_type),
+            CSPBlock(int(1024*width), int(1024*width), expand_ratio=0.5, nblocks=int(3*depth),
+                     shortcut=True, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ---------------------------- Functions ----------------------------
+def build_backbone(cfg): 
+    """Constructs a darknet-53 model.
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    backbone = CSPDarkNet(cfg['depth'], cfg['width'], cfg['bk_act'], cfg['bk_norm'], cfg['bk_dpw'])
+    feat_dims = backbone.feat_dims
+
+    return backbone, feat_dims
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'pretrained': False,
+        'bk_act': 'lrelu',
+        'bk_norm': 'BN',
+        'bk_dpw': False,
+        'p6_feat': False,
+        'p7_feat': False,
+        'width': 1.0,
+        'depth': 1.0,
+    }
+    model, feats = build_backbone(cfg)
+    x = torch.randn(1, 3, 256, 256)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    x = torch.randn(1, 3, 256, 256)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 137 - 0
models/yolox/yolox_basic.py

@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+
+
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+
+
+# Basic conv layer
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type='',          # activation
+                 norm_type='',         # normalization
+                 depthwise=False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+
+# ConvBlocks
+class Bottleneck(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio=0.5,
+                 kernel=[1, 3],
+                 shortcut=False,
+                 act_type='silu',
+                 norm_type='BN',
+                 depthwise=False):
+        super(Bottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)  # hidden channels            
+        self.cv1 = Conv(in_dim, inter_dim, k=kernel[0], p=kernel[0]//2,
+                        norm_type=norm_type, act_type=act_type,
+                        depthwise=False if kernel[0] == 1 else depthwise)
+        self.cv2 = Conv(inter_dim, out_dim, k=kernel[1], p=kernel[1]//2,
+                        norm_type=norm_type, act_type=act_type,
+                        depthwise=False if kernel[1] == 1 else depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.cv2(self.cv1(x))
+
+        return x + h if self.shortcut else h
+
+
+# CSP-stage block
+class CSPBlock(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio=0.5,
+                 kernel=[1, 3],
+                 nblocks=1,
+                 shortcut=False,
+                 depthwise=False,
+                 act_type='silu',
+                 norm_type='BN'):
+        super(CSPBlock, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)
+        self.cv1 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
+        self.cv2 = Conv(in_dim, inter_dim, k=1, norm_type=norm_type, act_type=act_type)
+        self.cv3 = Conv(2 * inter_dim, out_dim, k=1, norm_type=norm_type, act_type=act_type)
+        self.m = nn.Sequential(*[
+            Bottleneck(inter_dim, inter_dim, expand_ratio=1.0, kernel=kernel, shortcut=shortcut,
+                       norm_type=norm_type, act_type=act_type, depthwise=depthwise)
+                       for _ in range(nblocks)
+                       ])
+
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.m(x1)
+        out = self.cv3(torch.cat([x3, x2], dim=1))
+
+        return out
+    

+ 137 - 0
models/yolox/yolox_head.py

@@ -0,0 +1,137 @@
+import torch
+import torch.nn as nn
+try:
+    from .yolox_basic import Conv
+except:
+    from yolox_basic import Conv
+
+
+class DecoupledHead(nn.Module):
+    def __init__(self, cfg, in_dim, out_dim, num_classes=80):
+        super().__init__()
+        print('==============================')
+        print('Head: Decoupled Head')
+        self.in_dim = in_dim
+        self.num_cls_head=cfg['num_cls_head']
+        self.num_reg_head=cfg['num_reg_head']
+        self.act_type=cfg['head_act']
+        self.norm_type=cfg['head_norm']
+
+        # cls head
+        cls_feats = []
+        self.cls_out_dim = max(out_dim, num_classes)
+        for i in range(cfg['num_cls_head']):
+            if i == 0:
+                cls_feats.append(
+                    Conv(in_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+            else:
+                cls_feats.append(
+                    Conv(self.cls_out_dim, self.cls_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+                
+        # reg head
+        reg_feats = []
+        self.reg_out_dim = max(out_dim, 64)
+        for i in range(cfg['num_reg_head']):
+            if i == 0:
+                reg_feats.append(
+                    Conv(in_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+            else:
+                reg_feats.append(
+                    Conv(self.reg_out_dim, self.reg_out_dim, k=3, p=1, s=1, 
+                        act_type=self.act_type,
+                        norm_type=self.norm_type,
+                        depthwise=cfg['head_depthwise'])
+                        )
+
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+
+# build detection head
+def build_head(cfg, in_dim, out_dim, num_classes=80):
+    head = DecoupledHead(cfg, in_dim, out_dim, num_classes) 
+
+    return head
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'num_cls_head': 2,
+        'num_reg_head': 2,
+        'head_act': 'silu',
+        'head_norm': 'BN',
+        'head_depthwise': False,
+        'reg_max': 16,
+    }
+    fpn_dims = [256, 512, 512]
+    # Head-1
+    model = build_head(cfg, 256, fpn_dims, num_classes=80)
+    x = torch.randn(1, 256, 80, 80)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-1: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-1: Params : {:.2f} M'.format(params / 1e6))
+
+    # Head-2
+    model = build_head(cfg, 512, fpn_dims, num_classes=80)
+    x = torch.randn(1, 512, 40, 40)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-2: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-2: Params : {:.2f} M'.format(params / 1e6))
+
+    # Head-3
+    model = build_head(cfg, 512, fpn_dims, num_classes=80)
+    x = torch.randn(1, 512, 20, 20)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    # for out in outputs:
+    #     print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('Head-3: GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Head-3: Params : {:.2f} M'.format(params / 1e6))

+ 44 - 0
models/yolox/yolox_neck.py

@@ -0,0 +1,44 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolox_basic import Conv
+except:
+    from yolox_basic import Conv
+
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    def __init__(self, in_dim, out_dim, expand_ratio=0.5, pooling_size=5, act_type='', norm_type=''):
+        super().__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = Conv(inter_dim * 4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.m = nn.MaxPool2d(kernel_size=pooling_size, stride=1, padding=pooling_size // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+
+
+def build_neck(cfg, in_dim, out_dim):
+    model = cfg['neck']
+    print('==============================')
+    print('Neck: {}'.format(model))
+    # build neck
+    if model == 'sppf':
+        neck = SPPF(
+            in_dim=in_dim,
+            out_dim=out_dim,
+            expand_ratio=cfg['expand_ratio'], 
+            pooling_size=cfg['pooling_size'],
+            act_type=cfg['neck_act'],
+            norm_type=cfg['neck_norm']
+            )
+
+    return neck
+    

+ 144 - 0
models/yolox/yolox_pafpn.py

@@ -0,0 +1,144 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+try:
+    from .yolox_basic import Conv, CSPBlock
+except:
+    from yolox_basic import Conv, CSPBlock
+
+
+# PaFPN-CSP
+class YoloPaFPN(nn.Module):
+    def __init__(self, 
+                 in_dims=[256, 512, 1024],
+                 out_dim=256,
+                 width=1.0,
+                 depth=1.0,
+                 act_type='silu',
+                 norm_type='BN',
+                 depthwise=False):
+        super(YoloPaFPN, self).__init__()
+        self.in_dims = in_dims
+        self.out_dim = out_dim
+        c3, c4, c5 = in_dims
+
+        # top dwon
+        ## P5 -> P4
+        self.reduce_layer_1 = Conv(c5, int(512*width), k=1, norm_type=norm_type, act_type=act_type)
+        self.top_down_layer_1 = CSPBlock(c4 + int(512*width),
+                                         int(512*width),
+                                         expand_ratio=0.5,
+                                         kernel=[1, 3],
+                                         nblocks=int(3*depth),
+                                         shortcut=False,
+                                         act_type=act_type,
+                                         norm_type=norm_type,
+                                         depthwise=depthwise
+                                         )
+
+        ## P4 -> P3
+        self.reduce_layer_2 = Conv(c4, int(256*width), k=1, norm_type=norm_type, act_type=act_type)  # 14
+        self.top_down_layer_2 = CSPBlock(c3 + int(256*width),
+                                         int(256*width),
+                                         expand_ratio=0.5,
+                                         kernel=[1, 3],
+                                         nblocks=int(3*depth),
+                                         shortcut=False,
+                                         act_type=act_type,
+                                         norm_type=norm_type,
+                                         depthwise=depthwise
+                                         )
+
+        # bottom up
+        ## P3 -> P4
+        self.reduce_layer_3 = Conv(int(256*width), int(256*width), k=3, p=1, s=2,
+                                   act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.bottom_up_layer_1 = CSPBlock(int(256*width) + int(256*width),
+                                         int(512*width),
+                                         expand_ratio=0.5,
+                                         kernel=[1, 3],
+                                         nblocks=int(3*depth),
+                                         shortcut=False,
+                                         act_type=act_type,
+                                         norm_type=norm_type,
+                                         depthwise=depthwise
+                                         )
+
+        ## P4 -> P5
+        self.reduce_layer_4 = Conv(int(512*width), int(512*width), k=3, p=1, s=2,
+                                   act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.bottom_up_layer_2 = CSPBlock(int(512*width) + int(512*width),
+                                         int(1024*width),
+                                         expand_ratio=0.5,
+                                         kernel=[1, 3],
+                                         nblocks=int(3*depth),
+                                         shortcut=False,
+                                         act_type=act_type,
+                                         norm_type=norm_type,
+                                         depthwise=depthwise
+                                         )
+
+        # output proj layers
+        if out_dim is not None:
+            # output proj layers
+            self.out_layers = nn.ModuleList([
+                Conv(in_dim, out_dim, k=1,
+                        norm_type=norm_type, act_type=act_type)
+                        for in_dim in [int(256 * width), int(512 * width), int(1024 * width)]
+                        ])
+            self.out_dim = [out_dim] * 3
+
+        else:
+            self.out_layers = None
+            self.out_dim = [int(256 * width), int(512 * width), int(1024 * width)]
+
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        c6 = self.reduce_layer_1(c5)
+        c7 = F.interpolate(c6, scale_factor=2.0)   # s32->s16
+        c8 = torch.cat([c7, c4], dim=1)
+        c9 = self.top_down_layer_1(c8)
+        # P3/8
+        c10 = self.reduce_layer_2(c9)
+        c11 = F.interpolate(c10, scale_factor=2.0)   # s16->s8
+        c12 = torch.cat([c11, c3], dim=1)
+        c13 = self.top_down_layer_2(c12)  # to det
+        # p4/16
+        c14 = self.reduce_layer_3(c13)
+        c15 = torch.cat([c14, c10], dim=1)
+        c16 = self.bottom_up_layer_1(c15)  # to det
+        # p5/32
+        c17 = self.reduce_layer_4(c16)
+        c18 = torch.cat([c17, c6], dim=1)
+        c19 = self.bottom_up_layer_2(c18)  # to det
+
+        out_feats = [c13, c16, c19] # [P3, P4, P5]
+
+        # output proj layers
+        if self.out_layers is not None:
+            # output proj layers
+            out_feats_proj = []
+            for feat, layer in zip(out_feats, self.out_layers):
+                out_feats_proj.append(layer(feat))
+            return out_feats_proj
+
+        return out_feats
+
+
+def build_fpn(cfg, in_dims, out_dim=None):
+    model = cfg['fpn']
+    # build neck
+    if model == 'yolo_pafpn':
+        fpn_net = YoloPaFPN(in_dims=in_dims,
+                             out_dim=out_dim,
+                             width=cfg['width'],
+                             depth=cfg['depth'],
+                             act_type=cfg['fpn_act'],
+                             norm_type=cfg['fpn_norm'],
+                             depthwise=cfg['fpn_depthwise']
+                             )
+
+
+    return fpn_net

+ 2 - 2
test.py

@@ -23,7 +23,7 @@ def parse_args():
     parser = argparse.ArgumentParser(description='YOLO-Tutorial')
 
     # basic
-    parser.add_argument('-size', '--img_size', default=416, type=int,
+    parser.add_argument('-size', '--img_size', default=640, type=int,
                         help='the max size of input image')
     parser.add_argument('--show', action='store_true', default=False,
                         help='show the visulization results.')
@@ -40,7 +40,7 @@ def parse_args():
 
     # model
     parser.add_argument('-m', '--model', default='yolov1', type=str,
-                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4'], help='build yolo')
+                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolox'], help='build yolo')
     parser.add_argument('--weight', default=None,
                         type=str, help='Trained state_dict file path to open')
     parser.add_argument('-ct', '--conf_thresh', default=0.1, type=float,

+ 1 - 1
train.py

@@ -56,7 +56,7 @@ def parse_args():
 
     # model
     parser.add_argument('-m', '--model', default='yolov1', type=str,
-                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4'], help='build yolo')
+                        choices=['yolov1', 'yolov2', 'yolov3', 'yolov4', 'yolox'], help='build yolo')
     parser.add_argument('-ct', '--conf_thresh', default=0.005, type=float,
                         help='confidence threshold')
     parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,