fcos_backbone.py 1.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .resnet import build_resnet
  5. except:
  6. from resnet import build_resnet
  7. # --------------------- FCOS's Backbone -----------------------
  8. class FcosBackbone(nn.Module):
  9. def __init__(self, cfg):
  10. super().__init__()
  11. self.backbone, self.feat_dims = build_resnet(cfg.backbone, cfg.use_pretrained)
  12. def forward(self, x):
  13. pyramid_feats = self.backbone(x)
  14. return pyramid_feats # [C3, C4, C5]
  15. if __name__=='__main__':
  16. from thop import profile
  17. # YOLOv1 configuration
  18. class FcosBaseConfig(object):
  19. def __init__(self) -> None:
  20. # ---------------- Model config ----------------
  21. self.out_stride = [8, 16, 32]
  22. self.max_stride = 32
  23. ## Backbone
  24. self.backbone = 'resnet18'
  25. self.use_pretrained = True
  26. cfg = FcosBaseConfig()
  27. # Build backbone
  28. model = FcosBackbone(cfg)
  29. # Randomly generate a input data
  30. x = torch.randn(2, 3, 640, 640)
  31. # Inference
  32. outputs = model(x)
  33. print(' - the shape of input : ', x.shape)
  34. for i, out in enumerate(outputs):
  35. print(f' - the shape of level-{i} output : ', out.shape)
  36. x = torch.randn(1, 3, 640, 640)
  37. flops, params = profile(model, inputs=(x, ), verbose=False)
  38. print('============== FLOPs & Params ================')
  39. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  40. print(' - Params : {:.2f} M'.format(params / 1e6))