yolof_upsampler.py 1.1 KB

123456789101112131415161718192021222324252627282930
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .yolof_basic import BasicConv
  5. except:
  6. from yolof_basic import BasicConv
  7. class YolofUpsampler(nn.Module):
  8. def __init__(self, cfg, in_dims, out_dim):
  9. super(YolofUpsampler, self).__init__()
  10. # ----------- Model parameters -----------
  11. self.input_proj_1 = BasicConv(in_dims[-1], out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
  12. self.input_proj_2 = BasicConv(in_dims[-2], out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
  13. self.output_proj = nn.Sequential(
  14. BasicConv(out_dim * 2, out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm),
  15. BasicConv(out_dim, out_dim, kernel_size=3, padding=1, stride=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm),
  16. )
  17. def forward(self, pyramid_feats):
  18. x1 = self.input_proj_1(pyramid_feats[-1])
  19. x2 = self.input_proj_2(pyramid_feats[-2])
  20. x1_up = nn.functional.interpolate(x1, scale_factor=2.0)
  21. x3 = torch.cat([x2, x1_up], dim=1)
  22. out = self.output_proj(x3)
  23. return out