yjh0410 2 anos atrás
pai
commit
af0a49f80b
1 arquivos alterados com 1 adições e 1 exclusões
  1. 1 1
      models/detectors/rtrdet/rtrdet_transformer.py

+ 1 - 1
models/detectors/rtrdet/rtrdet_transformer.py

@@ -130,8 +130,8 @@ class RTRDetTransformer(nn.Module):
             src2 = encoder_layer(src2, pos2d_embed_2)
         
         ## Feature fusion
+        src2 = src2.permute(0, 2, 1).reshape(bs, c, h, w)
         if src1 is not None:
-            src2 = src2.permute(0, 2, 1).reshape(bs, c, h, w)
             src1 = src1 + nn.functional.interpolate(src2, scale_factor=2.0)
         else:
             src1 = src2