|
|
@@ -4,14 +4,15 @@ import numpy as np
|
|
|
|
|
|
from utils.nms import multiclass_nms
|
|
|
|
|
|
-from .yolov1_basic import Conv
|
|
|
-from .yolov1_neck import SPP
|
|
|
-from .yolov1_backbone import build_resnet
|
|
|
+from .yolov1_backbone import build_backbone
|
|
|
+from .yolov1_neck import build_neck
|
|
|
+from .yolov1_head import build_head
|
|
|
|
|
|
|
|
|
# YOLOv1
|
|
|
class YOLOv1(nn.Module):
|
|
|
def __init__(self,
|
|
|
+ cfg,
|
|
|
device,
|
|
|
img_size=None,
|
|
|
num_classes=20,
|
|
|
@@ -20,6 +21,7 @@ class YOLOv1(nn.Module):
|
|
|
trainable=False):
|
|
|
super(YOLOv1, self).__init__()
|
|
|
# ------------------- Basic parameters -------------------
|
|
|
+ self.cfg = cfg # 模型配置文件
|
|
|
self.img_size = img_size # 输入图像大小
|
|
|
self.device = device # cuda或者是cpu
|
|
|
self.num_classes = num_classes # 类别的数量
|
|
|
@@ -30,24 +32,20 @@ class YOLOv1(nn.Module):
|
|
|
|
|
|
# ------------------- Network Structure -------------------
|
|
|
## backbone: resnet18
|
|
|
- self.backbone, feat_dim = build_resnet('resnet18', pretrained=trainable)
|
|
|
+ self.backbone, feat_dim = build_backbone(
|
|
|
+ cfg['backbone'], trainable&cfg['pretrained'])
|
|
|
|
|
|
## neck: SPP
|
|
|
- self.neck = nn.Sequential(
|
|
|
- SPP(),
|
|
|
- Conv(feat_dim*4, feat_dim, k=1),
|
|
|
- )
|
|
|
+ self.neck = build_neck(cfg, feat_dim, out_dim=256)
|
|
|
+ head_dim = self.neck.out_dim
|
|
|
|
|
|
## head
|
|
|
- self.convsets = nn.Sequential(
|
|
|
- Conv(feat_dim, feat_dim//2, k=1),
|
|
|
- Conv(feat_dim//2, feat_dim, k=3, p=1),
|
|
|
- Conv(feat_dim, feat_dim//2, k=1),
|
|
|
- Conv(feat_dim//2, feat_dim, k=3, p=1)
|
|
|
- )
|
|
|
+ self.head = build_head(cfg, head_dim, head_dim, num_classes)
|
|
|
|
|
|
## pred
|
|
|
- self.pred = nn.Conv2d(feat_dim, 1 + self.num_classes + 4, 1)
|
|
|
+ self.obj_pred = nn.Conv2d(head_dim, 1, kernel_size=1)
|
|
|
+ self.cls_pred = nn.Conv2d(head_dim, num_classes, kernel_size=1)
|
|
|
+ self.reg_pred = nn.Conv2d(head_dim, 4, kernel_size=1)
|
|
|
|
|
|
|
|
|
if self.trainable:
|
|
|
@@ -58,8 +56,8 @@ class YOLOv1(nn.Module):
|
|
|
# init bias
|
|
|
init_prob = 0.01
|
|
|
bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
|
|
|
- nn.init.constant_(self.pred.bias[..., :1], bias_value)
|
|
|
- nn.init.constant_(self.pred.bias[..., 1:1+self.num_classes], bias_value)
|
|
|
+ nn.init.constant_(self.obj_pred.bias, bias_value)
|
|
|
+ nn.init.constant_(self.cls_pred.bias, bias_value)
|
|
|
|
|
|
|
|
|
def create_grid(self, fmp_size):
|
|
|
@@ -90,7 +88,7 @@ class YOLOv1(nn.Module):
|
|
|
|
|
|
# 计算预测边界框的中心点坐标和宽高
|
|
|
pred_ctr = (torch.sigmoid(pred[..., :2]) + grid_cell) * self.stride
|
|
|
- pred_wh = torch.exp(pred[..., 2:])
|
|
|
+ pred_wh = torch.exp(pred[..., 2:]) * self.stride
|
|
|
|
|
|
# 将所有bbox的中心带你坐标和宽高换算成x1y1x2y2形式
|
|
|
pred_x1y1 = pred_ctr - pred_wh * 0.5
|
|
|
@@ -129,30 +127,26 @@ class YOLOv1(nn.Module):
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def inference(self, x):
|
|
|
- # backbone主干网络
|
|
|
+ # 主干网络
|
|
|
feat = self.backbone(x)
|
|
|
|
|
|
- # neck网络
|
|
|
+ # 颈部网络
|
|
|
feat = self.neck(feat)
|
|
|
|
|
|
- # detection head网络
|
|
|
- feat = self.convsets(feat)
|
|
|
+ # 检测头
|
|
|
+ cls_feat, reg_feat = self.head(feat)
|
|
|
|
|
|
# 预测层
|
|
|
- pred = self.pred(feat)
|
|
|
- fmp_size = pred.shape[-2:]
|
|
|
+ obj_pred = self.obj_pred(cls_feat)
|
|
|
+ cls_pred = self.cls_pred(cls_feat)
|
|
|
+ reg_pred = self.reg_pred(reg_feat)
|
|
|
+ fmp_size = obj_pred.shape[-2:]
|
|
|
|
|
|
- # 对pred 的size做一些view调整,便于后续的处理
|
|
|
+ # 对 pred 的size做一些view调整,便于后续的处理
|
|
|
# [B, C, H, W] -> [B, H, W, C] -> [B, H*W, C]
|
|
|
- pred = pred.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
|
|
|
-
|
|
|
- # 从pred中分离出objectness预测、类别class预测、bbox的txtytwth预测
|
|
|
- # [B, H*W, 1]
|
|
|
- obj_pred = pred[..., :1]
|
|
|
- # [B, H*W, num_cls]
|
|
|
- cls_pred = pred[..., 1:1+self.num_classes]
|
|
|
- # [B, H*W, 4]
|
|
|
- reg_pred = pred[..., 1+self.num_classes:]
|
|
|
+ obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
|
|
|
+ cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
|
|
|
+ reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
|
|
|
|
|
|
# 测试时,笔者默认batch是1,
|
|
|
# 因此,我们不需要用batch这个维度,用[0]将其取走。
|
|
|
@@ -180,30 +174,26 @@ class YOLOv1(nn.Module):
|
|
|
if not self.trainable:
|
|
|
return self.inference(x)
|
|
|
else:
|
|
|
- # backbone主干网络
|
|
|
+ # 主干网络
|
|
|
feat = self.backbone(x)
|
|
|
|
|
|
- # neck网络
|
|
|
+ # 颈部网络
|
|
|
feat = self.neck(feat)
|
|
|
|
|
|
- # detection head网络
|
|
|
- feat = self.convsets(feat)
|
|
|
+ # 检测头
|
|
|
+ cls_feat, reg_feat = self.head(feat)
|
|
|
|
|
|
# 预测层
|
|
|
- pred = self.pred(feat)
|
|
|
- fmp_size = pred.shape[-2:]
|
|
|
+ obj_pred = self.obj_pred(cls_feat)
|
|
|
+ cls_pred = self.cls_pred(cls_feat)
|
|
|
+ reg_pred = self.reg_pred(reg_feat)
|
|
|
+ fmp_size = obj_pred.shape[-2:]
|
|
|
|
|
|
- # 对pred 的size做一些view调整,便于后续的处理
|
|
|
+ # 对 pred 的size做一些view调整,便于后续的处理
|
|
|
# [B, C, H, W] -> [B, H, W, C] -> [B, H*W, C]
|
|
|
- pred = pred.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
|
|
|
-
|
|
|
- # 从pred中分离出objectness预测、类别class预测、bbox的txtytwth预测
|
|
|
- # [B, H*W, 1]
|
|
|
- obj_pred = pred[..., :1]
|
|
|
- # [B, H*W, num_cls]
|
|
|
- cls_pred = pred[..., 1:1+self.num_classes]
|
|
|
- # [B, H*W, 4]
|
|
|
- reg_pred = pred[..., 1+self.num_classes:]
|
|
|
+ obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
|
|
|
+ cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
|
|
|
+ reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
|
|
|
|
|
|
# decode bbox
|
|
|
box_pred = self.decode_boxes(reg_pred, fmp_size)
|