| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172 |
- import torch.nn as nn
- from utils import weight_init
- from ..basic.conv import BasicConv
- # BottleNeck
- class Bottleneck(nn.Module):
- def __init__(self, in_dim, dilation, expand_ratio, act_type='relu', norm_type='BN'):
- super(Bottleneck, self).__init__()
- # ------------------ Basic parameters -------------------
- self.in_dim = in_dim
- self.dilation = dilation
- self.expand_ratio = expand_ratio
- inter_dim = round(in_dim * expand_ratio)
- # ------------------ Network parameters -------------------
- self.branch = nn.Sequential(
- BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type),
- BasicConv(inter_dim, inter_dim, kernel_size=3, padding=dilation, dilation=dilation, act_type=act_type, norm_type=norm_type),
- BasicConv(inter_dim, in_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
- )
- def forward(self, x):
- return x + self.branch(x)
- # Dilated Encoder
- class DilatedEncoder(nn.Module):
- def __init__(self, cfg, in_dim, out_dim):
- super(DilatedEncoder, self).__init__()
- # ------------------ Basic parameters -------------------
- self.in_dim = in_dim
- self.out_dim = out_dim
- self.expand_ratio = cfg.neck_expand_ratio
- self.dilations = cfg.neck_dilations
- self.act_type = cfg.neck_act
- self.norm_type = cfg.neck_norm
- # ------------------ Network parameters -------------------
- ## proj layer
- self.projector = nn.Sequential(
- BasicConv(in_dim, out_dim, kernel_size=1, act_type=None, norm_type=self.norm_type),
- BasicConv(out_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=self.norm_type)
- )
- ## encoder layers
- self.encoders = nn.Sequential(
- *[Bottleneck(out_dim, d, self.expand_ratio, self.act_type, self.norm_type) for d in self.dilations])
- self._init_weight()
- def _init_weight(self):
- for m in self.projector:
- if isinstance(m, nn.Conv2d):
- weight_init.c2_xavier_fill(m)
- weight_init.c2_xavier_fill(m)
- if isinstance(m, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- for m in self.encoders.modules():
- if isinstance(m, nn.Conv2d):
- nn.init.normal_(m.weight, mean=0, std=0.01)
- if hasattr(m, 'bias') and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- if isinstance(m, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
- nn.init.constant_(m.weight, 1)
- nn.init.constant_(m.bias, 0)
- def forward(self, x):
- x = self.projector(x)
- x = self.encoders(x)
- return x
|