| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from typing import List
- # ------------------ Basic Feature Pyramid Network ------------------
- class FcosFPN(nn.Module):
- def __init__(self, cfg, in_dims: List = [512, 1024, 2048]):
- super().__init__()
- # ------------------ Basic parameters -------------------
- self.out_dim = cfg.head_dim
- # ------------------ Network parameters -------------------
- self.input_proj_1 = nn.Conv2d(in_dims[0], self.out_dim, kernel_size=1)
- self.input_proj_2 = nn.Conv2d(in_dims[1], self.out_dim, kernel_size=1)
- self.input_proj_3 = nn.Conv2d(in_dims[2], self.out_dim, kernel_size=1)
- self.smooth_layer_1 = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=1)
- self.smooth_layer_2 = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=1)
- self.smooth_layer_3 = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=1)
- self.p6_conv = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, stride=2, padding=1)
- def forward(self, feats):
- """
- feats: (List of Tensor) [C3, C4, C5]
- """
- c3, c4, c5 = feats
- # -------- Input projection --------
- p3 = self.input_proj_1(c3)
- p4 = self.input_proj_2(c4)
- p5 = self.input_proj_3(c5)
-
- # -------- Feature fusion --------
- outputs = [self.smooth_layer_3(p5)]
- # P5 -> P4
- p4 = p4 + F.interpolate(p5, size=p4.shape[2:], mode='nearest')
- outputs.insert(0, self.smooth_layer_2(p4))
- # P4 -> P3
- p3 = p3 + F.interpolate(p4, size=p3.shape[2:], mode='nearest')
- outputs.insert(0, self.smooth_layer_1(p3))
- # P5 -> P6
- outputs.append(self.p6_conv(outputs[-1]))
- # [P3, P4, P5, P6]
- return outputs
- if __name__=='__main__':
- import time
- from thop import profile
- # Model config
-
- # YOLOv2-Base config
- class FcosBaseConfig(object):
- def __init__(self) -> None:
- # ---------------- Model config ----------------
- self.width = 0.50
- self.depth = 0.34
- self.out_stride = [8, 16, 32, 64]
- ## Head
- self.head_dim = 256
- cfg = FcosBaseConfig()
- # Build a head
- in_dims = [128, 256, 512]
- fpn = FcosFPN(cfg, in_dims)
- # Inference
- x = [torch.randn(1, in_dims[0], 80, 80),
- torch.randn(1, in_dims[1], 40, 40),
- torch.randn(1, in_dims[2], 20, 20)]
- t0 = time.time()
- output = fpn(x)
- t1 = time.time()
- print('Time: ', t1 - t0)
- print('====== FPN output ====== ')
- for level, feat in enumerate(output):
- print("- Level-{} : ".format(level), feat.shape)
- flops, params = profile(fpn, inputs=(x, ), verbose=False)
- print('==============================')
- print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
- print('Params : {:.2f} M'.format(params / 1e6))
|