build.py 3.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. import os
  2. try:
  3. from .voc import VOCDetection
  4. from .coco import COCODataset
  5. from .ourdataset import OurDataset
  6. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  7. from .data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  8. except:
  9. import sys
  10. sys.path.append('.')
  11. from voc import VOCDetection
  12. from coco import COCODataset
  13. from ourdataset import OurDataset
  14. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  15. from data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  16. # ------------------------------ Dataset ------------------------------
  17. def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
  18. # Basic parameters
  19. data_dir = os.path.join(args.root, data_cfg['data_name'])
  20. num_classes = data_cfg['num_classes']
  21. class_names = data_cfg['class_names']
  22. class_indexs = data_cfg['class_indexs']
  23. dataset_info = {
  24. 'num_classes': num_classes,
  25. 'class_names': class_names,
  26. 'class_indexs': class_indexs
  27. }
  28. # Build dataset class
  29. ## VOC dataset
  30. if args.dataset == 'voc':
  31. dataset = VOCDetection(
  32. img_size=args.img_size,
  33. data_dir=data_dir,
  34. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
  35. transform=transform,
  36. trans_config=trans_config
  37. )
  38. ## COCO dataset
  39. elif args.dataset == 'coco':
  40. dataset = COCODataset(
  41. img_size=args.img_size,
  42. data_dir=data_dir,
  43. image_set='train2017' if is_train else 'val2017',
  44. transform=transform,
  45. trans_config=trans_config
  46. )
  47. ## Custom dataset
  48. elif args.dataset == 'ourdataset':
  49. dataset = OurDataset(
  50. data_dir=data_dir,
  51. img_size=args.img_size,
  52. image_set='train' if is_train else 'val',
  53. transform=transform,
  54. trans_config=trans_config,
  55. )
  56. return dataset, dataset_info
  57. # ------------------------------ Transform ------------------------------
  58. def build_transform(args, trans_config, max_stride=32, is_train=False):
  59. # Modify trans_config
  60. if is_train:
  61. ## mosaic prob.
  62. if args.mosaic is not None:
  63. trans_config['mosaic_prob']=args.mosaic if is_train else 0.0
  64. else:
  65. trans_config['mosaic_prob']=trans_config['mosaic_prob'] if is_train else 0.0
  66. ## mixup prob.
  67. if args.mixup is not None:
  68. trans_config['mixup_prob']=args.mixup if is_train else 0.0
  69. else:
  70. trans_config['mixup_prob']=trans_config['mixup_prob'] if is_train else 0.0
  71. # Transform
  72. if trans_config['aug_type'] == 'ssd':
  73. if is_train:
  74. transform = SSDAugmentation(img_size=args.img_size,)
  75. else:
  76. transform = SSDBaseTransform(img_size=args.img_size,)
  77. elif trans_config['aug_type'] == 'yolov5':
  78. if is_train:
  79. transform = YOLOv5Augmentation(
  80. img_size=args.img_size,
  81. trans_config=trans_config
  82. )
  83. else:
  84. transform = YOLOv5BaseTransform(
  85. img_size=args.img_size,
  86. max_stride=max_stride
  87. )
  88. return transform, trans_config