build.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. import os
  2. try:
  3. # dataset class
  4. from .voc import VOCDataset
  5. from .coco import COCODataset
  6. from .customed import CustomedDataset
  7. # transform class
  8. from .data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  9. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  10. except:
  11. # dataset class
  12. from voc import VOCDataset
  13. from coco import COCODataset
  14. from customed import CustomedDataset
  15. # transform class
  16. from data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  17. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  18. # ------------------------------ Dataset ------------------------------
  19. def build_dataset(args, cfg, transform=None, is_train=False):
  20. # ------------------------- Build dataset -------------------------
  21. ## VOC dataset
  22. if args.dataset == 'voc':
  23. image_set = [('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')]
  24. cfg.num_classes = 20
  25. dataset = VOCDataset(cfg = cfg,
  26. data_dir = args.root,
  27. image_set = image_set,
  28. transform = transform,
  29. is_train = is_train,
  30. )
  31. ## COCO dataset
  32. elif args.dataset == 'coco':
  33. image_set = 'train2017' if is_train else 'val2017'
  34. cfg.num_classes = 80
  35. dataset = COCODataset(cfg = cfg,
  36. data_dir = args.root,
  37. image_set = image_set,
  38. transform = transform,
  39. is_train = is_train,
  40. )
  41. ## Custom dataset
  42. elif args.dataset == 'customed':
  43. image_set = 'train' if is_train else 'val'
  44. cfg.num_classes = 20
  45. dataset = CustomedDataset(cfg = cfg,
  46. data_dir = args.root,
  47. image_set = image_set,
  48. transform = transform,
  49. is_train = is_train,
  50. )
  51. cfg.class_labels = dataset.class_labels
  52. cfg.class_indexs = dataset.class_indexs
  53. cfg.num_classes = dataset.num_classes
  54. return dataset
  55. # ------------------------------ Transform ------------------------------
  56. def build_transform(cfg, is_train=False):
  57. # ---------------- Build transform ----------------
  58. ## YOLO style transform
  59. if cfg.aug_type == 'yolo':
  60. if is_train:
  61. transform = YOLOAugmentation(cfg.train_img_size,
  62. cfg.affine_params,
  63. cfg.use_ablu,
  64. cfg.pixel_mean,
  65. cfg.pixel_std,
  66. cfg.box_format,
  67. cfg.normalize_coords)
  68. else:
  69. transform = YOLOBaseTransform(cfg.test_img_size,
  70. cfg.max_stride,
  71. cfg.pixel_mean,
  72. cfg.pixel_std,
  73. cfg.box_format,
  74. cfg.normalize_coords)
  75. ## RT-DETR style transform
  76. elif cfg.aug_type == 'ssd':
  77. if is_train:
  78. transform = SSDAugmentation(cfg.train_img_size,
  79. cfg.pixel_mean,
  80. cfg.pixel_std,
  81. cfg.box_format,
  82. cfg.normalize_coords)
  83. else:
  84. transform = SSDBaseTransform(cfg.test_img_size,
  85. cfg.pixel_mean,
  86. cfg.pixel_std,
  87. cfg.box_format,
  88. cfg.normalize_coords)
  89. return transform