dilated_encoder.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. import torch.nn as nn
  2. from utils import weight_init
  3. from ..basic.conv import BasicConv
  4. # BottleNeck
  5. class Bottleneck(nn.Module):
  6. def __init__(self, in_dim, dilation, expand_ratio, act_type='relu', norm_type='BN'):
  7. super(Bottleneck, self).__init__()
  8. # ------------------ Basic parameters -------------------
  9. self.in_dim = in_dim
  10. self.dilation = dilation
  11. self.expand_ratio = expand_ratio
  12. inter_dim = round(in_dim * expand_ratio)
  13. # ------------------ Network parameters -------------------
  14. self.branch = nn.Sequential(
  15. BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type),
  16. BasicConv(inter_dim, inter_dim, kernel_size=3, padding=dilation, dilation=dilation, act_type=act_type, norm_type=norm_type),
  17. BasicConv(inter_dim, in_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
  18. )
  19. def forward(self, x):
  20. return x + self.branch(x)
  21. # Dilated Encoder
  22. class DilatedEncoder(nn.Module):
  23. def __init__(self, cfg, in_dim, out_dim):
  24. super(DilatedEncoder, self).__init__()
  25. # ------------------ Basic parameters -------------------
  26. self.in_dim = in_dim
  27. self.out_dim = out_dim
  28. self.expand_ratio = cfg.neck_expand_ratio
  29. self.dilations = cfg.neck_dilations
  30. self.act_type = cfg.neck_act
  31. self.norm_type = cfg.neck_norm
  32. # ------------------ Network parameters -------------------
  33. ## proj layer
  34. self.projector = nn.Sequential(
  35. BasicConv(in_dim, out_dim, kernel_size=1, act_type=None, norm_type=self.norm_type),
  36. BasicConv(out_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=self.norm_type)
  37. )
  38. ## encoder layers
  39. self.encoders = nn.Sequential(
  40. *[Bottleneck(out_dim, d, self.expand_ratio, self.act_type, self.norm_type) for d in self.dilations])
  41. self._init_weight()
  42. def _init_weight(self):
  43. for m in self.projector:
  44. if isinstance(m, nn.Conv2d):
  45. weight_init.c2_xavier_fill(m)
  46. weight_init.c2_xavier_fill(m)
  47. if isinstance(m, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
  48. nn.init.constant_(m.weight, 1)
  49. nn.init.constant_(m.bias, 0)
  50. for m in self.encoders.modules():
  51. if isinstance(m, nn.Conv2d):
  52. nn.init.normal_(m.weight, mean=0, std=0.01)
  53. if hasattr(m, 'bias') and m.bias is not None:
  54. nn.init.constant_(m.bias, 0)
  55. if isinstance(m, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
  56. nn.init.constant_(m.weight, 1)
  57. nn.init.constant_(m.bias, 0)
  58. def forward(self, x):
  59. x = self.projector(x)
  60. x = self.encoders(x)
  61. return x