|
@@ -8,22 +8,23 @@ except:
|
|
|
|
|
|
|
|
|
|
|
|
|
class YolofUpsampler(nn.Module):
|
|
class YolofUpsampler(nn.Module):
|
|
|
- def __init__(self, cfg, in_dim, out_dim):
|
|
|
|
|
|
|
+ def __init__(self, cfg, in_dims, out_dim):
|
|
|
super(YolofUpsampler, self).__init__()
|
|
super(YolofUpsampler, self).__init__()
|
|
|
- # ----------- Basic parameters -----------
|
|
|
|
|
- self.upscale_factor = cfg.upscale_factor
|
|
|
|
|
- inter_dim = self.upscale_factor ** 2 * in_dim
|
|
|
|
|
# ----------- Model parameters -----------
|
|
# ----------- Model parameters -----------
|
|
|
- self.input_proj = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
|
|
|
|
|
- self.output_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
|
|
|
|
|
|
|
+ self.input_proj_1 = BasicConv(in_dims[-1], out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
|
|
|
|
|
+ self.input_proj_2 = BasicConv(in_dims[-2], out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
|
|
|
|
|
+ self.output_proj = nn.Sequential(
|
|
|
|
|
+ BasicConv(out_dim * 2, out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm),
|
|
|
|
|
+ BasicConv(out_dim, out_dim, kernel_size=3, padding=1, stride=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm),
|
|
|
|
|
+ )
|
|
|
|
|
|
|
|
- def forward(self, x):
|
|
|
|
|
- # [B, C, H, W] -> [B, 4*C, H, W]
|
|
|
|
|
- x = self.input_proj(x)
|
|
|
|
|
-
|
|
|
|
|
- # [B, 4*C, H, W] -> [B, C, 2*H, 2*W]
|
|
|
|
|
- x = torch.pixel_shuffle(x, upscale_factor=self.upscale_factor)
|
|
|
|
|
|
|
+ def forward(self, pyramid_feats):
|
|
|
|
|
+ x1 = self.input_proj_1(pyramid_feats[-1])
|
|
|
|
|
+ x2 = self.input_proj_2(pyramid_feats[-2])
|
|
|
|
|
|
|
|
- x = self.output_proj(x)
|
|
|
|
|
|
|
+ x1_up = nn.functional.interpolate(x1, scale_factor=2.0)
|
|
|
|
|
+
|
|
|
|
|
+ x3 = torch.cat([x2, x1_up], dim=1)
|
|
|
|
|
+ out = self.output_proj(x3)
|
|
|
|
|
|
|
|
- return x
|
|
|
|
|
|
|
+ return out
|