com_flops_params.py 426 B

123456789101112131415161718
  1. import torch
  2. from thop import profile
  3. def FLOPs_and_Params(model, size, img_dim, device):
  4. x = torch.randn(1, img_dim, size, size).to(device)
  5. model.eval()
  6. flops, params = profile(model, inputs=(x, ))
  7. print('=================== FLOPs & Params ===================')
  8. print('- GFLOPs : ', flops / 1e9 * 2)
  9. print('- Params : ', params / 1e6, ' M')
  10. model.train()
  11. if __name__ == "__main__":
  12. pass