|
|
@@ -19,7 +19,8 @@ class YOLOv2(nn.Module):
|
|
|
nms_thresh=0.5,
|
|
|
topk=100,
|
|
|
trainable=False,
|
|
|
- deploy=False):
|
|
|
+ deploy=False,
|
|
|
+ nms_class_agnostic=False):
|
|
|
super(YOLOv2, self).__init__()
|
|
|
# ------------------- Basic parameters -------------------
|
|
|
self.cfg = cfg # 模型配置文件
|
|
|
@@ -31,6 +32,7 @@ class YOLOv2(nn.Module):
|
|
|
self.topk = topk # topk
|
|
|
self.stride = 32 # 网络的最大步长
|
|
|
self.deploy = deploy
|
|
|
+ self.nms_class_agnostic = nms_class_agnostic
|
|
|
# ------------------- Anchor box -------------------
|
|
|
self.anchor_size = torch.as_tensor(cfg['anchor_size']).float().view(-1, 2) # [A, 2]
|
|
|
self.num_anchors = self.anchor_size.shape[0]
|
|
|
@@ -143,7 +145,7 @@ class YOLOv2(nn.Module):
|
|
|
|
|
|
# nms
|
|
|
scores, labels, bboxes = multiclass_nms(
|
|
|
- scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
|
|
|
+ scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
|
|
|
|
|
|
return bboxes, scores, labels
|
|
|
|