build.py 3.9 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. import os
  2. try:
  3. # dataset class
  4. from .voc import VOCDataset
  5. from .coco import COCODataset
  6. from .custom import CustomDataset
  7. # transform class
  8. from .data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  9. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  10. except:
  11. # dataset class
  12. from voc import VOCDataset
  13. from coco import COCODataset
  14. from yolo.dataset.custom import CustomDataset
  15. # transform class
  16. from data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  17. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  18. # ------------------------------ Dataset ------------------------------
  19. def build_dataset(args, cfg, transform=None, is_train=False):
  20. # ------------------------- Build dataset -------------------------
  21. ## VOC dataset
  22. if args.dataset == 'voc':
  23. image_set = [('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')]
  24. dataset = VOCDataset(cfg = cfg,
  25. data_dir = args.root,
  26. image_set = image_set,
  27. transform = transform,
  28. is_train = is_train,
  29. )
  30. ## COCO dataset
  31. elif args.dataset == 'coco':
  32. image_set = 'train2017' if is_train else 'val2017'
  33. dataset = COCODataset(cfg = cfg,
  34. data_dir = args.root,
  35. image_set = image_set,
  36. transform = transform,
  37. is_train = is_train,
  38. )
  39. ## Custom dataset
  40. elif args.dataset == 'custom':
  41. image_set = 'train' if is_train else 'val'
  42. dataset = CustomDataset(cfg = cfg,
  43. data_dir = args.root,
  44. image_set = image_set,
  45. transform = transform,
  46. is_train = is_train,
  47. )
  48. cfg.class_labels = dataset.class_labels
  49. cfg.class_indexs = dataset.class_indexs
  50. cfg.num_classes = dataset.num_classes
  51. return dataset
  52. # ------------------------------ Transform ------------------------------
  53. def build_transform(cfg, is_train=False):
  54. # ---------------- Build transform ----------------
  55. ## YOLO style transform
  56. if cfg.aug_type == 'yolo':
  57. if is_train:
  58. transform = YOLOAugmentation(cfg.train_img_size,
  59. cfg.affine_params,
  60. cfg.pixel_mean,
  61. cfg.pixel_std,
  62. cfg.box_format,
  63. cfg.normalize_coords)
  64. else:
  65. transform = YOLOBaseTransform(cfg.test_img_size,
  66. cfg.max_stride,
  67. cfg.pixel_mean,
  68. cfg.pixel_std,
  69. cfg.box_format,
  70. cfg.normalize_coords)
  71. ## RT-DETR style transform
  72. elif cfg.aug_type == 'ssd':
  73. if is_train:
  74. transform = SSDAugmentation(cfg.train_img_size,
  75. cfg.pixel_mean,
  76. cfg.pixel_std,
  77. cfg.box_format,
  78. cfg.normalize_coords)
  79. else:
  80. transform = SSDBaseTransform(cfg.test_img_size,
  81. cfg.pixel_mean,
  82. cfg.pixel_std,
  83. cfg.box_format,
  84. cfg.normalize_coords)
  85. return transform