build.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import os
  2. try:
  3. # dataset class
  4. from .voc import VOCDataset
  5. from .coco import COCODataset
  6. from .custom import CustomDataset
  7. # transform class
  8. from .data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  9. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  10. except:
  11. # dataset class
  12. from voc import VOCDataset
  13. from coco import COCODataset
  14. from yolo.dataset.custom import CustomDataset
  15. # transform class
  16. from data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  17. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  18. # ------------------------------ Dataset ------------------------------
  19. def build_dataset(args, cfg, transform=None, is_train=False):
  20. # ------------------------- Build dataset -------------------------
  21. ## VOC dataset
  22. if args.dataset == 'voc':
  23. image_set = [('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')]
  24. dataset = VOCDataset(cfg = cfg,
  25. data_dir = args.root,
  26. image_set = image_set,
  27. transform = transform,
  28. is_train = is_train,
  29. )
  30. ## COCO dataset
  31. elif args.dataset == 'coco':
  32. image_set = 'train2017' if is_train else 'val2017'
  33. dataset = COCODataset(cfg = cfg,
  34. data_dir = args.root,
  35. image_set = image_set,
  36. transform = transform,
  37. is_train = is_train,
  38. )
  39. ## Custom dataset
  40. elif args.dataset == 'custom':
  41. image_set = 'train' if is_train else 'val'
  42. dataset = CustomDataset(cfg = cfg,
  43. data_dir = args.root,
  44. image_set = image_set,
  45. transform = transform,
  46. is_train = is_train,
  47. )
  48. cfg.class_labels = dataset.class_labels
  49. cfg.class_indexs = dataset.class_indexs
  50. cfg.num_classes = dataset.num_classes
  51. return dataset
  52. # ------------------------------ Transform ------------------------------
  53. def build_transform(cfg, is_train=False):
  54. # ---------------- Build transform ----------------
  55. ## YOLO style transform
  56. if cfg.aug_type == 'yolo':
  57. if is_train:
  58. transform = YOLOAugmentation(cfg.train_img_size,
  59. cfg.affine_params,
  60. cfg.use_ablu,
  61. cfg.pixel_mean,
  62. cfg.pixel_std,
  63. cfg.box_format,
  64. cfg.normalize_coords)
  65. else:
  66. transform = YOLOBaseTransform(cfg.test_img_size,
  67. cfg.max_stride,
  68. cfg.pixel_mean,
  69. cfg.pixel_std,
  70. cfg.box_format,
  71. cfg.normalize_coords)
  72. ## RT-DETR style transform
  73. elif cfg.aug_type == 'ssd':
  74. if is_train:
  75. transform = SSDAugmentation(cfg.train_img_size,
  76. cfg.pixel_mean,
  77. cfg.pixel_std,
  78. cfg.box_format,
  79. cfg.normalize_coords)
  80. else:
  81. transform = SSDBaseTransform(cfg.test_img_size,
  82. cfg.pixel_mean,
  83. cfg.pixel_std,
  84. cfg.box_format,
  85. cfg.normalize_coords)
  86. return transform