build.py 686 B

123456789101112131415161718192021
  1. from .resnet import ResNet
  2. from .modules import PlainResBlock, BottleneckResBlock
  3. def build_resnet(args):
  4. if args.model == 'resnet18':
  5. model = ResNet(in_dim=args.img_dim,
  6. block=PlainResBlock,
  7. expansion=1.0,
  8. num_blocks=[2, 2, 2, 2],
  9. )
  10. elif args.model == 'resnet50':
  11. model = ResNet(in_dim=args.img_dim,
  12. block=BottleneckResBlock,
  13. expansion=4.0,
  14. num_blocks=[3, 4, 6, 3],
  15. )
  16. else:
  17. raise NotImplementedError("Unknown resnet: {}".format(args.model))
  18. return model