build.py 426 B

123456789101112131415161718
  1. import torch.nn as nn
  2. from .loss import SetCriterion
  3. from .rtcdet import RTCDet
  4. # build object detector
  5. def build_rtcdet(cfg, is_val=False):
  6. # -------------- Build YOLO --------------
  7. model = RTCDet(cfg, is_val)
  8. # -------------- Build criterion --------------
  9. criterion = None
  10. if is_val:
  11. # build criterion for training
  12. criterion = SetCriterion(cfg)
  13. return model, criterion