build.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import os
  2. try:
  3. from .voc import VOCDataset
  4. from .coco import COCODataset
  5. from .crowdhuman import CrowdHumanDataset
  6. from .ourdataset import OurDataset
  7. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  8. from .data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  9. except:
  10. from voc import VOCDataset
  11. from coco import COCODataset
  12. from crowdhuman import CrowdHumanDataset
  13. from ourdataset import OurDataset
  14. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  15. from data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  16. # ------------------------------ Dataset ------------------------------
  17. def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
  18. # ------------------------- Basic parameters -------------------------
  19. data_dir = os.path.join(args.root, data_cfg['data_name'])
  20. num_classes = data_cfg['num_classes']
  21. class_names = data_cfg['class_names']
  22. class_indexs = data_cfg['class_indexs']
  23. dataset_info = {
  24. 'num_classes': num_classes,
  25. 'class_names': class_names,
  26. 'class_indexs': class_indexs
  27. }
  28. # ------------------------- Build dataset -------------------------
  29. ## VOC dataset
  30. if args.dataset == 'voc':
  31. image_sets = [('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')]
  32. dataset = VOCDataset(img_size = args.img_size,
  33. data_dir = data_dir,
  34. image_sets = image_sets,
  35. transform = transform,
  36. trans_config = trans_config,
  37. is_train = is_train,
  38. load_cache = args.load_cache
  39. )
  40. ## COCO dataset
  41. elif args.dataset == 'coco':
  42. image_set = 'train2017' if is_train else 'val2017'
  43. dataset = COCODataset(img_size = args.img_size,
  44. data_dir = data_dir,
  45. image_set = image_set,
  46. transform = transform,
  47. trans_config = trans_config,
  48. is_train = is_train,
  49. load_cache = args.load_cache
  50. )
  51. ## CrowdHuman dataset
  52. elif args.dataset == 'crowdhuman':
  53. image_set = 'train' if is_train else 'val'
  54. dataset = CrowdHumanDataset(img_size = args.img_size,
  55. data_dir = data_dir,
  56. image_set = image_set,
  57. transform = transform,
  58. trans_config = trans_config,
  59. is_train = is_train,
  60. )
  61. ## Custom dataset
  62. elif args.dataset == 'ourdataset':
  63. image_set = 'train' if is_train else 'val'
  64. dataset = OurDataset(data_dir = data_dir,
  65. img_size = args.img_size,
  66. image_set = image_set,
  67. transform = transform,
  68. trans_config = trans_config,
  69. is_train = is_train,
  70. load_cache = args.load_cache
  71. )
  72. return dataset, dataset_info
  73. # ------------------------------ Transform ------------------------------
  74. def build_transform(args, trans_config, max_stride=32, is_train=False):
  75. # Modify trans_config
  76. if is_train:
  77. ## mosaic prob.
  78. if args.mosaic is not None:
  79. trans_config['mosaic_prob']=args.mosaic if is_train else 0.0
  80. else:
  81. trans_config['mosaic_prob']=trans_config['mosaic_prob'] if is_train else 0.0
  82. ## mixup prob.
  83. if args.mixup is not None:
  84. trans_config['mixup_prob']=args.mixup if is_train else 0.0
  85. else:
  86. trans_config['mixup_prob']=trans_config['mixup_prob'] if is_train else 0.0
  87. # Transform
  88. if trans_config['aug_type'] == 'ssd':
  89. if is_train:
  90. transform = SSDAugmentation(img_size=args.img_size,)
  91. else:
  92. transform = SSDBaseTransform(img_size=args.img_size,)
  93. trans_config['mosaic_prob'] = 0.0
  94. trans_config['mixup_prob'] = 0.0
  95. elif trans_config['aug_type'] == 'yolov5':
  96. if is_train:
  97. transform = YOLOv5Augmentation(
  98. img_size=args.img_size,
  99. trans_config=trans_config,
  100. use_ablu=trans_config['use_ablu']
  101. )
  102. else:
  103. transform = YOLOv5BaseTransform(
  104. img_size=args.img_size,
  105. max_stride=max_stride
  106. )
  107. return transform, trans_config