| 1234567891011121314151617181920212223242526272829 |
- import torch
- import torch.nn as nn
- try:
- from .yolof_basic import BasicConv
- except:
- from yolof_basic import BasicConv
- class YolofUpsampler(nn.Module):
- def __init__(self, cfg, in_dim, out_dim):
- super(YolofUpsampler, self).__init__()
- # ----------- Basic parameters -----------
- self.upscale_factor = cfg.upscale_factor
- inter_dim = self.upscale_factor ** 2 * in_dim
- # ----------- Model parameters -----------
- self.input_proj = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
- self.output_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
- def forward(self, x):
- # [B, C, H, W] -> [B, 4*C, H, W]
- x = self.input_proj(x)
- # [B, 4*C, H, W] -> [B, C, 2*H, 2*W]
- x = torch.pixel_shuffle(x, upscale_factor=self.upscale_factor)
-
- x = self.output_proj(x)
-
- return x
|