yolof_upsampler.py 978 B

1234567891011121314151617181920212223242526272829
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolof_basic import BasicConv
  5. except:
  6. from yolof_basic import BasicConv
  7. class YolofUpsampler(nn.Module):
  8. def __init__(self, cfg, in_dim, out_dim):
  9. super(YolofUpsampler, self).__init__()
  10. # ----------- Basic parameters -----------
  11. self.upscale_factor = cfg.upscale_factor
  12. inter_dim = self.upscale_factor ** 2 * in_dim
  13. # ----------- Model parameters -----------
  14. self.input_proj = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
  15. self.output_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
  16. def forward(self, x):
  17. # [B, C, H, W] -> [B, 4*C, H, W]
  18. x = self.input_proj(x)
  19. # [B, 4*C, H, W] -> [B, C, 2*H, 2*W]
  20. x = torch.pixel_shuffle(x, upscale_factor=self.upscale_factor)
  21. x = self.output_proj(x)
  22. return x