浏览代码

modify loss of YOLOv1

yjh0410 2 年之前
父节点
当前提交
1d8f16cdaa

+ 1 - 1
config/yolov1_config.py

@@ -6,7 +6,7 @@ yolov1_cfg = {
     # loss weight
     'loss_obj_weight': 1.0,
     'loss_cls_weight': 1.0,
-    'loss_reg_weight': 5.0,
+    'loss_box_weight': 5.0,
     # training configuration
     'no_aug_epoch': -1,
     # optimizer

+ 2 - 4
dataset/coco.py

@@ -225,7 +225,7 @@ if __name__ == "__main__":
     import argparse
     from data_augment import build_transform
     
-    parser = argparse.ArgumentParser(description='VOC-Dataset')
+    parser = argparse.ArgumentParser(description='COCO-Dataset')
 
     # opt
     parser.add_argument('--root', default='D:\\python_work\\object-detection\\dataset\\COCO',
@@ -233,7 +233,7 @@ if __name__ == "__main__":
     
     args = parser.parse_args()
 
-    is_train = True
+    is_train = False
     img_size = 640
     yolov5_trans_config = {
         'aug_type': 'yolov5',
@@ -280,8 +280,6 @@ if __name__ == "__main__":
         image, target, deltas = dataset.pull_item(i)
         # to numpy
         image = image.permute(1, 2, 0).numpy()
-        # denormalize
-        image *= 255.
         # to uint8
         image = image.astype(np.uint8)
         image = image.copy()

+ 6 - 14
dataset/data_augment/ssd_augment.py

@@ -56,14 +56,6 @@ class ConvertFromInts(object):
         return image.astype(np.float32), boxes, labels
 
 
-class Normalize(object):
-    def __call__(self, image, boxes=None, labels=None):
-        image = image.astype(np.float32)
-        image /= 255.
-
-        return image, boxes, labels
-
-
 class ConvertColor(object):
     def __init__(self, current='BGR', transform='HSV'):
         self.transform = transform
@@ -363,13 +355,13 @@ class SSDAugmentation(object):
             Expand(),                                  # 扩充增强
             RandomSampleCrop(),                        # 随机剪裁
             RandomHorizontalFlip(),                    # 随机水平翻转
-            Resize(self.img_size),                     # resize操作
-            Normalize()                                # 图像颜色归一化
+            Resize(self.img_size)                      # resize操作
         ])
 
     def __call__(self, image, target, mosaic=False):
         boxes = target['boxes'].copy()
         labels = target['labels'].copy()
+        deltas = None
         # augment
         image, boxes, labels = self.augment(image, boxes, labels)
 
@@ -379,7 +371,7 @@ class SSDAugmentation(object):
         target['labels'] = torch.from_numpy(labels).float()
         
 
-        return img_tensor, target, None
+        return img_tensor, target, deltas
     
 
 ## SSD-style valTransform
@@ -388,12 +380,12 @@ class SSDBaseTransform(object):
         self.img_size = img_size
 
     def __call__(self, image, target=None, mosaic=False):
+        deltas = None
         # resize
         orig_h, orig_w = image.shape[:2]
         image = cv2.resize(image, (self.img_size, self.img_size)).astype(np.float32)
         
-        # normalize
-        image /= 255.
+        # scale targets
         if target is not None:
             boxes = target['boxes'].copy()
             labels = target['labels'].copy()
@@ -408,4 +400,4 @@ class SSDBaseTransform(object):
             target['boxes'] = torch.from_numpy(boxes).float()
             target['labels'] = torch.from_numpy(labels).float()
             
-        return img_tensor, target, None
+        return img_tensor, target, deltas

+ 0 - 6
dataset/data_augment/yolov5_augment.py

@@ -360,9 +360,6 @@ class YOLOv5Augmentation(object):
         dh = self.img_size - img_h0
         dw = self.img_size - img_w0
 
-        # normalize
-        pad_image /= 255.
-
         return pad_image, target, [dw, dh]
 
 
@@ -414,7 +411,4 @@ class YOLOv5BaseTransform(object):
         pad_image = torch.ones([img_tensor.size(0), pad_img_h, pad_img_w]).float() * 114.
         pad_image[:, :img_h0, :img_w0] = img_tensor
 
-        # normalize
-        pad_image /= 255.
-
         return pad_image, target, [dw, dh]

+ 1 - 3
dataset/voc.py

@@ -256,7 +256,7 @@ if __name__ == "__main__":
     
     args = parser.parse_args()
 
-    is_train = True
+    is_train = False
     img_size = 640
     yolov5_trans_config = {
         'aug_type': 'yolov5',
@@ -301,8 +301,6 @@ if __name__ == "__main__":
         image, target, deltas = dataset.pull_item(i)
         # to numpy
         image = image.permute(1, 2, 0).numpy()
-        # denormalize
-        image *= 255.
         # to uint8
         image = image.astype(np.uint8)
         image = image.copy()

+ 5 - 11
evaluator/coco_evaluator.py

@@ -2,8 +2,7 @@ import json
 import tempfile
 import torch
 from dataset.coco import COCODataset
-from utils.box_ops import rescale_bboxes, rescale_bboxes_with_deltas
-from dataset.data_augment import SSDBaseTransform, YOLOv5BaseTransform
+from utils.box_ops import rescale_bboxes
 
 try:
     from pycocotools.cocoeval import COCOeval
@@ -71,7 +70,7 @@ class COCOAPIEvaluator():
 
             # preprocess
             x, _, deltas = self.transform(img)
-            x = x.unsqueeze(0).to(self.device)
+            x = x.unsqueeze(0).to(self.device) / 255.
             
             id_ = int(id_)
             ids.append(id_)
@@ -81,14 +80,9 @@ class COCOAPIEvaluator():
             bboxes, scores, cls_inds = outputs
 
             # rescale bboxes
-            if isinstance(self.transform, SSDBaseTransform):
-                origin_img_size = [orig_h, orig_w]
-                cur_img_size = [*x.shape[-2:]]
-                bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size)
-            elif isinstance(self.transform, YOLOv5BaseTransform):
-                origin_img_size = [orig_h, orig_w]
-                cur_img_size = [*x.shape[-2:]]
-                bboxes = rescale_bboxes_with_deltas(bboxes, deltas, origin_img_size, cur_img_size)
+            origin_img_size = [orig_h, orig_w]
+            cur_img_size = [*x.shape[-2:]]
+            bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
 
             # process outputs
             for i, box in enumerate(bboxes):

+ 5 - 11
evaluator/voc_evaluator.py

@@ -11,8 +11,7 @@ import numpy as np
 import pickle
 import xml.etree.ElementTree as ET
 
-from utils.box_ops import rescale_bboxes, rescale_bboxes_with_deltas
-from dataset.data_augment import SSDBaseTransform, YOLOv5BaseTransform
+from utils.box_ops import rescale_bboxes
 
 
 class VOCAPIEvaluator():
@@ -67,7 +66,7 @@ class VOCAPIEvaluator():
 
             # preprocess
             x, _, deltas = self.transform(img)
-            x = x.unsqueeze(0).to(self.device)
+            x = x.unsqueeze(0).to(self.device) / 255.
 
             # forward
             t0 = time.time()
@@ -75,14 +74,9 @@ class VOCAPIEvaluator():
             detect_time = time.time() - t0
 
             # rescale bboxes
-            if isinstance(self.transform, SSDBaseTransform):
-                origin_img_size = [orig_h, orig_w]
-                cur_img_size = [*x.shape[-2:]]
-                bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size)
-            elif isinstance(self.transform, YOLOv5BaseTransform):
-                origin_img_size = [orig_h, orig_w]
-                cur_img_size = [*x.shape[-2:]]
-                bboxes = rescale_bboxes_with_deltas(bboxes, deltas, origin_img_size, cur_img_size)
+            origin_img_size = [orig_h, orig_w]
+            cur_img_size = [*x.shape[-2:]]
+            bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
 
             for j in range(len(self.labelmap)):
                 inds = np.where(labels == j)[0]

+ 40 - 63
models/yolov1/loss.py

@@ -1,6 +1,8 @@
 import torch
 import torch.nn.functional as F
 from .matcher import YoloMatcher
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
 
 class Criterion(object):
@@ -10,112 +12,87 @@ class Criterion(object):
         self.num_classes = num_classes
         self.loss_obj_weight = cfg['loss_obj_weight']
         self.loss_cls_weight = cfg['loss_cls_weight']
-        self.loss_txty_weight = cfg['loss_txty_weight']
-        self.loss_twth_weight = cfg['loss_twth_weight']
+        self.loss_box_weight = cfg['loss_box_weight']
 
         # matcher
         self.matcher = YoloMatcher(num_classes=num_classes)
 
 
     def loss_objectness(self, pred_obj, gt_obj):
-        obj_score = torch.clamp(torch.sigmoid(pred_obj), min=1e-4, max=1.0 - 1e-4)
-        # obj loss
-        pos_id = (gt_obj==1.0).float()
-        pos_loss = pos_id * (obj_score - gt_obj)**2
-
-        # noobj loss
-        neg_id = (gt_obj==0.0).float()
-        neg_loss = neg_id * (obj_score)**2
-
-        # total loss
-        loss_obj = 5.0 * pos_loss + 1.0 * neg_loss
+        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
 
         return loss_obj
     
 
-    def loss_labels(self, pred_cls, gt_label):
-        loss_cls = F.cross_entropy(pred_cls, gt_label, reduction='none')
+    def loss_classes(self, pred_cls, gt_label):
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
 
         return loss_cls
 
 
-    def loss_txty(self, pred_txty, gt_txty, gt_box_weight):
-        # txty loss
-        loss_txty = F.binary_cross_entropy_with_logits(
-            pred_txty, gt_txty, reduction='none').sum(-1)
-        loss_txty *= gt_box_weight
-
-        return loss_txty
+    def loss_bboxes(self, pred_box, gt_box):
+        # regression loss
+        ious = get_ious(pred_box,
+                        gt_box,
+                        box_mode="xyxy",
+                        iou_type='giou')
+        loss_box = 1.0 - ious
 
-
-    def loss_twth(self, pred_twth, gt_twth, gt_box_weight):
-        # twth loss
-        loss_twth = F.mse_loss(pred_twth, gt_twth, reduction='none').sum(-1)
-        loss_twth *= gt_box_weight
-
-        return loss_twth
+        return loss_box
 
 
     def __call__(self, outputs, targets):
         device = outputs['pred_cls'][0].device
         stride = outputs['stride']
-        img_size = outputs['img_size']
+        fmp_size = outputs['fmp_size']
         (
             gt_objectness, 
-            gt_labels, 
+            gt_classes, 
             gt_bboxes,
-            gt_box_weight
-            ) = self.matcher(img_size=img_size, 
+            ) = self.matcher(fmp_size=fmp_size, 
                              stride=stride, 
                              targets=targets)
         # List[B, M, C] -> [B, M, C] -> [BM, C]
-        batch_size = outputs['pred_obj'].shape[0]
         pred_obj = outputs['pred_obj'].view(-1)                     # [BM,]
         pred_cls = outputs['pred_cls'].view(-1, self.num_classes)   # [BM, C]
-        pred_txty = outputs['pred_txty'].view(-1, 2)                # [BM, 2]
-        pred_twth = outputs['pred_twth'].view(-1, 2)                # [BM, 2]
+        pred_box = outputs['pred_box'].view(-1, 4)                  # [BM, 4]
        
-        gt_objectness = gt_objectness.view(-1).to(device).float()   # [BM,]
-        gt_labels = gt_labels.view(-1).to(device).long()            # [BM,]
-        gt_bboxes = gt_bboxes.view(-1, 4).to(device).float()        # [BM, 4]
-        gt_box_weight = gt_box_weight.view(-1).to(device).float()   # [BM,]
+        gt_objectness = gt_objectness.view(-1).to(device).float()               # [BM,]
+        gt_classes = gt_classes.view(-1, self.num_classes).to(device).float()   # [BM, C]
+        gt_bboxes = gt_bboxes.view(-1, 4).to(device).float()                    # [BM, 4]
 
         pos_masks = (gt_objectness > 0)
+        num_fgs = pos_masks.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
 
-        # objectness loss
+        # obj loss
         loss_obj = self.loss_objectness(pred_obj, gt_objectness)
-        loss_obj = loss_obj.sum() / batch_size
+        loss_obj = loss_obj.sum() / num_fgs
 
-        # classification loss
+        # cls loss
         pred_cls_pos = pred_cls[pos_masks]
-        gt_labels_pos = gt_labels[pos_masks]
-        loss_cls = self.loss_labels(pred_cls_pos, gt_labels_pos)
-        loss_cls = loss_cls.sum() / batch_size
-
-        # txty loss
-        pred_txty_pos = pred_txty[pos_masks]
-        gt_txty_pos = gt_bboxes[pos_masks][..., :2]
-        gt_box_weight_pos = gt_box_weight[pos_masks]
-        loss_txty = self.loss_txty(pred_txty_pos, gt_txty_pos, gt_box_weight_pos)
-        loss_txty = loss_txty.sum() / batch_size
+        gt_classes_pos = gt_classes[pos_masks]
+        loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # box loss
+        pred_box_pos = pred_box[pos_masks]
+        gt_bboxes_pos = gt_bboxes[pos_masks]
+        loss_box = self.loss_bboxes(pred_box_pos, gt_bboxes_pos)
+        loss_box = loss_box.sum() / num_fgs
         
-        # twth loss
-        pred_twth_pos = pred_twth[pos_masks]
-        gt_twth_pos = gt_bboxes[pos_masks][..., 2:]
-        loss_twth = self.loss_twth(pred_twth_pos, gt_twth_pos, gt_box_weight_pos)
-        loss_twth = loss_twth.sum() / batch_size
-
         # total loss
         losses = self.loss_obj_weight * loss_obj + \
                  self.loss_cls_weight * loss_cls + \
-                 self.loss_txty_weight * loss_txty + \
-                 self.loss_twth_weight * loss_twth
+                 self.loss_box_weight * loss_box
 
         loss_dict = dict(
                 loss_obj = loss_obj,
                 loss_cls = loss_cls,
-                loss_txty = loss_txty,
-                loss_twth = loss_twth,
+                loss_box = loss_box,
                 losses = losses
         )
 

+ 22 - 43
models/yolov1/matcher.py

@@ -7,36 +7,8 @@ class YoloMatcher(object):
         self.num_classes = num_classes
 
 
-    def generate_dxdywh(self, gt_box, img_size, stride):
-        x1, y1, x2, y2 = gt_box
-        # xyxy -> cxcywh
-        xc, yc = (x2 + x1) * 0.5, (y2 + y1) * 0.5
-        bw, bh = x2 - x1, y2 - y1
-
-        # 检查数据的有效性
-        if bw < 1. or bh < 1.:
-            return False    
-
-        # 计算中心点所在的网格坐标
-        xs_c = xc / stride
-        ys_c = yc / stride
-        grid_x = int(xs_c)
-        grid_y = int(ys_c)
-
-        # 计算中心点偏移量和宽高的标签
-        tx = xs_c - grid_x
-        ty = ys_c - grid_y
-        tw = np.log(bw)
-        th = np.log(bh)
-
-        # 计算边界框位置参数的损失权重
-        weight = 2.0 - (bh / img_size[0]) * (bw / img_size[1])
-
-        return grid_x, grid_y, tx, ty, tw, th, weight
-
-
     @torch.no_grad()
-    def __call__(self, img_size, stride, targets):
+    def __call__(self, fmp_size, stride, targets):
         """
             img_size: (Int) input image size
             stride: (Int) -> stride of YOLOv1 output.
@@ -46,11 +18,10 @@ class YoloMatcher(object):
         """
         # prepare
         bs = len(targets)
-        fmp_h, fmp_w = img_size[0] // stride, img_size[1] // stride
+        fmp_h, fmp_w = fmp_size
         gt_objectness = np.zeros([bs, fmp_h, fmp_w, 1]) 
-        gt_labels = np.zeros([bs, fmp_h, fmp_w, 1]) 
+        gt_classes = np.zeros([bs, fmp_h, fmp_w, self.num_classes]) 
         gt_bboxes = np.zeros([bs, fmp_h, fmp_w, 4])
-        gt_box_weight = np.zeros([bs, fmp_h, fmp_w, 1])
 
         for batch_index in range(bs):
             targets_per_image = targets[batch_index]
@@ -60,26 +31,34 @@ class YoloMatcher(object):
             tgt_box = targets_per_image['boxes'].numpy()
 
             for gt_box, gt_label in zip(tgt_box, tgt_cls):
-                result = self.generate_dxdywh(gt_box, img_size, stride)
-                if result:
-                    grid_x, grid_y, tx, ty, tw, th, weight = result
+                x1, y1, x2, y2 = gt_box
+                # xyxy -> cxcywh
+                xc, yc = (x2 + x1) * 0.5, (y2 + y1) * 0.5
+                bw, bh = x2 - x1, y2 - y1
+
+                # check
+                if bw < 1. or bh < 1.:
+                    return False    
+
+                # grid
+                xs_c = xc / stride
+                ys_c = yc / stride
+                grid_x = int(xs_c)
+                grid_y = int(ys_c)
 
                 if grid_x < fmp_w and grid_y < fmp_h:
                     gt_objectness[batch_index, grid_y, grid_x] = 1.0
-                    gt_labels[batch_index, grid_y, grid_x] = gt_label
-                    gt_bboxes[batch_index, grid_y, grid_x] = np.array([tx, ty, tw, th])
-                    gt_box_weight[batch_index, grid_y, grid_x] = weight
+                    gt_classes[batch_index, grid_y, grid_x, int(gt_label)] = 1.0
+                    gt_bboxes[batch_index, grid_y, grid_x] = np.array([x1, y1, x2, y2])
 
         # [B, M, C]
         gt_objectness = gt_objectness.reshape(bs, -1, 1)
-        gt_labels = gt_labels.reshape(bs, -1, 1)
+        gt_classes = gt_classes.reshape(bs, -1, self.num_classes)
         gt_bboxes = gt_bboxes.reshape(bs, -1, 4)
-        gt_box_weight = gt_box_weight.reshape(bs, -1, 1)
 
         # to tensor
         gt_objectness = torch.from_numpy(gt_objectness).float()
-        gt_labels = torch.from_numpy(gt_labels).long()
+        gt_classes = torch.from_numpy(gt_classes).float()
         gt_bboxes = torch.from_numpy(gt_bboxes).float()
-        gt_box_weight = torch.from_numpy(gt_box_weight).float()
 
-        return gt_objectness, gt_labels, gt_bboxes, gt_box_weight
+        return gt_objectness, gt_classes, gt_bboxes

+ 26 - 76
models/yolov1/yolov1.py

@@ -2,6 +2,8 @@ import torch
 import torch.nn as nn
 import numpy as np
 
+from utils.nms import multiclass_nms
+
 from .yolov1_basic import Conv
 from .yolov1_neck import SPP
 from .yolov1_backbone import build_resnet
@@ -79,14 +81,6 @@ class YOLOv1(nn.Module):
         return grid_xy
 
 
-    def set_grid(self, img_size):
-        """
-            用于重置G矩阵。
-        """
-        self.img_size = img_size
-        self.grid_cell = self.create_grid(img_size)
-
-
     def decode_boxes(self, pred, fmp_size):
         """
             将txtytwth转换为常用的x1y1x2y2形式。
@@ -95,52 +89,17 @@ class YOLOv1(nn.Module):
         grid_cell = self.create_grid(fmp_size)
 
         # 计算预测边界框的中心点坐标和宽高
-        pred[..., :2] = torch.sigmoid(pred[..., :2]) + grid_cell
+        pred[..., :2] = (torch.sigmoid(pred[..., :2]) + grid_cell) * self.stride
         pred[..., 2:] = torch.exp(pred[..., 2:])
 
         # 将所有bbox的中心带你坐标和宽高换算成x1y1x2y2形式
         output = torch.zeros_like(pred)
-        output[..., :2] = pred[..., :2] * self.stride - pred[..., 2:] * 0.5
-        output[..., 2:] = pred[..., :2] * self.stride + pred[..., 2:] * 0.5
+        output[..., :2] = pred[..., :2] - pred[..., 2:] * 0.5
+        output[..., 2:] = pred[..., :2] + pred[..., 2:] * 0.5
         
         return output
 
 
-    def nms(self, bboxes, scores):
-        """"Pure Python NMS baseline."""
-        x1 = bboxes[:, 0]  #xmin
-        y1 = bboxes[:, 1]  #ymin
-        x2 = bboxes[:, 2]  #xmax
-        y2 = bboxes[:, 3]  #ymax
-
-        areas = (x2 - x1) * (y2 - y1)
-        order = scores.argsort()[::-1]
-        
-        keep = []                                             
-        while order.size > 0:
-            i = order[0]
-            keep.append(i)
-            # 计算交集的左上角点和右下角点的坐标
-            xx1 = np.maximum(x1[i], x1[order[1:]])
-            yy1 = np.maximum(y1[i], y1[order[1:]])
-            xx2 = np.minimum(x2[i], x2[order[1:]])
-            yy2 = np.minimum(y2[i], y2[order[1:]])
-            # 计算交集的宽高
-            w = np.maximum(1e-10, xx2 - xx1)
-            h = np.maximum(1e-10, yy2 - yy1)
-            # 计算交集的面积
-            inter = w * h
-
-            # 计算交并比
-            iou = inter / (areas[i] + areas[order[1:]] - inter)
-
-            # 滤除超过nms阈值的检测框
-            inds = np.where(iou <= self.nms_thresh)[0]
-            order = order[inds + 1]
-
-        return keep
-
-
     def postprocess(self, bboxes, scores):
         """
         Input:
@@ -161,21 +120,9 @@ class YOLOv1(nn.Module):
         scores = scores[keep]
         labels = labels[keep]
 
-        # NMS
-        keep = np.zeros(len(bboxes), dtype=np.int)
-        for i in range(self.num_classes):
-            inds = np.where(labels == i)[0]
-            if len(inds) == 0:
-                continue
-            c_bboxes = bboxes[inds]
-            c_scores = scores[inds]
-            c_keep = self.nms(c_bboxes, c_scores)
-            keep[inds[c_keep]] = 1
-
-        keep = np.where(keep > 0)
-        bboxes = bboxes[keep]
-        scores = scores[keep]
-        labels = labels[keep]
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
 
         return bboxes, scores, labels
 
@@ -201,27 +148,27 @@ class YOLOv1(nn.Module):
 
         # 从pred中分离出objectness预测、类别class预测、bbox的txtytwth预测  
         # [B, H*W, 1]
-        conf_pred = pred[..., :1]
+        obj_pred = pred[..., :1]
         # [B, H*W, num_cls]
         cls_pred = pred[..., 1:1+self.num_classes]
         # [B, H*W, 4]
-        txtytwth_pred = pred[..., 1+self.num_classes:]
+        reg_pred = pred[..., 1+self.num_classes:]
 
         # 测试时,笔者默认batch是1,
         # 因此,我们不需要用batch这个维度,用[0]将其取走。
-        conf_pred = conf_pred[0]            #[H*W, 1]
-        cls_pred = cls_pred[0]              #[H*W, NC]
-        txtytwth_pred = txtytwth_pred[0]    #[H*W, 4]
+        obj_pred = obj_pred[0]       # [H*W, 1]
+        cls_pred = cls_pred[0]       # [H*W, NC]
+        reg_pred = reg_pred[0]       # [H*W, 4]
 
         # 每个边界框的得分
-        scores = torch.sigmoid(conf_pred) * torch.softmax(cls_pred, dim=-1)
+        scores = torch.sqrt(obj_pred.sigmoid() * cls_pred.sigmoid())
         
         # 解算边界框, 并归一化边界框: [H*W, 4]
-        bboxes = self.decode_boxes(txtytwth_pred, fmp_size)
+        bboxes = self.decode_boxes(reg_pred, fmp_size)
         
         # 将预测放在cpu处理上,以便进行后处理
-        scores = scores.to('cpu').numpy()
-        bboxes = bboxes.to('cpu').numpy()
+        scores = scores.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
         
         # 后处理
         bboxes, scores, labels = self.postprocess(bboxes, scores)
@@ -244,6 +191,7 @@ class YOLOv1(nn.Module):
 
             # 预测层
             pred = self.pred(feat)
+            fmp_size = pred.shape[-2:]
 
             # 对pred 的size做一些view调整,便于后续的处理
             # [B, C, H, W] -> [B, H, W, C] -> [B, H*W, C]
@@ -251,19 +199,21 @@ class YOLOv1(nn.Module):
 
             # 从pred中分离出objectness预测、类别class预测、bbox的txtytwth预测  
             # [B, H*W, 1]
-            conf_pred = pred[..., :1]
+            obj_pred = pred[..., :1]
             # [B, H*W, num_cls]
             cls_pred = pred[..., 1:1+self.num_classes]
             # [B, H*W, 4]
-            txtytwth_pred = pred[..., 1+self.num_classes:]
+            reg_pred = pred[..., 1+self.num_classes:]
+
+            # decode bbox
+            box_pred = self.decode_boxes(reg_pred, fmp_size)
 
             # 网络输出
-            outputs = {"pred_obj": conf_pred,                  # (Tensor) [B, M, 1]
+            outputs = {"pred_obj": obj_pred,                  # (Tensor) [B, M, 1]
                        "pred_cls": cls_pred,                   # (Tensor) [B, M, C]
-                       "pred_txty": txtytwth_pred[..., :2],    # (Tensor) [B, M, 2]
-                       "pred_twth": txtytwth_pred[..., 2:],    # (Tensor) [B, M, 2]
+                       "pred_box": box_pred,                   # (Tensor) [B, M, 4]
                        "stride": self.stride,                  # (Int)
-                       "img_size": x.shape[-2:]                # (List) [img_h, img_w]
+                       "fmp_size": fmp_size                    # (List) [fmp_h, fmp_w]
                        }           
             return outputs
         

+ 5 - 12
test.py

@@ -12,9 +12,8 @@ from dataset.data_augment import build_transform
 # load some utils
 from utils.misc import build_dataset, load_weight
 from utils.com_flops_params import FLOPs_and_Params
+from utils.box_ops import rescale_bboxes
 from utils import fuse_conv_bn
-from utils.box_ops import rescale_bboxes, rescale_bboxes_with_deltas
-from dataset.data_augment import SSDBaseTransform, YOLOv5BaseTransform
 
 from models import build_model
 from config import build_model_config, build_trans_config
@@ -132,7 +131,7 @@ def test(args,
 
         # prepare
         x, _, deltas = transforms(image)
-        x = x.unsqueeze(0).to(device)
+        x = x.unsqueeze(0).to(device) / 255.
 
         t0 = time.time()
         # inference
@@ -140,15 +139,9 @@ def test(args,
         print("detection time used ", time.time() - t0, "s")
         
         # rescale bboxes
-        if isinstance(transform, SSDBaseTransform):
-            origin_img_size = [orig_h, orig_w]
-            cur_img_size = [*x.shape[-2:]]
-            bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size)
-        elif isinstance(transform, YOLOv5BaseTransform):
-            origin_img_size = [orig_h, orig_w]
-            cur_img_size = x.shape[-2:]
-            print(origin_img_size, cur_img_size, deltas)
-            bboxes = rescale_bboxes_with_deltas(bboxes, deltas, origin_img_size, cur_img_size)
+        origin_img_size = [orig_h, orig_w]
+        cur_img_size = [*x.shape[-2:]]
+        bboxes = rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas)
 
         # vis detection
         img_processed = visualize(

+ 0 - 15
train.sh

@@ -16,18 +16,3 @@ python train.py \
         # --resume weights/coco/yolo_free_vx_pico/yolo_free_vx_pico_epoch_41_20.46.pth \
         # --pretrained weights/coco/yolo_free_medium/yolo_free_medium_39.46.pth \
         # --eval_first
-
-# # Debug FreeYOLO on VOC
-# python train.py \
-#         --cuda \
-#         -d voc \
-#         --root /mnt/share/ssd2/dataset/ \
-#         -v yolo_free_v2_tiny \
-#         -bs 16 \
-#         --max_epoch 25 \
-#         --wp_epoch 1 \
-#         --eval_epoch 5 \
-#         --ema \
-#         --fp16 \
-#         # --resume weights/coco/yolo_free_medium/yolo_free_medium_epoch_31_39.46.pth \
-#         # --pretrained weights/coco/yolo_free_medium/yolo_free_medium_39.46.pth \

+ 17 - 22
utils/box_ops.py

@@ -74,30 +74,25 @@ def get_ious(bboxes1,
         raise NotImplementedError
 
 
-def rescale_bboxes(bboxes, origin_img_size, cur_img_size):
+def rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas=None):
     origin_h, origin_w = origin_img_size
     cur_img_h, cur_img_w = cur_img_size
-    # rescale
-    bboxes[..., [0, 2]] = bboxes[..., [0, 2]] / cur_img_w * origin_w
-    bboxes[..., [1, 3]] = bboxes[..., [1, 3]] / cur_img_h * origin_h
-
-    # clip bboxes
-    bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_w)
-    bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_h)
-
-    return bboxes
-
-
-def rescale_bboxes_with_deltas(bboxes, deltas, origin_img_size, cur_img_size):
-    origin_h, origin_w = origin_img_size
-    cur_img_h, cur_img_w = cur_img_size
-    # rescale
-    bboxes[..., [0, 2]] = bboxes[..., [0, 2]] / (cur_img_w - deltas[0]) * origin_w
-    bboxes[..., [1, 3]] = bboxes[..., [1, 3]] / (cur_img_h - deltas[1]) * origin_h
-    
-    # clip bboxes
-    bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_w)
-    bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_h)
+    if deltas is None:
+        # rescale
+        bboxes[..., [0, 2]] = bboxes[..., [0, 2]] / cur_img_w * origin_w
+        bboxes[..., [1, 3]] = bboxes[..., [1, 3]] / cur_img_h * origin_h
+
+        # clip bboxes
+        bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_w)
+        bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_h)
+    else:
+        # rescale
+        bboxes[..., [0, 2]] = bboxes[..., [0, 2]] / (cur_img_w - deltas[0]) * origin_w
+        bboxes[..., [1, 3]] = bboxes[..., [1, 3]] / (cur_img_h - deltas[1]) * origin_h
+        
+        # clip bboxes
+        bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_w)
+        bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_h)
 
     return bboxes
 

+ 71 - 0
utils/nms.py

@@ -0,0 +1,71 @@
+import numpy as np
+
+
+def nms(bboxes, scores, nms_thresh):
+    """"Pure Python NMS."""
+    x1 = bboxes[:, 0]  #xmin
+    y1 = bboxes[:, 1]  #ymin
+    x2 = bboxes[:, 2]  #xmax
+    y2 = bboxes[:, 3]  #ymax
+
+    areas = (x2 - x1) * (y2 - y1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        # compute iou
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(1e-10, xx2 - xx1)
+        h = np.maximum(1e-10, yy2 - yy1)
+        inter = w * h
+
+        iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
+        #reserve all the boundingbox whose ovr less than thresh
+        inds = np.where(iou <= nms_thresh)[0]
+        order = order[inds + 1]
+
+    return keep
+
+
+def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
+    # nms
+    keep = nms(bboxes, scores, nms_thresh)
+
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+
+def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
+    # nms
+    keep = np.zeros(len(bboxes), dtype=np.int)
+    for i in range(num_classes):
+        inds = np.where(labels == i)[0]
+        if len(inds) == 0:
+            continue
+        c_bboxes = bboxes[inds]
+        c_scores = scores[inds]
+        c_keep = nms(c_bboxes, c_scores, nms_thresh)
+        keep[inds[c_keep]] = 1
+
+    keep = np.where(keep > 0)
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+
+def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
+    if class_agnostic:
+        return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
+    else:
+        return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)