| 123456789101112131415161718 |
- import torch
- from thop import profile
- def FLOPs_and_Params(model, size, img_dim, device):
- x = torch.randn(1, img_dim, size, size).to(device)
- model.eval()
- flops, params = profile(model, inputs=(x, ))
- print('=================== FLOPs & Params ===================')
- print('- GFLOPs : ', flops / 1e9 * 2)
- print('- Params : ', params / 1e6, ' M')
- model.train()
- if __name__ == "__main__":
- pass
|