build.py 3.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. import os
  2. try:
  3. # dataset class
  4. from .voc import VOCDataset
  5. from .coco import COCODataset
  6. from .custom import CustomDataset
  7. # transform class
  8. from .data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  9. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  10. except:
  11. # dataset class
  12. from voc import VOCDataset
  13. from coco import COCODataset
  14. from custom import CustomDataset
  15. # transform class
  16. from data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  17. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  18. # ------------------------------ Dataset ------------------------------
  19. def build_dataset(args, cfg, transform=None, is_train=False):
  20. # ------------------------- Build dataset -------------------------
  21. ## VOC dataset
  22. if args.dataset == 'voc':
  23. dataset = VOCDataset(cfg = cfg,
  24. data_dir = args.root,
  25. transform = transform,
  26. is_train = is_train,
  27. )
  28. ## COCO dataset
  29. elif args.dataset == 'coco':
  30. dataset = COCODataset(cfg = cfg,
  31. data_dir = args.root,
  32. transform = transform,
  33. is_train = is_train,
  34. )
  35. ## Custom dataset
  36. elif args.dataset == 'custom':
  37. image_set = 'train' if is_train else 'val'
  38. dataset = CustomDataset(cfg = cfg,
  39. data_dir = args.root,
  40. image_set = image_set,
  41. transform = transform,
  42. is_train = is_train,
  43. )
  44. cfg.class_labels = dataset.class_labels
  45. cfg.class_indexs = dataset.class_indexs
  46. cfg.num_classes = dataset.num_classes
  47. return dataset
  48. # ------------------------------ Transform ------------------------------
  49. def build_transform(cfg, is_train=False):
  50. # ---------------- Build transform ----------------
  51. ## YOLO style transform
  52. if cfg.aug_type == 'yolo':
  53. if is_train:
  54. transform = YOLOAugmentation(cfg.train_img_size,
  55. cfg.affine_params,
  56. cfg.pixel_mean,
  57. cfg.pixel_std,
  58. cfg.box_format,
  59. cfg.normalize_coords)
  60. else:
  61. transform = YOLOBaseTransform(cfg.test_img_size,
  62. cfg.max_stride,
  63. cfg.pixel_mean,
  64. cfg.pixel_std,
  65. cfg.box_format,
  66. cfg.normalize_coords)
  67. ## RT-DETR style transform
  68. elif cfg.aug_type == 'ssd':
  69. if is_train:
  70. transform = SSDAugmentation(cfg.train_img_size,
  71. cfg.pixel_mean,
  72. cfg.pixel_std,
  73. cfg.box_format,
  74. cfg.normalize_coords)
  75. else:
  76. transform = SSDBaseTransform(cfg.test_img_size,
  77. cfg.pixel_mean,
  78. cfg.pixel_std,
  79. cfg.box_format,
  80. cfg.normalize_coords)
  81. return transform