yolov2_backbone.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. import torch
  2. import torch.nn as nn
  3. try:
  4. from .resnet import build_resnet
  5. except:
  6. from resnet import build_resnet
  7. # --------------------- Yolov2's Backbone -----------------------
  8. class Yolov2Backbone(nn.Module):
  9. def __init__(self, cfg):
  10. super().__init__()
  11. self.backbone, self.feat_dim = build_resnet(cfg.backbone, cfg.use_pretrained)
  12. def forward(self, x):
  13. c5 = self.backbone(x)
  14. return c5
  15. if __name__=='__main__':
  16. import time
  17. from thop import profile
  18. # YOLOv8-Base config
  19. class Yolov2BaseConfig(object):
  20. def __init__(self) -> None:
  21. # ---------------- Model config ----------------
  22. self.out_stride = 32
  23. self.max_stride = 32
  24. ## Backbone
  25. self.backbone = 'resnet18'
  26. self.use_pretrained = True
  27. cfg = Yolov2BaseConfig()
  28. # Build backbone
  29. model = Yolov2Backbone(cfg)
  30. # Inference
  31. x = torch.randn(1, 3, 640, 640)
  32. t0 = time.time()
  33. output = model(x)
  34. t1 = time.time()
  35. print('Time: ', t1 - t0)
  36. print(output.shape)
  37. print('==============================')
  38. flops, params = profile(model, inputs=(x, ), verbose=False)
  39. print('==============================')
  40. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  41. print('Params : {:.2f} M'.format(params / 1e6))