backbone.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. import torch
  2. import torch.nn as nn
  3. from torch import Tensor
  4. from typing import Callable, List, Optional, Type, Union
  5. try:
  6. from .basic import conv1x1, BasicBlock, Bottleneck
  7. except:
  8. from basic import conv1x1, BasicBlock, Bottleneck
  9. # IN1K pretrained weights
  10. pretrained_urls = {
  11. # ResNet series
  12. 'resnet18': None,
  13. 'resnet34': None,
  14. 'resnet50': None,
  15. 'resnet101': None,
  16. 'resnet152': None,
  17. # ShuffleNet series
  18. }
  19. # ----------------- Model functions -----------------
  20. ## Build backbone network
  21. def build_backbone(cfg, pretrained):
  22. if 'resnet' in cfg['backbone']:
  23. # Build ResNet
  24. model, feats = build_resnet(cfg, pretrained)
  25. else:
  26. raise NotImplementedError("Unknown backbone: <>.".format(cfg['backbone']))
  27. return model, feats
  28. ## Load pretrained weight
  29. def load_pretrained(model_name):
  30. return
  31. # ----------------- ResNet Backbone -----------------
  32. class ResNet(nn.Module):
  33. def __init__(self,
  34. block: Type[Union[BasicBlock, Bottleneck]],
  35. layers: List[int],
  36. num_classes: int = 1000,
  37. zero_init_residual: bool = False,
  38. groups: int = 1,
  39. width_per_group: int = 64,
  40. replace_stride_with_dilation: Optional[List[bool]] = None,
  41. norm_layer: Optional[Callable[..., nn.Module]] = None,
  42. ) -> None:
  43. super().__init__()
  44. # --------------- Basic parameters ----------------
  45. self.groups = groups
  46. self.base_width = width_per_group
  47. self.inplanes = 64
  48. self.dilation = 1
  49. self.zero_init_residual = zero_init_residual
  50. self.replace_stride_with_dilation = [False, False, False] if replace_stride_with_dilation is None else replace_stride_with_dilation
  51. if len(self.replace_stride_with_dilation) != 3:
  52. raise ValueError(
  53. "replace_stride_with_dilation should be None "
  54. f"or a 3-element tuple, got {self.replace_stride_with_dilation}"
  55. )
  56. # --------------- Network parameters ----------------
  57. self._norm_layer = nn.BatchNorm2d if norm_layer is None else norm_layer
  58. ## Stem layer
  59. self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
  60. self.bn1 = self._norm_layer(self.inplanes)
  61. self.relu = nn.ReLU(inplace=True)
  62. self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
  63. ## Res Layer
  64. self.layer1 = self._make_layer(block, 64, layers[0])
  65. self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=self.replace_stride_with_dilation[0])
  66. self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=self.replace_stride_with_dilation[1])
  67. self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=self.replace_stride_with_dilation[2])
  68. self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
  69. self.fc = nn.Linear(512 * block.expansion, num_classes)
  70. self._init_layer()
  71. def _init_layer(self):
  72. for m in self.modules():
  73. if isinstance(m, nn.Conv2d):
  74. nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
  75. elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
  76. nn.init.constant_(m.weight, 1)
  77. nn.init.constant_(m.bias, 0)
  78. if self.zero_init_residual:
  79. for m in self.modules():
  80. if isinstance(m, Bottleneck) and m.bn3.weight is not None:
  81. nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
  82. elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
  83. nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
  84. def _make_layer(
  85. self,
  86. block: Type[Union[BasicBlock, Bottleneck]],
  87. planes: int,
  88. blocks: int,
  89. stride: int = 1,
  90. dilate: bool = False,
  91. ) -> nn.Sequential:
  92. norm_layer = self._norm_layer
  93. downsample = None
  94. previous_dilation = self.dilation
  95. if dilate:
  96. self.dilation *= stride
  97. stride = 1
  98. if stride != 1 or self.inplanes != planes * block.expansion:
  99. downsample = nn.Sequential(
  100. conv1x1(self.inplanes, planes * block.expansion, stride),
  101. norm_layer(planes * block.expansion),
  102. )
  103. layers = []
  104. layers.append(
  105. block(
  106. self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
  107. )
  108. )
  109. self.inplanes = planes * block.expansion
  110. for _ in range(1, blocks):
  111. layers.append(
  112. block(
  113. self.inplanes,
  114. planes,
  115. groups=self.groups,
  116. base_width=self.base_width,
  117. dilation=self.dilation,
  118. norm_layer=norm_layer,
  119. )
  120. )
  121. return nn.Sequential(*layers)
  122. def forward(self, x: Tensor) -> Tensor:
  123. # See note [TorchScript super()]
  124. x = self.conv1(x)
  125. x = self.bn1(x)
  126. x = self.relu(x)
  127. x = self.maxpool(x)
  128. x = self.layer1(x)
  129. x = self.layer2(x)
  130. x = self.layer3(x)
  131. x = self.layer4(x)
  132. x = self.avgpool(x)
  133. x = torch.flatten(x, 1)
  134. x = self.fc(x)
  135. return x
  136. def _resnet(block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], **kwargs) -> ResNet:
  137. return ResNet(block, layers, **kwargs)
  138. def build_resnet(cfg, pretrained=False, **kwargs):
  139. # ---------- Build ResNet ----------
  140. if cfg['backbone'] == 'resnet18':
  141. model = _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
  142. feats = [128, 256, 512]
  143. elif cfg['backbone'] == 'resnet34':
  144. model = _resnet(BasicBlock, [3, 4, 6, 3], **kwargs)
  145. feats = [128, 256, 512]
  146. elif cfg['backbone'] == 'resnet50':
  147. model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
  148. feats = [512, 1024, 2048]
  149. elif cfg['backbone'] == 'resnet101':
  150. model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
  151. feats = [512, 1024, 2048]
  152. elif cfg['backbone'] == 'resnet152':
  153. model = _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
  154. feats = [512, 1024, 2048]
  155. # ---------- Load pretrained ----------
  156. if pretrained:
  157. # TODO: load IN1K pretrained
  158. pass
  159. return model, feats
  160. # ----------------- ShuffleNet Backbone -----------------
  161. ## TODO: Add shufflenet-v2