| 123456789101112131415161718192021 |
- from .resnet import ResNet
- from .modules import PlainResBlock, BottleneckResBlock
- def build_resnet(args):
- if args.model == 'resnet18':
- model = ResNet(in_dim=args.img_dim,
- block=PlainResBlock,
- expansion=1.0,
- num_blocks=[2, 2, 2, 2],
- )
- elif args.model == 'resnet50':
- model = ResNet(in_dim=args.img_dim,
- block=BottleneckResBlock,
- expansion=4.0,
- num_blocks=[3, 4, 6, 3],
- )
- else:
- raise NotImplementedError("Unknown resnet: {}".format(args.model))
-
- return model
|