yjh0410 1 жил өмнө
parent
commit
f75326d9de

+ 8 - 3
models/detectors/rtdetr/rtdetr.py

@@ -207,16 +207,21 @@ if __name__ == '__main__':
     criterion = build_criterion(cfg, num_classes=20)
 
     # Model inference
-    t0 = time.time()
     outputs = model(image, targets)
-    t1 = time.time()
-    print('Infer time: ', t1 - t0)
 
     # Compute loss
     loss = criterion(outputs, targets)
     for k in loss.keys():
         print("{} : {}".format(k, loss[k].item()))
 
+    # Inference
+    with torch.no_grad():
+        model.eval()
+        t0 = time.time()
+        outputs = model(image)
+        t1 = time.time()
+        print('Infer time: ', t1 - t0)
+
     print('==============================')
     model.eval()
     flops, params = profile(model, inputs=(image, ), verbose=False)