build.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133
  1. import os
  2. try:
  3. # dataset class
  4. from .voc import VOCDataset
  5. from .coco import COCODataset
  6. from .crowdhuman import CrowdHumanDataset
  7. from .widerface import WiderFaceDataset
  8. from .customed import CustomedDataset
  9. # transform class
  10. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  11. from .data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  12. from .data_augment.rtdetr_augment import RTDetrAugmentation, RTDetrBaseTransform
  13. except:
  14. # dataset class
  15. from voc import VOCDataset
  16. from coco import COCODataset
  17. from crowdhuman import CrowdHumanDataset
  18. from widerface import WiderFaceDataset
  19. from customed import CustomedDataset
  20. # transform class
  21. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  22. from data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  23. from data_augment.rtdetr_augment import RTDetrAugmentation, RTDetrBaseTransform
  24. # ------------------------------ Dataset ------------------------------
  25. def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
  26. # ------------------------- Basic parameters -------------------------
  27. data_dir = os.path.join(args.root, data_cfg['data_name'])
  28. num_classes = data_cfg['num_classes']
  29. class_names = data_cfg['class_names']
  30. class_indexs = data_cfg['class_indexs']
  31. dataset_info = {
  32. 'num_classes': num_classes,
  33. 'class_names': class_names,
  34. 'class_indexs': class_indexs
  35. }
  36. # ------------------------- Build dataset -------------------------
  37. ## VOC dataset
  38. if args.dataset == 'voc':
  39. image_sets = [('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')]
  40. dataset = VOCDataset(img_size = args.img_size,
  41. data_dir = data_dir,
  42. image_sets = image_sets,
  43. transform = transform,
  44. trans_config = trans_config,
  45. is_train = is_train,
  46. load_cache = args.load_cache
  47. )
  48. ## COCO dataset
  49. elif args.dataset == 'coco':
  50. image_set = 'train2017' if is_train else 'val2017'
  51. dataset = COCODataset(img_size = args.img_size,
  52. data_dir = data_dir,
  53. image_set = image_set,
  54. transform = transform,
  55. trans_config = trans_config,
  56. is_train = is_train,
  57. load_cache = args.load_cache
  58. )
  59. ## CrowdHuman dataset
  60. elif args.dataset == 'crowdhuman':
  61. image_set = 'train' if is_train else 'val'
  62. dataset = CrowdHumanDataset(img_size = args.img_size,
  63. data_dir = data_dir,
  64. image_set = image_set,
  65. transform = transform,
  66. trans_config = trans_config,
  67. is_train = is_train,
  68. )
  69. ## WiderFace dataset
  70. elif args.dataset == 'widerface':
  71. image_set = 'train' if is_train else 'val'
  72. dataset = WiderFaceDataset(img_size = args.img_size,
  73. data_dir = data_dir,
  74. image_set = image_set,
  75. transform = transform,
  76. trans_config = trans_config,
  77. is_train = is_train,
  78. )
  79. ## Custom dataset
  80. elif args.dataset == 'customed':
  81. image_set = 'train' if is_train else 'val'
  82. dataset = CustomedDataset(data_dir = data_dir,
  83. img_size = args.img_size,
  84. image_set = image_set,
  85. transform = transform,
  86. trans_config = trans_config,
  87. is_train = is_train,
  88. load_cache = args.load_cache
  89. )
  90. return dataset, dataset_info
  91. # ------------------------------ Transform ------------------------------
  92. def build_transform(args, trans_config, max_stride=32, is_train=False):
  93. # ---------------- Modify trans_config ----------------
  94. if is_train:
  95. ## mosaic prob.
  96. if args.mosaic is not None:
  97. trans_config['mosaic_prob'] = args.mosaic
  98. ## mixup prob.
  99. if args.mixup is not None:
  100. trans_config['mixup_prob'] = args.mixup
  101. # ---------------- Build transform ----------------
  102. ## SSD style transform
  103. if trans_config['aug_type'] == 'ssd':
  104. if is_train:
  105. transform = SSDAugmentation(img_size=args.img_size,)
  106. else:
  107. transform = SSDBaseTransform(img_size=args.img_size,)
  108. ## YOLO style transform
  109. elif trans_config['aug_type'] == 'yolov5':
  110. if is_train:
  111. transform = YOLOv5Augmentation(img_size=args.img_size, trans_config=trans_config, use_ablu=trans_config['use_ablu'])
  112. else:
  113. transform = YOLOv5BaseTransform(img_size=args.img_size,max_stride=max_stride)
  114. ## RT-DETR style transform
  115. elif trans_config['aug_type'] == 'rtdetr':
  116. if is_train:
  117. use_mosaic = False if trans_config['mosaic_prob'] < 0.2 else True
  118. transform = RTDetrAugmentation(
  119. img_size=args.img_size, pixel_mean=trans_config['pixel_mean'], pixel_std=trans_config['pixel_std'], use_mosaic=use_mosaic)
  120. else:
  121. transform = RTDetrBaseTransform(
  122. img_size=args.img_size, pixel_mean=trans_config['pixel_mean'], pixel_std=trans_config['pixel_std'])
  123. return transform, trans_config