yolov1_backbone.py 1.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .resnet import build_resnet
  5. except:
  6. from resnet import build_resnet
  7. # --------------------- Yolov1's Backbone -----------------------
  8. class Yolov1Backbone(nn.Module):
  9. def __init__(self, cfg):
  10. super().__init__()
  11. self.backbone, self.feat_dim = build_resnet(cfg.backbone, cfg.use_pretrained)
  12. def forward(self, x):
  13. c5 = self.backbone(x)
  14. return c5
  15. if __name__=='__main__':
  16. from thop import profile
  17. # YOLOv1 configuration
  18. class Yolov1BaseConfig(object):
  19. def __init__(self) -> None:
  20. # ---------------- Model config ----------------
  21. self.out_stride = 32
  22. self.max_stride = 32
  23. ## Backbone
  24. self.backbone = 'resnet18'
  25. self.use_pretrained = True
  26. cfg = Yolov1BaseConfig()
  27. # Build backbone
  28. model = Yolov1Backbone(cfg)
  29. # Randomly generate a input data
  30. x = torch.randn(2, 3, 640, 640)
  31. # Inference
  32. output = model(x)
  33. print(' - the shape of input : ', x.shape)
  34. print(' - the shape of output : ', output.shape)
  35. x = torch.randn(1, 3, 640, 640)
  36. flops, params = profile(model, inputs=(x, ), verbose=False)
  37. print('============== FLOPs & Params ================')
  38. print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
  39. print(' - Params : {:.2f} M'.format(params / 1e6))