yolof_backbone.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .resnet import build_resnet
  5. except:
  6. from resnet import build_resnet
  7. # --------------------- Yolov1's Backbone -----------------------
  8. class YolofBackbone(nn.Module):
  9. def __init__(self, cfg):
  10. super().__init__()
  11. self.backbone, self.feat_dim = build_resnet(cfg.backbone, cfg.use_pretrained)
  12. def forward(self, x):
  13. pyramid_feats = self.backbone(x)
  14. return pyramid_feats # [C3, C4, C5]
  15. if __name__=='__main__':
  16. from thop import profile
  17. # YOLOv1 configuration
  18. class YolofBaseConfig(object):
  19. def __init__(self) -> None:
  20. # ---------------- Model config ----------------
  21. self.out_stride = 32
  22. ## Backbone
  23. self.backbone = 'resnet18'
  24. self.use_pretrained = True
  25. cfg = YolofBaseConfig()
  26. # Build backbone
  27. model = YolofBackbone(cfg)
  28. # Randomly generate a input data
  29. x = torch.randn(2, 3, 640, 640)
  30. # Inference
  31. output = model(x)
  32. print(' - the shape of input : ', x.shape)
  33. print(' - the shape of output : ', output.shape)
  34. x = torch.randn(1, 3, 640, 640)
  35. flops, params = profile(model, inputs=(x, ), verbose=False)
  36. print('============== FLOPs & Params ================')
  37. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  38. print(' - Params : {:.2f} M'.format(params / 1e6))