build.py 3.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697
  1. import os
  2. try:
  3. from .voc import VOCDetection
  4. from .coco import COCODataset
  5. from .ourdataset import OurDataset
  6. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  7. from .data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  8. except:
  9. from voc import VOCDetection
  10. from coco import COCODataset
  11. from ourdataset import OurDataset
  12. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  13. from data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  14. # ------------------------------ Dataset ------------------------------
  15. def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
  16. # Basic parameters
  17. data_dir = os.path.join(args.root, data_cfg['data_name'])
  18. num_classes = data_cfg['num_classes']
  19. class_names = data_cfg['class_names']
  20. class_indexs = data_cfg['class_indexs']
  21. dataset_info = {
  22. 'num_classes': num_classes,
  23. 'class_names': class_names,
  24. 'class_indexs': class_indexs
  25. }
  26. # Build dataset class
  27. ## VOC dataset
  28. if args.dataset == 'voc':
  29. dataset = VOCDetection(
  30. img_size=args.img_size,
  31. data_dir=data_dir,
  32. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
  33. transform=transform,
  34. trans_config=trans_config
  35. )
  36. ## COCO dataset
  37. elif args.dataset == 'coco':
  38. dataset = COCODataset(
  39. img_size=args.img_size,
  40. data_dir=data_dir,
  41. image_set='train2017' if is_train else 'val2017',
  42. transform=transform,
  43. trans_config=trans_config
  44. )
  45. ## Custom dataset
  46. elif args.dataset == 'ourdataset':
  47. dataset = OurDataset(
  48. data_dir=data_dir,
  49. img_size=args.img_size,
  50. image_set='train' if is_train else 'val',
  51. transform=transform,
  52. trans_config=trans_config,
  53. )
  54. return dataset, dataset_info
  55. # ------------------------------ Transform ------------------------------
  56. def build_transform(args, trans_config, max_stride=32, is_train=False):
  57. # Modify trans_config
  58. if is_train:
  59. ## mosaic prob.
  60. if args.mosaic is not None:
  61. trans_config['mosaic_prob']=args.mosaic if is_train else 0.0
  62. else:
  63. trans_config['mosaic_prob']=trans_config['mosaic_prob'] if is_train else 0.0
  64. ## mixup prob.
  65. if args.mixup is not None:
  66. trans_config['mixup_prob']=args.mixup if is_train else 0.0
  67. else:
  68. trans_config['mixup_prob']=trans_config['mixup_prob'] if is_train else 0.0
  69. # Transform
  70. if trans_config['aug_type'] == 'ssd':
  71. if is_train:
  72. transform = SSDAugmentation(img_size=args.img_size,)
  73. else:
  74. transform = SSDBaseTransform(img_size=args.img_size,)
  75. elif trans_config['aug_type'] == 'yolov5':
  76. if is_train:
  77. transform = YOLOv5Augmentation(
  78. img_size=args.img_size,
  79. trans_config=trans_config
  80. )
  81. else:
  82. transform = YOLOv5BaseTransform(
  83. img_size=args.img_size,
  84. max_stride=max_stride
  85. )
  86. return transform, trans_config