yjh0410 2 年之前
父节点
当前提交
af0a49f80b
共有 1 个文件被更改,包括 1 次插入1 次删除
  1. 1 1
      models/detectors/rtrdet/rtrdet_transformer.py

+ 1 - 1
models/detectors/rtrdet/rtrdet_transformer.py

@@ -130,8 +130,8 @@ class RTRDetTransformer(nn.Module):
             src2 = encoder_layer(src2, pos2d_embed_2)
         
         ## Feature fusion
+        src2 = src2.permute(0, 2, 1).reshape(bs, c, h, w)
         if src1 is not None:
-            src2 = src2.permute(0, 2, 1).reshape(bs, c, h, w)
             src1 = src1 + nn.functional.interpolate(src2, scale_factor=2.0)
         else:
             src1 = src2