build.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374
  1. try:
  2. # dataset class
  3. from .voc import VOCDataset
  4. from .coco import COCODataset
  5. # transform class
  6. from .data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  7. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  8. except:
  9. # dataset class
  10. from voc import VOCDataset
  11. from coco import COCODataset
  12. # transform class
  13. from data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  14. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  15. # ------------------------------ Dataset ------------------------------
  16. def build_dataset(args, cfg, transform=None, is_train=False):
  17. # ------------------------- Build dataset -------------------------
  18. ## VOC dataset
  19. if args.dataset == 'voc':
  20. dataset = VOCDataset(cfg = cfg,
  21. data_dir = args.root,
  22. transform = transform,
  23. is_train = is_train,
  24. )
  25. ## COCO dataset
  26. elif args.dataset == 'coco':
  27. dataset = COCODataset(cfg = cfg,
  28. data_dir = args.root,
  29. transform = transform,
  30. is_train = is_train,
  31. )
  32. cfg.class_labels = dataset.class_labels
  33. cfg.class_indexs = dataset.class_indexs
  34. cfg.num_classes = dataset.num_classes
  35. return dataset
  36. # ------------------------------ Transform ------------------------------
  37. def build_transform(cfg, is_train=False):
  38. # ---------------- Build transform ----------------
  39. ## YOLO style transform
  40. if cfg.aug_type == 'yolo':
  41. if is_train:
  42. transform = YOLOAugmentation(cfg.train_img_size,
  43. cfg.affine_params,
  44. cfg.pixel_mean,
  45. cfg.pixel_std,
  46. )
  47. else:
  48. transform = YOLOBaseTransform(cfg.test_img_size,
  49. cfg.max_stride,
  50. cfg.pixel_mean,
  51. cfg.pixel_std,
  52. )
  53. ## RT-DETR style transform
  54. elif cfg.aug_type == 'ssd':
  55. if is_train:
  56. transform = SSDAugmentation(cfg.train_img_size,
  57. cfg.pixel_mean,
  58. cfg.pixel_std,
  59. )
  60. else:
  61. transform = SSDBaseTransform(cfg.test_img_size,
  62. cfg.pixel_mean,
  63. cfg.pixel_std,
  64. )
  65. return transform