yjh0410 преди 2 години
родител
ревизия
ec6f6f81df
променени са 2 файла, в които са добавени 3 реда и са изтрити 4 реда
  1. 3 3
      models/detectors/rtdetr/rtdetr.py
  2. 0 1
      test.py

+ 3 - 3
models/detectors/rtdetr/rtdetr.py

@@ -30,7 +30,7 @@ class RTDETR(nn.Module):
         
         # --------- Network Parameters ----------
         ## Encoder
-        self.img_encoder = build_encoder(cfg, trainable, 'img_encoder')
+        self.encoder = build_encoder(cfg, trainable, 'img_encoder')
 
         ## Decoder
         self.decoder = build_decoder(cfg, self.d_model, return_intermediate=aux_loss)
@@ -57,7 +57,7 @@ class RTDETR(nn.Module):
     @torch.no_grad()
     def inference_single_image(self, x):
         # -------------------- Encoder --------------------
-        memory, memory_pos = self.img_encoder(x)
+        memory, memory_pos = self.encoder(x)
 
         # -------------------- Decoder --------------------
         hs, reference = self.decoder(memory, memory_pos)
@@ -93,7 +93,7 @@ class RTDETR(nn.Module):
             return self.inference_single_image(x)
         else:
             # -------------------- Encoder --------------------
-            memory, memory_pos = self.img_encoder(x)
+            memory, memory_pos = self.encoder(x)
 
             # -------------------- Decoder --------------------
             hs, reference = self.decoder(memory, memory_pos)

+ 0 - 1
test.py

@@ -134,7 +134,6 @@ def test(args,
         t0 = time.time()
         # inference
         bboxes, scores, labels = model(x)
-        print(bboxes, scores, labels)
         print("detection time used ", time.time() - t0, "s")
         
         # rescale bboxes