build.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. import os
  2. try:
  3. from .voc import VOCDataset
  4. from .coco import COCODataset
  5. from .crowdhuman import CrowdHumanDataset
  6. from .widerface import WiderFaceDataset
  7. from .customed import CustomedDataset
  8. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  9. from .data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  10. except:
  11. from .voc import VOCDataset
  12. from coco import COCODataset
  13. from crowdhuman import CrowdHumanDataset
  14. from widerface import WiderFaceDataset
  15. from dataset.customed import CustomedDataset
  16. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  17. from data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  18. # ------------------------------ Dataset ------------------------------
  19. def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
  20. # ------------------------- Basic parameters -------------------------
  21. data_dir = os.path.join(args.root, data_cfg['data_name'])
  22. num_classes = data_cfg['num_classes']
  23. class_names = data_cfg['class_names']
  24. class_indexs = data_cfg['class_indexs']
  25. dataset_info = {
  26. 'num_classes': num_classes,
  27. 'class_names': class_names,
  28. 'class_indexs': class_indexs
  29. }
  30. # ------------------------- Build dataset -------------------------
  31. ## VOC dataset
  32. if args.dataset == 'voc':
  33. image_sets = [('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')]
  34. dataset = VOCDataset(img_size = args.img_size,
  35. data_dir = data_dir,
  36. image_sets = image_sets,
  37. transform = transform,
  38. trans_config = trans_config,
  39. is_train = is_train,
  40. load_cache = args.load_cache
  41. )
  42. ## COCO dataset
  43. elif args.dataset == 'coco':
  44. image_set = 'train2017' if is_train else 'val2017'
  45. dataset = COCODataset(img_size = args.img_size,
  46. data_dir = data_dir,
  47. image_set = image_set,
  48. transform = transform,
  49. trans_config = trans_config,
  50. is_train = is_train,
  51. load_cache = args.load_cache
  52. )
  53. ## CrowdHuman dataset
  54. elif args.dataset == 'crowdhuman':
  55. image_set = 'train' if is_train else 'val'
  56. dataset = CrowdHumanDataset(img_size = args.img_size,
  57. data_dir = data_dir,
  58. image_set = image_set,
  59. transform = transform,
  60. trans_config = trans_config,
  61. is_train = is_train,
  62. )
  63. ## WiderFace dataset
  64. elif args.dataset == 'widerface':
  65. image_set = 'train' if is_train else 'val'
  66. dataset = WiderFaceDataset(img_size = args.img_size,
  67. data_dir = data_dir,
  68. image_set = image_set,
  69. transform = transform,
  70. trans_config = trans_config,
  71. is_train = is_train,
  72. )
  73. ## Custom dataset
  74. elif args.dataset == 'customed':
  75. image_set = 'train' if is_train else 'val'
  76. dataset = CustomedDataset(data_dir = data_dir,
  77. img_size = args.img_size,
  78. image_set = image_set,
  79. transform = transform,
  80. trans_config = trans_config,
  81. is_train = is_train,
  82. load_cache = args.load_cache
  83. )
  84. return dataset, dataset_info
  85. # ------------------------------ Transform ------------------------------
  86. def build_transform(args, trans_config, max_stride=32, is_train=False):
  87. # Modify trans_config
  88. if is_train:
  89. ## mosaic prob.
  90. if args.mosaic is not None:
  91. trans_config['mosaic_prob']=args.mosaic if is_train else 0.0
  92. else:
  93. trans_config['mosaic_prob']=trans_config['mosaic_prob'] if is_train else 0.0
  94. ## mixup prob.
  95. if args.mixup is not None:
  96. trans_config['mixup_prob']=args.mixup if is_train else 0.0
  97. else:
  98. trans_config['mixup_prob']=trans_config['mixup_prob'] if is_train else 0.0
  99. # Transform
  100. if trans_config['aug_type'] == 'ssd':
  101. if is_train:
  102. transform = SSDAugmentation(img_size=args.img_size,)
  103. else:
  104. transform = SSDBaseTransform(img_size=args.img_size,)
  105. trans_config['mosaic_prob'] = 0.0
  106. trans_config['mixup_prob'] = 0.0
  107. elif trans_config['aug_type'] == 'yolov5':
  108. if is_train:
  109. transform = YOLOv5Augmentation(
  110. img_size=args.img_size,
  111. trans_config=trans_config,
  112. use_ablu=trans_config['use_ablu']
  113. )
  114. else:
  115. transform = YOLOv5BaseTransform(
  116. img_size=args.img_size,
  117. max_stride=max_stride
  118. )
  119. return transform, trans_config