com_flops_params.py 439 B

123456789101112131415
  1. import torch
  2. from thop import profile
  3. def FLOPs_and_Params(model, img_size, device):
  4. x = torch.randn(1, 3, img_size, img_size).to(device)
  5. print('==============================')
  6. flops, params = profile(model, inputs=(x, ), verbose=False)
  7. print('==============================')
  8. print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
  9. print('Params : {:.2f} M'.format(params / 1e6))
  10. if __name__ == "__main__":
  11. pass