build.py 553 B

12345678910111213141516
  1. from .convnet import ConvNet
  2. def build_convnet(args):
  3. if args.model == "convnet":
  4. model = ConvNet(img_size = args.img_size,
  5. in_dim = args.img_dim,
  6. hidden_dim = 64,
  7. num_classes = args.num_classes,
  8. act_type = "relu",
  9. norm_type = "bn",
  10. use_adavgpool = True)
  11. else:
  12. raise NotImplementedError("Unknown model: {}".format(args.model))
  13. return model