build.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. try:
  3. # dataset class
  4. from .voc import VOCDataset
  5. from .coco import COCODataset
  6. from .crowdhuman import CrowdHumanDataset
  7. from .widerface import WiderFaceDataset
  8. from .customed import CustomedDataset
  9. # transform class
  10. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  11. from .data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  12. except:
  13. # dataset class
  14. from voc import VOCDataset
  15. from coco import COCODataset
  16. from crowdhuman import CrowdHumanDataset
  17. from widerface import WiderFaceDataset
  18. from customed import CustomedDataset
  19. # transform class
  20. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  21. from data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  22. # ------------------------------ Dataset ------------------------------
  23. def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
  24. # ------------------------- Basic parameters -------------------------
  25. data_dir = os.path.join(args.root, data_cfg['data_name'])
  26. num_classes = data_cfg['num_classes']
  27. class_names = data_cfg['class_names']
  28. class_indexs = data_cfg['class_indexs']
  29. dataset_info = {
  30. 'num_classes': num_classes,
  31. 'class_names': class_names,
  32. 'class_indexs': class_indexs
  33. }
  34. # ------------------------- Build dataset -------------------------
  35. ## VOC dataset
  36. if args.dataset == 'voc':
  37. image_sets = [('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')]
  38. dataset = VOCDataset(img_size = args.img_size,
  39. data_dir = data_dir,
  40. image_sets = image_sets,
  41. transform = transform,
  42. trans_config = trans_config,
  43. is_train = is_train,
  44. )
  45. ## COCO dataset
  46. elif args.dataset == 'coco':
  47. image_set = 'train2017' if is_train else 'val2017'
  48. dataset = COCODataset(img_size = args.img_size,
  49. data_dir = data_dir,
  50. image_set = image_set,
  51. transform = transform,
  52. trans_config = trans_config,
  53. is_train = is_train,
  54. )
  55. ## CrowdHuman dataset
  56. elif args.dataset == 'crowdhuman':
  57. image_set = 'train' if is_train else 'val'
  58. dataset = CrowdHumanDataset(img_size = args.img_size,
  59. data_dir = data_dir,
  60. image_set = image_set,
  61. transform = transform,
  62. trans_config = trans_config,
  63. is_train = is_train,
  64. )
  65. ## WiderFace dataset
  66. elif args.dataset == 'widerface':
  67. image_set = 'train' if is_train else 'val'
  68. dataset = WiderFaceDataset(img_size = args.img_size,
  69. data_dir = data_dir,
  70. image_set = image_set,
  71. transform = transform,
  72. trans_config = trans_config,
  73. is_train = is_train,
  74. )
  75. ## Custom dataset
  76. elif args.dataset == 'customed':
  77. image_set = 'train' if is_train else 'val'
  78. dataset = CustomedDataset(data_dir = data_dir,
  79. img_size = args.img_size,
  80. image_set = image_set,
  81. transform = transform,
  82. trans_config = trans_config,
  83. is_train = is_train,
  84. )
  85. return dataset, dataset_info
  86. # ------------------------------ Transform ------------------------------
  87. def build_transform(args, trans_config, max_stride=32, is_train=False):
  88. # ---------------- Modify trans_config ----------------
  89. if is_train:
  90. ## mosaic prob.
  91. if args.mosaic is not None:
  92. trans_config['mosaic_prob'] = args.mosaic
  93. ## mixup prob.
  94. if args.mixup is not None:
  95. trans_config['mixup_prob'] = args.mixup
  96. # ---------------- Build transform ----------------
  97. ## SSD style transform
  98. if trans_config['aug_type'] == 'ssd':
  99. if is_train:
  100. transform = SSDAugmentation(args.img_size)
  101. else:
  102. transform = SSDBaseTransform(args.img_size)
  103. ## YOLO style transform
  104. elif trans_config['aug_type'] == 'yolo':
  105. if is_train:
  106. transform = YOLOv5Augmentation(args.img_size, trans_config['affine_params'], trans_config['use_ablu'])
  107. else:
  108. transform = YOLOv5BaseTransform(args.img_size, max_stride)
  109. return transform, trans_config