|
|
@@ -18,12 +18,13 @@ class RTCDet(nn.Module):
|
|
|
def __init__(self,
|
|
|
cfg,
|
|
|
device,
|
|
|
- num_classes :int = 20,
|
|
|
- conf_thresh :float = 0.05,
|
|
|
- nms_thresh :float = 0.6,
|
|
|
- topk :int = 1000,
|
|
|
- trainable :bool = False,
|
|
|
- deploy :bool = False):
|
|
|
+ num_classes :int = 20,
|
|
|
+ conf_thresh :float = 0.05,
|
|
|
+ nms_thresh :float = 0.6,
|
|
|
+ topk :int = 1000,
|
|
|
+ trainable :bool = False,
|
|
|
+ deploy :bool = False,
|
|
|
+ nms_class_agnostic :bool = False):
|
|
|
super(RTCDet, self).__init__()
|
|
|
# ---------------------- Basic Parameters ----------------------
|
|
|
self.cfg = cfg
|
|
|
@@ -36,6 +37,7 @@ class RTCDet(nn.Module):
|
|
|
self.nms_thresh = nms_thresh
|
|
|
self.topk = topk
|
|
|
self.deploy = deploy
|
|
|
+ self.nms_class_agnostic = nms_class_agnostic
|
|
|
self.head_dim = round(256*cfg['width'])
|
|
|
|
|
|
# ---------------------- Network Parameters ----------------------
|
|
|
@@ -111,7 +113,7 @@ class RTCDet(nn.Module):
|
|
|
|
|
|
# nms
|
|
|
scores, labels, bboxes = multiclass_nms(
|
|
|
- scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
|
|
|
+ scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
|
|
|
|
|
|
return bboxes, scores, labels
|
|
|
|