| 123456789101112131415161718192021222324252627282930313233343536 |
- # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
- import torch
- from .fcos.build import build_fcos
- from .yolof.build import build_yolof
- from .detr.build import build_detr
- def build_model(args, cfg, num_classes=80, is_val=False):
- # ------------ build object detector ------------
- ## FCOS
- if 'fcos' in args.model:
- model, criterion = build_fcos(cfg, num_classes, is_val)
- ## YOLOF
- elif 'yolof' in args.model:
- model, criterion = build_yolof(cfg, num_classes, is_val)
- ## DETR
- elif 'detr' in args.model:
- model, criterion = build_detr(cfg, num_classes, is_val)
- else:
- raise NotImplementedError("Unknown detector: {}".args.model)
-
- if is_val:
- # ------------ Keep training from the given weight ------------
- if args.resume is not None:
- print('keep training: ', args.resume)
- checkpoint = torch.load(args.resume, map_location='cpu')
- # checkpoint state dict
- checkpoint_state_dict = checkpoint.pop("model")
- model.load_state_dict(checkpoint_state_dict)
- return model, criterion
- else:
- return model
-
|