yjh0410 2 rokov pred
rodič
commit
2259ef87f2
6 zmenil súbory, kde vykonal 19 pridanie a 8 odobranie
  1. 2 0
      demo.py
  2. 2 0
      eval.py
  3. 2 1
      models/detectors/rtcdet/build.py
  4. 9 7
      models/detectors/rtcdet/rtcdet.py
  5. 2 0
      test.py
  6. 2 0
      train.py

+ 2 - 0
demo.py

@@ -66,6 +66,8 @@ def parse_args():
                         help='fuse RepConv')
     parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
                         help='fuse Conv & BN')
+    parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
+                        help='Perform NMS operations regardless of category.')
 
     return parser.parse_args()
                     

+ 2 - 0
eval.py

@@ -42,6 +42,8 @@ def parse_args():
                         help="not decode in inference or yes")
     parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
                         help='fuse Conv & BN')
+    parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
+                        help='Perform NMS operations regardless of category.')
 
     # dataset
     parser.add_argument('--root', default='/mnt/share/ssd2/dataset',

+ 2 - 1
models/detectors/rtcdet/build.py

@@ -22,7 +22,8 @@ def build_rtcdet(args, cfg, device, num_classes=80, trainable=False, deploy=Fals
         conf_thresh=args.conf_thresh,
         nms_thresh=args.nms_thresh,
         topk=args.topk,
-        deploy=deploy
+        deploy=deploy,
+        nms_class_agnostic=args.nms_class_agnostic
         )
 
     # -------------- Initialize RTCDet --------------

+ 9 - 7
models/detectors/rtcdet/rtcdet.py

@@ -18,12 +18,13 @@ class RTCDet(nn.Module):
     def __init__(self, 
                  cfg,
                  device, 
-                 num_classes :int   = 20, 
-                 conf_thresh :float = 0.05,
-                 nms_thresh  :float = 0.6,
-                 topk        :int   = 1000,
-                 trainable   :bool  = False, 
-                 deploy      :bool  = False):
+                 num_classes        :int   = 20, 
+                 conf_thresh        :float = 0.05,
+                 nms_thresh         :float = 0.6,
+                 topk               :int   = 1000,
+                 trainable          :bool  = False, 
+                 deploy             :bool  = False,
+                 nms_class_agnostic :bool = False):
         super(RTCDet, self).__init__()
         # ---------------------- Basic Parameters ----------------------
         self.cfg = cfg
@@ -36,6 +37,7 @@ class RTCDet(nn.Module):
         self.nms_thresh = nms_thresh
         self.topk = topk
         self.deploy = deploy
+        self.nms_class_agnostic = nms_class_agnostic
         self.head_dim = round(256*cfg['width'])
         
         # ---------------------- Network Parameters ----------------------
@@ -111,7 +113,7 @@ class RTCDet(nn.Module):
 
         # nms
         scores, labels, bboxes = multiclass_nms(
-            scores, labels, bboxes, self.nms_thresh, self.num_classes, False)
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, self.nms_class_agnostic)
 
         return bboxes, scores, labels
 

+ 2 - 0
test.py

@@ -53,6 +53,8 @@ def parse_args():
                         help="not decode in inference or yes")
     parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
                         help='fuse Conv & BN')
+    parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
+                        help='Perform NMS operations regardless of category.')
 
     # dataset
     parser.add_argument('--root', default='/mnt/share/ssd2/dataset',

+ 2 - 0
train.py

@@ -71,6 +71,8 @@ def parse_args():
                         help='load pretrained weight')
     parser.add_argument('-r', '--resume', default=None, type=str,
                         help='keep training')
+    parser.add_argument('--nms_class_agnostic', action='store_true', default=False,
+                        help='Perform NMS operations regardless of category.')
 
     # Dataset
     parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',