yjh0410 2 anni fa
parent
commit
bfbc824138

+ 3 - 5
eval.py

@@ -14,7 +14,7 @@ from dataset.data_augment import build_transform
 
 # load some utils
 from utils.misc import load_weight
-from utils.com_flops_params import FLOPs_and_Params
+from utils.misc import compute_flops
 
 from models import build_model
 from config import build_model_config, build_trans_config
@@ -41,8 +41,6 @@ def parse_args():
                         help='topk candidates for testing')
     parser.add_argument("--no_decode", action="store_true", default=False,
                         help="not decode in inference or yes")
-    parser.add_argument('--fuse_repconv', action='store_true', default=False,
-                        help='fuse RepConv')
     parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
                         help='fuse Conv & BN')
 
@@ -140,14 +138,14 @@ if __name__ == '__main__':
     model = build_model(args, model_cfg, device, num_classes, False)
 
     # load trained weight
-    model = load_weight(model, args.weight, args.fuse_conv_bn, args.fuse_repconv)
+    model = load_weight(model, args.weight, args.fuse_conv_bn)
     model.to(device).eval()
 
     # compute FLOPs and Params
     model_copy = deepcopy(model)
     model_copy.trainable = False
     model_copy.eval()
-    FLOPs_and_Params(
+    compute_flops(
         model=model_copy,
         img_size=args.img_size, 
         device=device)

+ 1 - 15
models/__init__.py

@@ -80,21 +80,7 @@ def build_model(args,
             checkpoint = torch.load(args.resume, map_location='cpu')
             # checkpoint state dict
             checkpoint_state_dict = checkpoint.pop("model")
-            # check
-            new_checkpoint_state_dict = {}
-
-            for k in list(checkpoint_state_dict.keys()):
-                v = checkpoint_state_dict[k]
-                if 'reduce_layer_3' in k:
-                    k_new = k.split('.')
-                    k_new[1] = 'downsample_layer_1'
-                    k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
-                elif 'reduce_layer_4' in k:
-                    k_new = k.split('.')
-                    k_new[1] = 'downsample_layer_2'
-                    k = k_new[0] + '.' + k_new[1] + '.' + k_new[2] + '.' + k_new[3] + '.' + k_new[4]
-                new_checkpoint_state_dict[k] = v
-            model.load_state_dict(new_checkpoint_state_dict)
+            model.load_state_dict(checkpoint_state_dict)
 
         return model, criterion
 

+ 33 - 0
models/yolov1/build.py

@@ -1,6 +1,9 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 
+import torch
+import torch.nn as nn
+
 from .loss import build_criterion
 from .yolov1 import YOLOv1
 
@@ -13,6 +16,7 @@ def build_yolov1(args, cfg, device, num_classes=80, trainable=False):
     print('==============================')
     print('Model Configuration: \n', cfg)
     
+    # -------------- Build YOLO --------------
     model = YOLOv1(
         cfg = cfg,
         device = device,
@@ -23,6 +27,35 @@ def build_yolov1(args, cfg, device, num_classes=80, trainable=False):
         trainable = trainable
         )
 
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+    # Init bias
+    init_prob = 0.01
+    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+    # obj pred
+    for obj_pred in model.obj_preds:
+        b = obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # cls pred
+    for cls_pred in model.cls_preds:
+        b = cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # reg pred
+    for reg_pred in model.reg_preds:
+        b = reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = reg_pred.weight
+        w.data.fill_(0.)
+        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    # -------------- Build criterion --------------
     criterion = None
     if trainable:
         # build criterion for training

+ 1 - 13
models/yolov1/yolov1.py

@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 import numpy as np
 
-from utils.nms import multiclass_nms
+from utils.misc import multiclass_nms
 
 from .yolov1_backbone import build_backbone
 from .yolov1_neck import build_neck
@@ -48,18 +48,6 @@ class YOLOv1(nn.Module):
         self.reg_pred = nn.Conv2d(head_dim, 4, kernel_size=1)
     
 
-        if self.trainable:
-            self.init_bias()
-
-
-    def init_bias(self):
-        # init bias
-        init_prob = 0.01
-        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-        nn.init.constant_(self.obj_pred.bias, bias_value)
-        nn.init.constant_(self.cls_pred.bias, bias_value)
-
-
     def create_grid(self, fmp_size):
         """ 
             用于生成G矩阵,其中每个元素都是特征图上的像素坐标。

+ 40 - 9
models/yolov2/build.py

@@ -1,6 +1,9 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 
+import torch
+import torch.nn as nn
+
 from .loss import build_criterion
 from .yolov2 import YOLOv2
 
@@ -13,20 +16,48 @@ def build_yolov2(args, cfg, device, num_classes=80, trainable=False):
     print('==============================')
     print('Model Configuration: \n', cfg)
     
+    # -------------- Build YOLO --------------
     model = YOLOv2(
-        cfg = cfg,
-        device = device,
-        img_size = args.img_size,
-        num_classes = num_classes,
-        conf_thresh = args.conf_thresh,
-        nms_thresh = args.nms_thresh,
-        topk = args.topk,
-        trainable = trainable
+        cfg=cfg,
+        device=device, 
+        num_classes=num_classes,
+        trainable=trainable,
+        conf_thresh=args.conf_thresh,
+        nms_thresh=args.nms_thresh,
+        topk=args.topk,
         )
 
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+    # Init bias
+    init_prob = 0.01
+    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+    # obj pred
+    for obj_pred in model.obj_preds:
+        b = obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # cls pred
+    for cls_pred in model.cls_preds:
+        b = cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # reg pred
+    for reg_pred in model.reg_preds:
+        b = reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = reg_pred.weight
+        w.data.fill_(0.)
+        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    # -------------- Build criterion --------------
     criterion = None
     if trainable:
         # build criterion for training
         criterion = build_criterion(cfg, device, num_classes)
-
     return model, criterion

+ 2 - 4
models/yolov2/yolov2.py

@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 import numpy as np
 
-from utils.nms import multiclass_nms
+from utils.misc import multiclass_nms
 
 from .yolov2_backbone import build_backbone
 from .yolov2_neck import build_neck
@@ -14,16 +14,14 @@ class YOLOv2(nn.Module):
     def __init__(self,
                  cfg,
                  device,
-                 img_size=None,
                  num_classes=20,
                  conf_thresh=0.01,
-                 topk=100,
                  nms_thresh=0.5,
+                 topk=100,
                  trainable=False):
         super(YOLOv2, self).__init__()
         # ------------------- Basic parameters -------------------
         self.cfg = cfg                                 # 模型配置文件
-        self.img_size = img_size                       # 输入图像大小
         self.device = device                           # cuda或者是cpu
         self.num_classes = num_classes                 # 类别的数量
         self.trainable = trainable                     # 训练的标记

+ 40 - 8
models/yolov3/build.py

@@ -1,6 +1,9 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 
+import torch
+import torch.nn as nn
+
 from .loss import build_criterion
 from .yolov3 import YOLOv3
 
@@ -13,19 +16,48 @@ def build_yolov3(args, cfg, device, num_classes=80, trainable=False):
     print('==============================')
     print('Model Configuration: \n', cfg)
     
+    # -------------- Build YOLO --------------
     model = YOLOv3(
-        cfg = cfg,
-        device = device,
-        num_classes = num_classes,
-        conf_thresh = args.conf_thresh,
-        nms_thresh = args.nms_thresh,
-        topk = args.topk,
-        trainable = trainable
+        cfg=cfg,
+        device=device, 
+        num_classes=num_classes,
+        trainable=trainable,
+        conf_thresh=args.conf_thresh,
+        nms_thresh=args.nms_thresh,
+        topk=args.topk,
         )
 
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+    # Init bias
+    init_prob = 0.01
+    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+    # obj pred
+    for obj_pred in model.obj_preds:
+        b = obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # cls pred
+    for cls_pred in model.cls_preds:
+        b = cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # reg pred
+    for reg_pred in model.reg_preds:
+        b = reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = reg_pred.weight
+        w.data.fill_(0.)
+        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    # -------------- Build criterion --------------
     criterion = None
     if trainable:
         # build criterion for training
         criterion = build_criterion(cfg, device, num_classes)
-
     return model, criterion

+ 30 - 63
models/yolov3/yolov3.py

@@ -1,7 +1,7 @@
 import torch
 import torch.nn as nn
 
-from utils.nms import multiclass_nms
+from utils.misc import multiclass_nms
 
 from .yolov3_backbone import build_backbone
 from .yolov3_neck import build_neck
@@ -70,32 +70,8 @@ class YOLOv3(nn.Module):
                               ])                 
     
 
-        # --------- Network Initialization ----------
-        self.init_yolo()
-
-
-    def init_yolo(self): 
-        # Init yolo
-        for m in self.modules():
-            if isinstance(m, nn.BatchNorm2d):
-                m.eps = 1e-3
-                m.momentum = 0.03
-                
-        # Init bias
-        init_prob = 0.01
-        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-        # obj pred
-        for obj_pred in self.obj_preds:
-            b = obj_pred.bias.view(self.num_anchors, -1)
-            b.data.fill_(bias_value.item())
-            obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # cls pred
-        for cls_pred in self.cls_preds:
-            b = cls_pred.bias.view(self.num_anchors, -1)
-            b.data.fill_(bias_value.item())
-            cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-
-
+    # ---------------------- Basic Functions ----------------------
+    ## generate anchor points
     def generate_anchors(self, level, fmp_size):
         """
             fmp_size: (List) [H, W]
@@ -119,43 +95,25 @@ class YOLOv3(nn.Module):
 
         return anchors
         
-
-    def decode_boxes(self, level, anchors, reg_pred):
-        """
-            将txtytwth转换为常用的x1y1x2y2形式。
-        """
-
-        # 计算预测边界框的中心点坐标和宽高
-        pred_ctr = (torch.sigmoid(reg_pred[..., :2]) + anchors[..., :2]) * self.stride[level]
-        pred_wh = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
-
-        # 将所有bbox的中心带你坐标和宽高换算成x1y1x2y2形式
-        pred_x1y1 = pred_ctr - pred_wh * 0.5
-        pred_x2y2 = pred_ctr + pred_wh * 0.5
-        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-        return pred_box
-
-
-    def post_process(self, obj_preds, cls_preds, reg_preds, anchors):
+    ## post-process
+    def post_process(self, obj_preds, cls_preds, box_preds):
         """
         Input:
-            obj_preds: List(Tensor) [[H x W, 1], ...]
-            cls_preds: List(Tensor) [[H x W, C], ...]
-            reg_preds: List(Tensor) [[H x W, 4], ...]
-            anchors:  List(Tensor) [[H x W, 2], ...]
+            obj_preds: List(Tensor) [[H x W x A, 1], ...]
+            cls_preds: List(Tensor) [[H x W x A, C], ...]
+            box_preds: List(Tensor) [[H x W x A, 4], ...]
+            anchors:   List(Tensor) [[H x W x A, 2], ...]
         """
         all_scores = []
         all_labels = []
         all_bboxes = []
         
-        for level, (obj_pred_i, cls_pred_i, reg_pred_i, anchor_i) \
-                in enumerate(zip(obj_preds, cls_preds, reg_preds, anchors)):
+        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
             # (H x W x KA x C,)
             scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
 
             # Keep top k top scoring indices only.
-            num_topk = min(self.topk, reg_pred_i.size(0))
+            num_topk = min(self.topk, box_pred_i.size(0))
 
             # torch.sort is actually faster than .topk (at least on GPUs)
             predicted_prob, topk_idxs = scores_i.sort(descending=True)
@@ -170,11 +128,7 @@ class YOLOv3(nn.Module):
             anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
             labels = topk_idxs % self.num_classes
 
-            reg_pred_i = reg_pred_i[anchor_idxs]
-            anchor_i = anchor_i[anchor_idxs]
-
-            # decode box: [M, 4]
-            bboxes = self.decode_boxes(level, anchor_i, reg_pred_i)
+            bboxes = box_pred_i[anchor_idxs]
 
             all_scores.append(scores)
             all_labels.append(labels)
@@ -196,6 +150,7 @@ class YOLOv3(nn.Module):
         return bboxes, scores, labels
 
 
+    # ---------------------- Main Process for Inference ----------------------
     @torch.no_grad()
     def inference(self, x):
         # 主干网络
@@ -211,7 +166,7 @@ class YOLOv3(nn.Module):
         all_anchors = []
         all_obj_preds = []
         all_cls_preds = []
-        all_reg_preds = []
+        all_box_preds = []
         for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
             cls_feat, reg_feat = head(feat)
 
@@ -229,18 +184,26 @@ class YOLOv3(nn.Module):
             cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
             reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
 
+            # decode bbox
+            ctr_pred = (torch.sigmoid(reg_pred[..., :2]) + anchors[..., :2]) * self.stride[level]
+            wh_pred = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
+            pred_x1y1 = ctr_pred - wh_pred * 0.5
+            pred_x2y2 = ctr_pred + wh_pred * 0.5
+            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
             all_obj_preds.append(obj_pred)
             all_cls_preds.append(cls_pred)
-            all_reg_preds.append(reg_pred)
+            all_box_preds.append(box_pred)
             all_anchors.append(anchors)
 
         # post process
         bboxes, scores, labels = self.post_process(
-            all_obj_preds, all_cls_preds, all_reg_preds, all_anchors)
-
+            all_obj_preds, all_cls_preds, all_box_preds)
+        
         return bboxes, scores, labels
 
 
+
     def forward(self, x):
         if not self.trainable:
             return self.inference(x)
@@ -279,7 +242,11 @@ class YOLOv3(nn.Module):
                 reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, 4)
 
                 # decode bbox
-                box_pred = self.decode_boxes(level, anchors, reg_pred)
+                ctr_pred = (torch.sigmoid(reg_pred[..., :2]) + anchors[..., :2]) * self.stride[level]
+                wh_pred = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
+                pred_x1y1 = ctr_pred - wh_pred * 0.5
+                pred_x2y2 = ctr_pred + wh_pred * 0.5
+                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
 
                 all_obj_preds.append(obj_pred)
                 all_cls_preds.append(cls_pred)

+ 40 - 8
models/yolov4/build.py

@@ -1,6 +1,9 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 
+import torch
+import torch.nn as nn
+
 from .loss import build_criterion
 from .yolov4 import YOLOv4
 
@@ -13,19 +16,48 @@ def build_yolov4(args, cfg, device, num_classes=80, trainable=False):
     print('==============================')
     print('Model Configuration: \n', cfg)
     
+    # -------------- Build YOLO --------------
     model = YOLOv4(
-        cfg = cfg,
-        device = device,
-        num_classes = num_classes,
-        conf_thresh = args.conf_thresh,
-        nms_thresh = args.nms_thresh,
-        topk = args.topk,
-        trainable = trainable
+        cfg=cfg,
+        device=device, 
+        num_classes=num_classes,
+        trainable=trainable,
+        conf_thresh=args.conf_thresh,
+        nms_thresh=args.nms_thresh,
+        topk=args.topk,
         )
 
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+    # Init bias
+    init_prob = 0.01
+    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+    # obj pred
+    for obj_pred in model.obj_preds:
+        b = obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # cls pred
+    for cls_pred in model.cls_preds:
+        b = cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # reg pred
+    for reg_pred in model.reg_preds:
+        b = reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = reg_pred.weight
+        w.data.fill_(0.)
+        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    # -------------- Build criterion --------------
     criterion = None
     if trainable:
         # build criterion for training
         criterion = build_criterion(cfg, device, num_classes)
-
     return model, criterion

+ 33 - 66
models/yolov4/yolov4.py

@@ -1,7 +1,7 @@
 import torch
 import torch.nn as nn
 
-from utils.nms import multiclass_nms
+from utils.misc import multiclass_nms
 
 from .yolov4_backbone import build_backbone
 from .yolov4_neck import build_neck
@@ -16,8 +16,8 @@ class YOLOv4(nn.Module):
                  device,
                  num_classes=20,
                  conf_thresh=0.01,
-                 topk=100,
                  nms_thresh=0.5,
+                 topk=100,
                  trainable=False):
         super(YOLOv4, self).__init__()
         # ------------------- Basic parameters -------------------
@@ -70,32 +70,8 @@ class YOLOv4(nn.Module):
                               ])                 
     
 
-        # --------- Network Initialization ----------
-        self.init_yolo()
-
-
-    def init_yolo(self): 
-        # Init yolo
-        for m in self.modules():
-            if isinstance(m, nn.BatchNorm2d):
-                m.eps = 1e-3
-                m.momentum = 0.03
-                
-        # Init bias
-        init_prob = 0.01
-        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-        # obj pred
-        for obj_pred in self.obj_preds:
-            b = obj_pred.bias.view(self.num_anchors, -1)
-            b.data.fill_(bias_value.item())
-            obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # cls pred
-        for cls_pred in self.cls_preds:
-            b = cls_pred.bias.view(self.num_anchors, -1)
-            b.data.fill_(bias_value.item())
-            cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-
-
+    # ---------------------- Basic Functions ----------------------
+    ## generate anchor points
     def generate_anchors(self, level, fmp_size):
         """
             fmp_size: (List) [H, W]
@@ -108,8 +84,8 @@ class YOLOv4(nn.Module):
         anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
         anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
         # [HW, 2] -> [HW, KA, 2] -> [M, 2]
-        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1)
-        anchor_xy = anchor_xy.view(-1, 2).to(self.device) + 0.5
+        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1) + 0.5
+        anchor_xy = anchor_xy.view(-1, 2).to(self.device)
 
         # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2] -> [M, 2]
         anchor_wh = anchor_size.unsqueeze(0).repeat(fmp_h*fmp_w, 1, 1)
@@ -119,43 +95,25 @@ class YOLOv4(nn.Module):
 
         return anchors
         
-
-    def decode_boxes(self, level, anchors, reg_pred):
-        """
-            将txtytwth转换为常用的x1y1x2y2形式。
-        """
-
-        # 计算预测边界框的中心点坐标和宽高
-        pred_ctr = (torch.sigmoid(reg_pred[..., :2]) * 3.0 - 1.5 + anchors[..., :2]) * self.stride[level]
-        pred_wh = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
-
-        # 将所有bbox的中心带你坐标和宽高换算成x1y1x2y2形式
-        pred_x1y1 = pred_ctr - pred_wh * 0.5
-        pred_x2y2 = pred_ctr + pred_wh * 0.5
-        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-        return pred_box
-
-
-    def post_process(self, obj_preds, cls_preds, reg_preds, anchors):
+    ## post-process
+    def post_process(self, obj_preds, cls_preds, box_preds):
         """
         Input:
-            obj_preds: List(Tensor) [[H x W, 1], ...]
-            cls_preds: List(Tensor) [[H x W, C], ...]
-            reg_preds: List(Tensor) [[H x W, 4], ...]
-            anchors:  List(Tensor) [[H x W, 2], ...]
+            obj_preds: List(Tensor) [[H x W x A, 1], ...]
+            cls_preds: List(Tensor) [[H x W x A, C], ...]
+            box_preds: List(Tensor) [[H x W x A, 4], ...]
+            anchors:   List(Tensor) [[H x W x A, 2], ...]
         """
         all_scores = []
         all_labels = []
         all_bboxes = []
         
-        for level, (obj_pred_i, cls_pred_i, reg_pred_i, anchor_i) \
-                in enumerate(zip(obj_preds, cls_preds, reg_preds, anchors)):
+        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
             # (H x W x KA x C,)
             scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
 
             # Keep top k top scoring indices only.
-            num_topk = min(self.topk, reg_pred_i.size(0))
+            num_topk = min(self.topk, box_pred_i.size(0))
 
             # torch.sort is actually faster than .topk (at least on GPUs)
             predicted_prob, topk_idxs = scores_i.sort(descending=True)
@@ -170,11 +128,7 @@ class YOLOv4(nn.Module):
             anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
             labels = topk_idxs % self.num_classes
 
-            reg_pred_i = reg_pred_i[anchor_idxs]
-            anchor_i = anchor_i[anchor_idxs]
-
-            # decode box: [M, 4]
-            bboxes = self.decode_boxes(level, anchor_i, reg_pred_i)
+            bboxes = box_pred_i[anchor_idxs]
 
             all_scores.append(scores)
             all_labels.append(labels)
@@ -196,6 +150,7 @@ class YOLOv4(nn.Module):
         return bboxes, scores, labels
 
 
+    # ---------------------- Main Process for Inference ----------------------
     @torch.no_grad()
     def inference(self, x):
         # 主干网络
@@ -211,7 +166,7 @@ class YOLOv4(nn.Module):
         all_anchors = []
         all_obj_preds = []
         all_cls_preds = []
-        all_reg_preds = []
+        all_box_preds = []
         for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
             cls_feat, reg_feat = head(feat)
 
@@ -229,18 +184,26 @@ class YOLOv4(nn.Module):
             cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
             reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
 
+            # decode bbox
+            ctr_pred = (torch.sigmoid(reg_pred[..., :2]) * 3.0 - 1.5 + anchors[..., :2]) * self.stride[level]
+            wh_pred = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
+            pred_x1y1 = ctr_pred - wh_pred * 0.5
+            pred_x2y2 = ctr_pred + wh_pred * 0.5
+            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
             all_obj_preds.append(obj_pred)
             all_cls_preds.append(cls_pred)
-            all_reg_preds.append(reg_pred)
+            all_box_preds.append(box_pred)
             all_anchors.append(anchors)
 
         # post process
         bboxes, scores, labels = self.post_process(
-            all_obj_preds, all_cls_preds, all_reg_preds, all_anchors)
-
+            all_obj_preds, all_cls_preds, all_box_preds)
+        
         return bboxes, scores, labels
 
 
+
     def forward(self, x):
         if not self.trainable:
             return self.inference(x)
@@ -279,7 +242,11 @@ class YOLOv4(nn.Module):
                 reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(bs, -1, 4)
 
                 # decode bbox
-                box_pred = self.decode_boxes(level, anchors, reg_pred)
+                ctr_pred = (torch.sigmoid(reg_pred[..., :2]) * 3.0 - 1.5 + anchors[..., :2]) * self.stride[level]
+                wh_pred = torch.exp(reg_pred[..., 2:]) * anchors[..., 2:]
+                pred_x1y1 = ctr_pred - wh_pred * 0.5
+                pred_x2y2 = ctr_pred + wh_pred * 0.5
+                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
 
                 all_obj_preds.append(obj_pred)
                 all_cls_preds.append(cls_pred)

+ 3 - 1
models/yolov5/yolov5.py

@@ -5,7 +5,7 @@ from .yolov5_backbone import build_backbone
 from .yolov5_pafpn import build_fpn
 from .yolov5_head import build_head
 
-from utils.nms import multiclass_nms
+from utils.misc import multiclass_nms
 
 
 class YOLOv5(nn.Module):
@@ -140,6 +140,7 @@ class YOLOv5(nn.Module):
 
         return bboxes, scores, labels
 
+
     # ---------------------- Main Process for Inference ----------------------
     @torch.no_grad()
     def inference_single_image(self, x):
@@ -189,6 +190,7 @@ class YOLOv5(nn.Module):
         
         return bboxes, scores, labels
 
+
     # ---------------------- Main Process for Training ----------------------
     def forward(self, x):
         if not self.trainable:

+ 33 - 0
models/yolov7/build.py

@@ -1,6 +1,9 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 
+import torch
+import torch.nn as nn
+
 from .loss import build_criterion
 from .yolov7 import YOLOv7
 
@@ -13,6 +16,7 @@ def build_yolov7(args, cfg, device, num_classes=80, trainable=False):
     print('==============================')
     print('Model Configuration: \n', cfg)
     
+    # -------------- Build YOLO --------------
     model = YOLOv7(
         cfg = cfg,
         device = device,
@@ -23,6 +27,35 @@ def build_yolov7(args, cfg, device, num_classes=80, trainable=False):
         trainable = trainable
         )
 
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+    # Init bias
+    init_prob = 0.01
+    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+    # obj pred
+    for obj_pred in model.obj_preds:
+        b = obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # cls pred
+    for cls_pred in model.cls_preds:
+        b = cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # reg pred
+    for reg_pred in model.reg_preds:
+        b = reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = reg_pred.weight
+        w.data.fill_(0.)
+        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    # -------------- Build criterion --------------
     criterion = None
     if trainable:
         # build criterion for training

+ 29 - 68
models/yolov7/yolov7.py

@@ -1,7 +1,7 @@
 import torch
 import torch.nn as nn
 
-from utils.nms import multiclass_nms
+from utils.misc import multiclass_nms
 
 from .yolov7_backbone import build_backbone
 from .yolov7_neck import build_neck
@@ -61,40 +61,9 @@ class YOLOv7(nn.Module):
                                 for head in self.non_shared_heads
                               ])                 
 
-        # --------- Network Initialization ----------
-        # init bias
-        self.init_yolo()
-
-
-    def init_yolo(self): 
-        # Init yolo
-        for m in self.modules():
-            if isinstance(m, nn.BatchNorm2d):
-                m.eps = 1e-3
-                m.momentum = 0.03    
-        # Init bias
-        init_prob = 0.01
-        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-        # obj pred
-        for obj_pred in self.obj_preds:
-            b = obj_pred.bias.view(1, -1)
-            b.data.fill_(bias_value.item())
-            obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # cls pred
-        for cls_pred in self.cls_preds:
-            b = cls_pred.bias.view(1, -1)
-            b.data.fill_(bias_value.item())
-            cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # reg pred
-        for reg_pred in self.reg_preds:
-            b = reg_pred.bias.view(-1, )
-            b.data.fill_(1.0)
-            reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-            w = reg_pred.weight
-            w.data.fill_(0.)
-            reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
 
+    # ---------------------- Basic Functions ----------------------
+    ## generate anchor points
     def generate_anchors(self, level, fmp_size):
         """
             fmp_size: (List) [H, W]
@@ -110,42 +79,25 @@ class YOLOv7(nn.Module):
 
         return anchors
         
-
-    def decode_boxes(self, anchors, reg_pred, stride):
-        """
-            anchors:  (List[Tensor]) [1, M, 2] or [M, 2]
-            reg_pred: (List[Tensor]) [B, M, 4] or [M, 4]
-        """
-        # center of bbox
-        pred_ctr_xy = anchors + reg_pred[..., :2] * stride
-        # size of bbox
-        pred_box_wh = reg_pred[..., 2:].exp() * stride
-
-        pred_x1y1 = pred_ctr_xy - 0.5 * pred_box_wh
-        pred_x2y2 = pred_ctr_xy + 0.5 * pred_box_wh
-        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-        return pred_box
-
-
-    def post_process(self, obj_preds, cls_preds, reg_preds, anchors):
+    ## post-process
+    def post_process(self, obj_preds, cls_preds, box_preds):
         """
         Input:
             obj_preds: List(Tensor) [[H x W, 1], ...]
             cls_preds: List(Tensor) [[H x W, C], ...]
-            reg_preds: List(Tensor) [[H x W, 4], ...]
-            anchors:  List(Tensor) [[H x W, 2], ...]
+            box_preds: List(Tensor) [[H x W, 4], ...]
+            anchors:   List(Tensor) [[H x W, 2], ...]
         """
         all_scores = []
         all_labels = []
         all_bboxes = []
         
-        for level, (obj_pred_i, cls_pred_i, reg_pred_i, anchors_i) in enumerate(zip(obj_preds, cls_preds, reg_preds, anchors)):
-            # (H x W x C,)
+        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
+            # (H x W x KA x C,)
             scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
 
             # Keep top k top scoring indices only.
-            num_topk = min(self.topk, reg_pred_i.size(0))
+            num_topk = min(self.topk, box_pred_i.size(0))
 
             # torch.sort is actually faster than .topk (at least on GPUs)
             predicted_prob, topk_idxs = scores_i.sort(descending=True)
@@ -160,11 +112,7 @@ class YOLOv7(nn.Module):
             anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
             labels = topk_idxs % self.num_classes
 
-            reg_pred_i = reg_pred_i[anchor_idxs]
-            anchors_i = anchors_i[anchor_idxs]
-
-            # decode box: [M, 4]
-            bboxes = self.decode_boxes(anchors_i, reg_pred_i, self.stride[level])
+            bboxes = box_pred_i[anchor_idxs]
 
             all_scores.append(scores)
             all_labels.append(labels)
@@ -186,6 +134,7 @@ class YOLOv7(nn.Module):
         return bboxes, scores, labels
 
 
+    # ---------------------- Main Process for Inference ----------------------
     @torch.no_grad()
     def inference_single_image(self, x):
         # 主干网络
@@ -200,7 +149,7 @@ class YOLOv7(nn.Module):
         # 检测头
         all_obj_preds = []
         all_cls_preds = []
-        all_reg_preds = []
+        all_box_preds = []
         all_anchors = []
         for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
             cls_feat, reg_feat = head(feat)
@@ -219,18 +168,26 @@ class YOLOv7(nn.Module):
             cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
             reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
 
+            # decode bbox
+            ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
+            wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
+            pred_x1y1 = ctr_pred - wh_pred * 0.5
+            pred_x2y2 = ctr_pred + wh_pred * 0.5
+            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
             all_obj_preds.append(obj_pred)
             all_cls_preds.append(cls_pred)
-            all_reg_preds.append(reg_pred)
+            all_box_preds.append(box_pred)
             all_anchors.append(anchors)
 
         # post process
         bboxes, scores, labels = self.post_process(
-            all_obj_preds, all_cls_preds, all_reg_preds, all_anchors)
+            all_obj_preds, all_cls_preds, all_box_preds)
         
         return bboxes, scores, labels
 
 
+    # ---------------------- Main Process for Training ----------------------
     def forward(self, x):
         if not self.trainable:
             return self.inference_single_image(x)
@@ -267,8 +224,12 @@ class YOLOv7(nn.Module):
                 cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
                 reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
 
-                # decode box: [M, 4]
-                box_pred = self.decode_boxes(anchors, reg_pred, self.stride[level])
+                # decode bbox
+                ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
+                wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
+                pred_x1y1 = ctr_pred - wh_pred * 0.5
+                pred_x2y2 = ctr_pred + wh_pred * 0.5
+                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
 
                 all_obj_preds.append(obj_pred)
                 all_cls_preds.append(cls_pred)

+ 28 - 0
models/yolov8/build.py

@@ -1,6 +1,9 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 
+import torch
+import torch.nn as nn
+
 from .loss import build_criterion
 from .yolov8 import YOLOv8
 
@@ -13,6 +16,7 @@ def build_yolov8(args, cfg, device, num_classes=80, trainable=False):
     print('==============================')
     print('Model Configuration: \n', cfg)
     
+    # -------------- Build YOLO --------------
     model = YOLOv8(
         cfg=cfg,
         device=device, 
@@ -23,6 +27,30 @@ def build_yolov8(args, cfg, device, num_classes=80, trainable=False):
         topk=args.topk
         )
 
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+    # Init bias
+    init_prob = 0.01
+    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+    # cls pred
+    for cls_pred in model.cls_preds:
+        b = cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # reg pred
+    for reg_pred in model.reg_preds:
+        b = reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = reg_pred.weight
+        w.data.fill_(0.)
+        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    # -------------- Build criterion --------------
     criterion = None
     if trainable:
         # build criterion for training

+ 54 - 89
models/yolov8/yolov8.py

@@ -7,7 +7,7 @@ from .yolov8_neck import build_neck
 from .yolov8_pafpn import build_fpn
 from .yolov8_head import build_head
 
-from utils.nms import multiclass_nms
+from utils.misc import multiclass_nms
 
 
 # Anchor-free YOLO
@@ -63,38 +63,9 @@ class YOLOv8(nn.Module):
                                 for head in self.non_shared_heads
                               ])                 
 
-        # --------- Network Initialization ----------
-        # init bias
-        self.init_yolo()
-
-
-    def init_yolo(self): 
-        # Init yolo
-        for m in self.modules():
-            if isinstance(m, nn.BatchNorm2d):
-                m.eps = 1e-3
-                m.momentum = 0.03    
-        # Init bias
-        init_prob = 0.01
-        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-        # cls pred
-        for cls_pred in self.cls_preds:
-            b = cls_pred.bias.view(1, -1)
-            b.data.fill_(bias_value.item())
-            cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        for reg_pred in self.reg_preds:
-            b = reg_pred.bias.view(-1, )
-            b.data.fill_(1.0)
-            reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-            w = reg_pred.weight
-            w.data.fill_(0.)
-            reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
-        self.proj = nn.Parameter(torch.linspace(0, self.reg_max, self.reg_max), requires_grad=False)
-        self.proj_conv.weight = nn.Parameter(self.proj.view([1, self.reg_max, 1, 1]).clone().detach(),
-                                                   requires_grad=False)
-
 
+    # ---------------------- Basic Functions ----------------------
+    ## generate anchor points
     def generate_anchors(self, level, fmp_size):
         """
             fmp_size: (List) [H, W]
@@ -109,70 +80,39 @@ class YOLOv8(nn.Module):
 
         return anchors
         
-
-    def decode_boxes(self, anchors, pred_regs, stride):
-        """
-        Input:
-            anchors:  (List[Tensor]) [1, M, 2]
-            pred_reg: (List[Tensor]) [B, M, 4*(reg_max)]
-        Output:
-            pred_box: (Tensor) [B, M, 4]
-        """
-        if self.use_dfl:
-            B, M = pred_regs.shape[:2]
-            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
-            pred_regs = pred_regs.reshape([B, M, 4, self.reg_max])
-            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
-            pred_regs = pred_regs.permute(0, 3, 2, 1).contiguous()
-            # [B, reg_max, 4, M] -> [B, 1, 4, M]
-            pred_regs = self.proj_conv(F.softmax(pred_regs, dim=1))
-            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
-            pred_regs = pred_regs.view(B, 4, M).permute(0, 2, 1).contiguous()
-
-        # tlbr -> xyxy
-        pred_x1y1 = anchors - pred_regs[..., :2] * stride
-        pred_x2y2 = anchors + pred_regs[..., 2:] * stride
-        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-        return pred_box
-
-
-    def post_process(self, cls_preds, reg_preds, anchors):
+    ## post-process
+    def post_process(self, cls_preds, box_preds):
         """
         Input:
-            cls_preds: List(Tensor) [[B, H x W, C], ...]
-            reg_preds: List(Tensor) [[B, H x W, 4*(reg_max)], ...]
+            cls_preds: List(Tensor) [[H x W, C], ...]
+            box_preds: List(Tensor) [[H x W, 4], ...]
             anchors:   List(Tensor) [[H x W, 2], ...]
         """
         all_scores = []
         all_labels = []
         all_bboxes = []
         
-        for level, (cls_pred_i, reg_pred_i, anchors_i) in enumerate(zip(cls_preds, reg_preds, anchors)):
-            # [B, M, C] -> [M, C]
-            cur_cls_pred_i = cls_pred_i[0]
-            cur_reg_pred_i = reg_pred_i[0]
-            # [MC,]
-            scores_i = cur_cls_pred_i.sigmoid().flatten()
+        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+            # (H x W x KA x C,)
+            scores_i = cls_pred_i.sigmoid().flatten()
 
             # Keep top k top scoring indices only.
-            num_topk = min(self.topk, cur_reg_pred_i.size(0))
+            num_topk = min(self.topk, box_pred_i.size(0))
 
             # torch.sort is actually faster than .topk (at least on GPUs)
             predicted_prob, topk_idxs = scores_i.sort(descending=True)
-            scores = predicted_prob[:num_topk]
+            topk_scores = predicted_prob[:num_topk]
             topk_idxs = topk_idxs[:num_topk]
 
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            scores = topk_scores[keep_idxs]
+            topk_idxs = topk_idxs[keep_idxs]
+
             anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
             labels = topk_idxs % self.num_classes
 
-            cur_reg_pred_i = cur_reg_pred_i[anchor_idxs]
-            anchors_i = anchors_i[anchor_idxs]
-
-            # decode box: [M, 4]
-            box_pred_i = self.decode_boxes(
-                anchors_i[None], cur_reg_pred_i[None], self.stride[level])
-            bboxes = box_pred_i[0]
+            bboxes = box_pred_i[anchor_idxs]
 
             all_scores.append(scores)
             all_labels.append(labels)
@@ -182,12 +122,6 @@ class YOLOv8(nn.Module):
         labels = torch.cat(all_labels)
         bboxes = torch.cat(all_bboxes)
 
-        # threshold
-        keep_idxs = scores.gt(self.conf_thresh)
-        scores = scores[keep_idxs]
-        labels = labels[keep_idxs]
-        bboxes = bboxes[keep_idxs]
-
         # to cpu & numpy
         scores = scores.cpu().numpy()
         labels = labels.cpu().numpy()
@@ -200,6 +134,7 @@ class YOLOv8(nn.Module):
         return bboxes, scores, labels
 
 
+    # ---------------------- Main Process for Inference ----------------------
     @torch.no_grad()
     def inference_single_image(self, x):
         # backbone
@@ -213,7 +148,7 @@ class YOLOv8(nn.Module):
 
         # non-shared heads
         all_cls_preds = []
-        all_reg_preds = []
+        all_box_preds = []
         all_anchors = []
         for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
             cls_feat, reg_feat = head(feat)
@@ -231,17 +166,33 @@ class YOLOv8(nn.Module):
             cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
             reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
 
+            # decode bbox
+            if self.use_dfl:
+                B, M = reg_pred.shape[:2]
+                # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+                reg_pred = reg_pred.reshape([B, M, 4, self.reg_max])
+                # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+                reg_pred = reg_pred.permute(0, 3, 2, 1).contiguous()
+                # [B, reg_max, 4, M] -> [B, 1, 4, M]
+                reg_pred = self.proj_conv(F.softmax(reg_pred, dim=1))
+                # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+                reg_pred = reg_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+            pred_x1y1 = anchors - reg_pred[..., :2] * self.stride[level]
+            pred_x2y2 = anchors + reg_pred[..., 2:] * self.stride[level]
+            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
             all_cls_preds.append(cls_pred)
-            all_reg_preds.append(reg_pred)
+            all_box_preds.append(box_pred)
             all_anchors.append(anchors)
 
         # post process
         bboxes, scores, labels = self.post_process(
-            all_cls_preds, all_reg_preds, all_anchors)
+            all_cls_preds, all_box_preds, all_anchors)
         
         return bboxes, scores, labels
 
 
+    # ---------------------- Main Process for Training ----------------------
     def forward(self, x):
         if not self.trainable:
             return self.inference_single_image(x)
@@ -277,8 +228,22 @@ class YOLOv8(nn.Module):
                 cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
                 reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
 
-                # decode box: [B, M, 4]
-                box_pred = self.decode_boxes(anchors, reg_pred, self.stride[level])
+                # decode bbox
+                if self.use_dfl:
+                    B, M = reg_pred.shape[:2]
+                    # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max] -> [B, 4, M, reg_max]
+                    reg_pred_ = reg_pred.reshape([B, M, 4, self.reg_max]).clone()
+                    # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+                    reg_pred_ = reg_pred_.permute(0, 3, 2, 1).contiguous()
+                    # [B, reg_max, 4, M] -> [B, 1, 4, M]
+                    reg_pred_ = self.proj_conv(F.softmax(reg_pred_, dim=1))
+                    # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+                    reg_pred_ = reg_pred_.view(B, 4, M).permute(0, 2, 1).contiguous()
+                pred_x1y1 = anchors - reg_pred_[..., :2] * self.stride[level]
+                pred_x2y2 = anchors + reg_pred_[..., 2:] * self.stride[level]
+                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+                del reg_pred_
 
                 # stride tensor: [M, 1]
                 stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride[level]

+ 6 - 11
models/yolov8/yolov8_pafpn.py

@@ -23,9 +23,10 @@ class Yolov8PaFPN(nn.Module):
         self.in_dims = in_dims
         self.width = width
         self.depth = depth
+        self.out_dim = [int(256 * width), int(512 * width), int(512 * width * ratio)]
         c3, c4, c5 = in_dims
 
-        # top dwon
+        # --------------------------- Top-dwon ---------------------------
         ## P5 -> P4
         self.head_elan_1 = ELAN_CSP_Block(in_dim=c5 + c4,
                                           out_dim=int(512*width),
@@ -36,8 +37,7 @@ class Yolov8PaFPN(nn.Module):
                                           norm_type=norm_type,
                                           act_type=act_type
                                           )
-
-        # P4 -> P3
+        ## P4 -> P3
         self.head_elan_2 = ELAN_CSP_Block(in_dim=int(512*width) + c3,
                                           out_dim=int(256*width),
                                           expand_ratio=0.5,
@@ -47,10 +47,8 @@ class Yolov8PaFPN(nn.Module):
                                           norm_type=norm_type,
                                           act_type=act_type
                                           )
-
-
-        # bottom up
-        # P3 -> P4
+        # --------------------------- Bottom-up ---------------------------
+        ## P3 -> P4
         self.mp1 = Conv(int(256*width), int(256*width), k=3, p=1, s=2,
                         act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         self.head_elan_3 = ELAN_CSP_Block(in_dim=int(256*width) + int(512*width),
@@ -62,8 +60,7 @@ class Yolov8PaFPN(nn.Module):
                                           norm_type=norm_type,
                                           act_type=act_type
                                           )
-
-        # P4 -> P5
+        ## P4 -> P5
         self.mp2 = Conv(int(512 * width), int(512 * width), k=3, p=1, s=2,
                         act_type=act_type, norm_type=norm_type, depthwise=depthwise)
         self.head_elan_4 = ELAN_CSP_Block(in_dim=int(512 * width) + c5,
@@ -76,8 +73,6 @@ class Yolov8PaFPN(nn.Module):
                                           act_type=act_type
                                           )
 
-        self.out_dim = [int(256 * width), int(512 * width), int(512 * width * ratio)]
-
 
     def forward(self, features):
         c3, c4, c5 = features

+ 40 - 7
models/yolox/build.py

@@ -1,6 +1,9 @@
 #!/usr/bin/env python3
 # -*- coding:utf-8 -*-
 
+import torch
+import torch.nn as nn
+
 from .loss import build_criterion
 from .yolox import YOLOX
 
@@ -13,16 +16,46 @@ def build_yolox(args, cfg, device, num_classes=80, trainable=False):
     print('==============================')
     print('Model Configuration: \n', cfg)
     
+    # -------------- Build YOLO --------------
     model = YOLOX(
-        cfg = cfg,
-        device = device,
-        num_classes = num_classes,
-        conf_thresh = args.conf_thresh,
-        nms_thresh = args.nms_thresh,
-        topk = args.topk,
-        trainable = trainable
+        cfg=cfg,
+        device=device, 
+        num_classes=num_classes,
+        trainable=trainable,
+        conf_thresh=args.conf_thresh,
+        nms_thresh=args.nms_thresh,
+        topk=args.topk,
         )
 
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+    # Init bias
+    init_prob = 0.01
+    bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+    # obj pred
+    for obj_pred in model.obj_preds:
+        b = obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # cls pred
+    for cls_pred in model.cls_preds:
+        b = cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+    # reg pred
+    for reg_pred in model.reg_preds:
+        b = reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = reg_pred.weight
+        w.data.fill_(0.)
+        reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+
+    # -------------- Build criterion --------------
     criterion = None
     if trainable:
         # build criterion for training

+ 29 - 70
models/yolox/yolox.py

@@ -5,7 +5,7 @@ from .yolox_backbone import build_backbone
 from .yolox_fpn import build_fpn
 from .yolox_head import build_head
 
-from utils.nms import multiclass_nms
+from utils.misc import multiclass_nms
 
 
 # YOLOX
@@ -15,8 +15,8 @@ class YOLOX(nn.Module):
                  device,
                  num_classes=20,
                  conf_thresh=0.01,
-                 topk=100,
                  nms_thresh=0.5,
+                 topk=100,
                  trainable=False):
         super(YOLOX, self).__init__()
         # --------- Basic Parameters ----------
@@ -57,40 +57,8 @@ class YOLOX(nn.Module):
                                 for head in self.non_shared_heads
                               ])                 
 
-        # --------- Network Initialization ----------
-        # init bias
-        self.init_yolo()
-
-
-    def init_yolo(self): 
-        # Init yolo
-        for m in self.modules():
-            if isinstance(m, nn.BatchNorm2d):
-                m.eps = 1e-3
-                m.momentum = 0.03    
-        # Init bias
-        init_prob = 0.01
-        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
-        # obj pred
-        for obj_pred in self.obj_preds:
-            b = obj_pred.bias.view(1, -1)
-            b.data.fill_(bias_value.item())
-            obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # cls pred
-        for cls_pred in self.cls_preds:
-            b = cls_pred.bias.view(1, -1)
-            b.data.fill_(bias_value.item())
-            cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-        # reg pred
-        for reg_pred in self.reg_preds:
-            b = reg_pred.bias.view(-1, )
-            b.data.fill_(1.0)
-            reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
-            w = reg_pred.weight
-            w.data.fill_(0.)
-            reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
-
-
+    # ---------------------- Basic Functions ----------------------
+    ## generate anchor points
     def generate_anchors(self, level, fmp_size):
         """
             fmp_size: (List) [H, W]
@@ -106,42 +74,25 @@ class YOLOX(nn.Module):
 
         return anchors
         
-
-    def decode_boxes(self, anchors, reg_pred, stride):
-        """
-            anchors:  (List[Tensor]) [1, M, 2] or [M, 2]
-            reg_pred: (List[Tensor]) [B, M, 4] or [M, 4]
-        """
-        # center of bbox
-        pred_ctr_xy = anchors + reg_pred[..., :2] * stride
-        # size of bbox
-        pred_box_wh = reg_pred[..., 2:].exp() * stride
-
-        pred_x1y1 = pred_ctr_xy - 0.5 * pred_box_wh
-        pred_x2y2 = pred_ctr_xy + 0.5 * pred_box_wh
-        pred_box = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
-
-        return pred_box
-
-
-    def post_process(self, obj_preds, cls_preds, reg_preds, anchors):
+    ## post-process
+    def post_process(self, obj_preds, cls_preds, box_preds):
         """
         Input:
             obj_preds: List(Tensor) [[H x W, 1], ...]
             cls_preds: List(Tensor) [[H x W, C], ...]
-            reg_preds: List(Tensor) [[H x W, 4], ...]
-            anchors:  List(Tensor) [[H x W, 2], ...]
+            box_preds: List(Tensor) [[H x W, 4], ...]
+            anchors:   List(Tensor) [[H x W, 2], ...]
         """
         all_scores = []
         all_labels = []
         all_bboxes = []
         
-        for level, (obj_pred_i, cls_pred_i, reg_pred_i, anchors_i) in enumerate(zip(obj_preds, cls_preds, reg_preds, anchors)):
-            # (H x W x C,)
+        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
+            # (H x W x KA x C,)
             scores_i = (torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid())).flatten()
 
             # Keep top k top scoring indices only.
-            num_topk = min(self.topk, reg_pred_i.size(0))
+            num_topk = min(self.topk, box_pred_i.size(0))
 
             # torch.sort is actually faster than .topk (at least on GPUs)
             predicted_prob, topk_idxs = scores_i.sort(descending=True)
@@ -156,11 +107,7 @@ class YOLOX(nn.Module):
             anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
             labels = topk_idxs % self.num_classes
 
-            reg_pred_i = reg_pred_i[anchor_idxs]
-            anchors_i = anchors_i[anchor_idxs]
-
-            # decode box: [M, 4]
-            bboxes = self.decode_boxes(anchors_i, reg_pred_i, self.stride[level])
+            bboxes = box_pred_i[anchor_idxs]
 
             all_scores.append(scores)
             all_labels.append(labels)
@@ -182,6 +129,7 @@ class YOLOX(nn.Module):
         return bboxes, scores, labels
 
 
+    # ---------------------- Main Process for Inference ----------------------
     @torch.no_grad()
     def inference_single_image(self, x):
         # backbone
@@ -193,7 +141,7 @@ class YOLOX(nn.Module):
         # non-shared heads
         all_obj_preds = []
         all_cls_preds = []
-        all_reg_preds = []
+        all_box_preds = []
         all_anchors = []
         for level, (feat, head) in enumerate(zip(pyramid_feats, self.non_shared_heads)):
             cls_feat, reg_feat = head(feat)
@@ -212,14 +160,21 @@ class YOLOX(nn.Module):
             cls_pred = cls_pred[0].permute(1, 2, 0).contiguous().view(-1, self.num_classes)
             reg_pred = reg_pred[0].permute(1, 2, 0).contiguous().view(-1, 4)
 
+            # decode bbox
+            ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
+            wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
+            pred_x1y1 = ctr_pred - wh_pred * 0.5
+            pred_x2y2 = ctr_pred + wh_pred * 0.5
+            box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
             all_obj_preds.append(obj_pred)
             all_cls_preds.append(cls_pred)
-            all_reg_preds.append(reg_pred)
+            all_box_preds.append(box_pred)
             all_anchors.append(anchors)
 
         # post process
         bboxes, scores, labels = self.post_process(
-            all_obj_preds, all_cls_preds, all_reg_preds, all_anchors)
+            all_obj_preds, all_cls_preds, all_box_preds)
         
         return bboxes, scores, labels
 
@@ -257,8 +212,12 @@ class YOLOX(nn.Module):
                 cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
                 reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
 
-                # decode box: [M, 4]
-                box_pred = self.decode_boxes(anchors, reg_pred, self.stride[level])
+                # decode bbox
+                ctr_pred = reg_pred[..., :2] * self.stride[level] + anchors[..., :2]
+                wh_pred = torch.exp(reg_pred[..., 2:]) * self.stride[level]
+                pred_x1y1 = ctr_pred - wh_pred * 0.5
+                pred_x2y2 = ctr_pred + wh_pred * 0.5
+                box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
 
                 all_obj_preds.append(obj_pred)
                 all_cls_preds.append(cls_pred)

+ 3 - 6
test.py

@@ -11,9 +11,8 @@ from dataset.data_augment import build_transform
 
 # load some utils
 from utils.misc import build_dataset, load_weight
-from utils.com_flops_params import FLOPs_and_Params
+from utils.misc import compute_flops
 from utils.box_ops import rescale_bboxes
-from utils import fuse_conv_bn
 
 from models import build_model
 from config import build_model_config, build_trans_config
@@ -51,8 +50,6 @@ def parse_args():
                         help='topk candidates for testing')
     parser.add_argument("--no_decode", action="store_true", default=False,
                         help="not decode in inference or yes")
-    parser.add_argument('--fuse_repconv', action='store_true', default=False,
-                        help='fuse RepConv')
     parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
                         help='fuse Conv & BN')
 
@@ -195,14 +192,14 @@ if __name__ == '__main__':
     model = build_model(args, model_cfg, device, num_classes, False)
 
     # load trained weight
-    model = load_weight(model, args.weight, args.fuse_conv_bn, args.fuse_repconv)
+    model = load_weight(model, args.weight, args.fuse_conv_bn)
     model.to(device).eval()
 
     # compute FLOPs and Params
     model_copy = deepcopy(model)
     model_copy.trainable = False
     model_copy.eval()
-    FLOPs_and_Params(
+    compute_flops(
         model=model_copy,
         img_size=args.img_size, 
         device=device)

+ 2 - 2
train.py

@@ -9,7 +9,7 @@ import torch.distributed as dist
 from torch.nn.parallel import DistributedDataParallel as DDP
 
 from utils import distributed_utils
-from utils.com_flops_params import FLOPs_and_Params
+from utils.misc import compute_flops
 from utils.misc import ModelEMA, CollateFunc, build_dataset, build_dataloader
 from utils.solver.optimizer import build_optimizer
 from utils.solver.lr_scheduler import build_lr_scheduler
@@ -161,7 +161,7 @@ def train():
         model_copy = deepcopy(model_without_ddp)
         model_copy.trainable = False
         model_copy.eval()
-        FLOPs_and_Params(model=model_copy, 
+        compute_flops(model=model_copy, 
                          img_size=args.img_size, 
                          device=device)
         del model_copy

+ 0 - 15
utils/com_flops_params.py

@@ -1,15 +0,0 @@
-import torch
-from thop import profile
-
-
-def FLOPs_and_Params(model, img_size, device):
-    x = torch.randn(1, 3, img_size, img_size).to(device)
-    print('==============================')
-    flops, params = profile(model, inputs=(x, ), verbose=False)
-    print('==============================')
-    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))
-
-
-if __name__ == "__main__":
-    pass

+ 0 - 54
utils/fuse_conv_bn.py

@@ -1,54 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import torch
-import torch.nn as nn
-
-
-def _fuse_conv_bn(conv, bn):
-    """Fuse conv and bn into one module.
-    Args:
-        conv (nn.Module): Conv to be fused.
-        bn (nn.Module): BN to be fused.
-    Returns:
-        nn.Module: Fused module.
-    """
-    conv_w = conv.weight
-    conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
-        bn.running_mean)
-
-    factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
-    conv.weight = nn.Parameter(conv_w *
-                               factor.reshape([conv.out_channels, 1, 1, 1]))
-    conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
-    return conv
-
-
-def fuse_conv_bn(module):
-    """Recursively fuse conv and bn in a module.
-    During inference, the functionary of batch norm layers is turned off
-    but only the mean and var alone channels are used, which exposes the
-    chance to fuse it with the preceding conv layers to save computations and
-    simplify network structures.
-    Args:
-        module (nn.Module): Module to be fused.
-    Returns:
-        nn.Module: Fused module.
-    """
-    last_conv = None
-    last_conv_name = None
-
-    for name, child in module.named_children():
-        if isinstance(child,
-                      (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
-            if last_conv is None:  # only fuse BN that is after Conv
-                continue
-            fused_conv = _fuse_conv_bn(last_conv, child)
-            module._modules[last_conv_name] = fused_conv
-            # To reduce changes, set BN as Identity instead of deleting it.
-            module._modules[name] = nn.Identity()
-            last_conv = None
-        elif isinstance(child, nn.Conv2d):
-            last_conv = child
-            last_conv_name = name
-        else:
-            fuse_conv_bn(child)
-    return module

+ 240 - 13
utils/misc.py

@@ -3,10 +3,12 @@ import torch.nn as nn
 import torch.nn.functional as F
 from torch.utils.data import DataLoader, DistributedSampler
 
-import numpy as np
 import os
+import cv2
 import math
+import numpy as np
 from copy import deepcopy
+from thop import profile
 
 from evaluator.coco_evaluator import COCOAPIEvaluator
 from evaluator.voc_evaluator import VOCAPIEvaluator
@@ -17,9 +19,6 @@ from dataset.coco import COCODataset, coco_class_index, coco_class_labels
 from dataset.ourdataset import OurDataset, our_class_labels
 from dataset.data_augment import build_transform
 
-from utils import fuse_conv_bn
-from models.yolov7.yolov7_basic import RepConv
-
 
 # ---------------------------- For Dataset ----------------------------
 ## build dataset
@@ -143,8 +142,97 @@ class CollateFunc(object):
 
 
 # ---------------------------- For Model ----------------------------
+## fuse Conv & BN layer
+def fuse_conv_bn(module):
+    """Recursively fuse conv and bn in a module.
+    During inference, the functionary of batch norm layers is turned off
+    but only the mean and var alone channels are used, which exposes the
+    chance to fuse it with the preceding conv layers to save computations and
+    simplify network structures.
+    Args:
+        module (nn.Module): Module to be fused.
+    Returns:
+        nn.Module: Fused module.
+    """
+    last_conv = None
+    last_conv_name = None
+    
+    def _fuse_conv_bn(conv, bn):
+        """Fuse conv and bn into one module.
+        Args:
+            conv (nn.Module): Conv to be fused.
+            bn (nn.Module): BN to be fused.
+        Returns:
+            nn.Module: Fused module.
+        """
+        conv_w = conv.weight
+        conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
+            bn.running_mean)
+
+        factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+        conv.weight = nn.Parameter(conv_w *
+                                factor.reshape([conv.out_channels, 1, 1, 1]))
+        conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+        return conv
+    for name, child in module.named_children():
+        if isinstance(child,
+                      (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
+            if last_conv is None:  # only fuse BN that is after Conv
+                continue
+            fused_conv = _fuse_conv_bn(last_conv, child)
+            module._modules[last_conv_name] = fused_conv
+            # To reduce changes, set BN as Identity instead of deleting it.
+            module._modules[name] = nn.Identity()
+            last_conv = None
+        elif isinstance(child, nn.Conv2d):
+            last_conv = child
+            last_conv_name = name
+        else:
+            fuse_conv_bn(child)
+    return module
+
+## replace module
+def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
+    """
+    Replace given type in module to a new type. mostly used in deploy.
+
+    Args:
+        module (nn.Module): model to apply replace operation.
+        replaced_module_type (Type): module type to be replaced.
+        new_module_type (Type)
+        replace_func (function): python function to describe replace logic. Defalut value None.
+
+    Returns:
+        model (nn.Module): module that already been replaced.
+    """
+
+    def default_replace_func(replaced_module_type, new_module_type):
+        return new_module_type()
+
+    if replace_func is None:
+        replace_func = default_replace_func
+
+    model = module
+    if isinstance(module, replaced_module_type):
+        model = replace_func(replaced_module_type, new_module_type)
+    else:  # recurrsively replace
+        for name, child in module.named_children():
+            new_child = replace_module(child, replaced_module_type, new_module_type)
+            if new_child is not child:  # child is already replaced
+                model.add_module(name, new_child)
+
+    return model
+
+## compute FLOPs & Parameters
+def compute_flops(model, img_size, device):
+    x = torch.randn(1, 3, img_size, img_size).to(device)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))
+
 ## load trained weight
-def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_repconv=False):
+def load_weight(model, path_to_ckpt, fuse_cbn=False):
     # check ckpt file
     if path_to_ckpt is None:
         print('no weight file ...')
@@ -155,17 +243,10 @@ def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_repconv=False):
 
         print('Finished loading model!')
 
-    # fuse repconv
-    if fuse_repconv:
-        print('Fusing RepConv block ...')
-        for m in model.modules():
-            if isinstance(m, RepConv):
-                m.fuse_repvgg_block()
-
     # fuse conv & bn
     if fuse_cbn:
         print('Fusing Conv & BN ...')
-        model = fuse_conv_bn.fuse_conv_bn(model)
+        model = fuse_conv_bn(model)
 
     return model
 
@@ -220,3 +301,149 @@ class ModelEMA(object):
     def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
         # Update EMA attributes
         self.copy_attr(self.ema, model, include, exclude)
+
+
+# ---------------------------- NMS ----------------------------
+## basic NMS
+def nms(bboxes, scores, nms_thresh):
+    """"Pure Python NMS."""
+    x1 = bboxes[:, 0]  #xmin
+    y1 = bboxes[:, 1]  #ymin
+    x2 = bboxes[:, 2]  #xmax
+    y2 = bboxes[:, 3]  #ymax
+
+    areas = (x2 - x1) * (y2 - y1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        # compute iou
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(1e-10, xx2 - xx1)
+        h = np.maximum(1e-10, yy2 - yy1)
+        inter = w * h
+
+        iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
+        #reserve all the boundingbox whose ovr less than thresh
+        inds = np.where(iou <= nms_thresh)[0]
+        order = order[inds + 1]
+
+    return keep
+
+## class-agnostic NMS 
+def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
+    # nms
+    keep = nms(bboxes, scores, nms_thresh)
+
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+## class-aware NMS 
+def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
+    # nms
+    keep = np.zeros(len(bboxes), dtype=np.int)
+    for i in range(num_classes):
+        inds = np.where(labels == i)[0]
+        if len(inds) == 0:
+            continue
+        c_bboxes = bboxes[inds]
+        c_scores = scores[inds]
+        c_keep = nms(c_bboxes, c_scores, nms_thresh)
+        keep[inds[c_keep]] = 1
+
+    keep = np.where(keep > 0)
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+## multi-class NMS 
+def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
+    if class_agnostic:
+        return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
+    else:
+        return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
+
+
+# ---------------------------- Processor for Deployment ----------------------------
+## Pre-processer
+class PreProcessor(object):
+    def __init__(self, img_size):
+        self.img_size = img_size
+        self.input_size = [img_size, img_size]
+        
+
+    def __call__(self, image, swap=(2, 0, 1)):
+        """
+        Input:
+            image: (ndarray) [H, W, 3] or [H, W]
+            formar: color format
+        """
+        if len(image.shape) == 3:
+            padded_img = np.ones((self.input_size[0], self.input_size[1], 3), np.float32) * 114.
+        else:
+            padded_img = np.ones(self.input_size, np.float32) * 114.
+        # resize
+        orig_h, orig_w = image.shape[:2]
+        r = min(self.input_size[0] / orig_h, self.input_size[1] / orig_w)
+        resize_size = (int(orig_w * r), int(orig_h * r))
+        if r != 1:
+            resized_img = cv2.resize(image, resize_size, interpolation=cv2.INTER_LINEAR)
+        else:
+            resized_img = image
+
+        # padding
+        padded_img[:resized_img.shape[0], :resized_img.shape[1]] = resized_img
+        
+        # [H, W, C] -> [C, H, W]
+        padded_img = padded_img.transpose(swap)
+        padded_img = np.ascontiguousarray(padded_img, dtype=np.float32)
+
+
+        return padded_img, r
+
+## Post-processer
+class PostProcessor(object):
+    def __init__(self, img_size, strides, num_classes, conf_thresh=0.15, nms_thresh=0.5):
+        self.img_size = img_size
+        self.strides = strides
+        self.num_classes = num_classes
+        self.conf_thresh = conf_thresh
+        self.nms_thresh = nms_thresh
+
+
+    def __call__(self, predictions):
+        """
+        Input:
+            predictions: (ndarray) [n_anchors_all, 4+1+C]
+        """
+        bboxes = predictions[..., :4]
+        obj_preds = predictions[..., 4:5]
+        cls_preds = predictions[..., 5:]
+        scores = np.sqrt(obj_preds * cls_preds)
+
+        # scores & labels
+        labels = np.argmax(scores, axis=1)                      # [M,]
+        scores = scores[(np.arange(scores.shape[0]), labels)]   # [M,]
+
+        # thresh
+        keep = np.where(scores > self.conf_thresh)
+        scores = scores[keep]
+        labels = labels[keep]
+        bboxes = bboxes[keep]
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, True)
+
+        return bboxes, scores, labels

+ 0 - 71
utils/nms.py

@@ -1,71 +0,0 @@
-import numpy as np
-
-
-def nms(bboxes, scores, nms_thresh):
-    """"Pure Python NMS."""
-    x1 = bboxes[:, 0]  #xmin
-    y1 = bboxes[:, 1]  #ymin
-    x2 = bboxes[:, 2]  #xmax
-    y2 = bboxes[:, 3]  #ymax
-
-    areas = (x2 - x1) * (y2 - y1)
-    order = scores.argsort()[::-1]
-
-    keep = []
-    while order.size > 0:
-        i = order[0]
-        keep.append(i)
-        # compute iou
-        xx1 = np.maximum(x1[i], x1[order[1:]])
-        yy1 = np.maximum(y1[i], y1[order[1:]])
-        xx2 = np.minimum(x2[i], x2[order[1:]])
-        yy2 = np.minimum(y2[i], y2[order[1:]])
-
-        w = np.maximum(1e-10, xx2 - xx1)
-        h = np.maximum(1e-10, yy2 - yy1)
-        inter = w * h
-
-        iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
-        #reserve all the boundingbox whose ovr less than thresh
-        inds = np.where(iou <= nms_thresh)[0]
-        order = order[inds + 1]
-
-    return keep
-
-
-def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
-    # nms
-    keep = nms(bboxes, scores, nms_thresh)
-
-    scores = scores[keep]
-    labels = labels[keep]
-    bboxes = bboxes[keep]
-
-    return scores, labels, bboxes
-
-
-def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
-    # nms
-    keep = np.zeros(len(bboxes), dtype=np.int)
-    for i in range(num_classes):
-        inds = np.where(labels == i)[0]
-        if len(inds) == 0:
-            continue
-        c_bboxes = bboxes[inds]
-        c_scores = scores[inds]
-        c_keep = nms(c_bboxes, c_scores, nms_thresh)
-        keep[inds[c_keep]] = 1
-
-    keep = np.where(keep > 0)
-    scores = scores[keep]
-    labels = labels[keep]
-    bboxes = bboxes[keep]
-
-    return scores, labels, bboxes
-
-
-def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
-    if class_agnostic:
-        return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
-    else:
-        return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)