| 123456789101112131415161718192021222324252627282930 |
- import torch
- import torch.nn as nn
- try:
- from .yolof_basic import BasicConv
- except:
- from yolof_basic import BasicConv
- class YolofUpsampler(nn.Module):
- def __init__(self, cfg, in_dims, out_dim):
- super(YolofUpsampler, self).__init__()
- # ----------- Model parameters -----------
- self.input_proj_1 = BasicConv(in_dims[-1], out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
- self.input_proj_2 = BasicConv(in_dims[-2], out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
- self.output_proj = nn.Sequential(
- BasicConv(out_dim * 2, out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm),
- BasicConv(out_dim, out_dim, kernel_size=3, padding=1, stride=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm),
- )
- def forward(self, pyramid_feats):
- x1 = self.input_proj_1(pyramid_feats[-1])
- x2 = self.input_proj_2(pyramid_feats[-2])
-
- x1_up = nn.functional.interpolate(x1, scale_factor=2.0)
- x3 = torch.cat([x2, x1_up], dim=1)
- out = self.output_proj(x3)
-
- return out
|