build.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104
  1. import os
  2. try:
  3. from .voc import VOCDetection
  4. from .coco import COCODataset
  5. from .ourdataset import OurDataset
  6. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  7. from .data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  8. except:
  9. print(999)
  10. from voc import VOCDetection
  11. from coco import COCODataset
  12. from ourdataset import OurDataset
  13. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  14. from data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
  15. # ------------------------------ Dataset ------------------------------
  16. def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
  17. # ------------------------- Basic parameters -------------------------
  18. data_dir = os.path.join(args.root, data_cfg['data_name'])
  19. num_classes = data_cfg['num_classes']
  20. class_names = data_cfg['class_names']
  21. class_indexs = data_cfg['class_indexs']
  22. dataset_info = {
  23. 'num_classes': num_classes,
  24. 'class_names': class_names,
  25. 'class_indexs': class_indexs
  26. }
  27. # ------------------------- Build dataset -------------------------
  28. ## VOC dataset
  29. if args.dataset == 'voc':
  30. dataset = VOCDetection(
  31. img_size=args.img_size,
  32. data_dir=data_dir,
  33. image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
  34. transform=transform,
  35. trans_config=trans_config,
  36. load_cache=args.load_cache
  37. )
  38. ## COCO dataset
  39. elif args.dataset == 'coco':
  40. dataset = COCODataset(
  41. img_size=args.img_size,
  42. data_dir=data_dir,
  43. image_set='train2017' if is_train else 'val2017',
  44. transform=transform,
  45. trans_config=trans_config,
  46. load_cache=args.load_cache
  47. )
  48. ## Custom dataset
  49. elif args.dataset == 'ourdataset':
  50. dataset = OurDataset(
  51. data_dir=data_dir,
  52. img_size=args.img_size,
  53. image_set='train' if is_train else 'val',
  54. transform=transform,
  55. trans_config=trans_config,
  56. load_cache=args.load_cache
  57. )
  58. return dataset, dataset_info
  59. # ------------------------------ Transform ------------------------------
  60. def build_transform(args, trans_config, max_stride=32, is_train=False):
  61. # Modify trans_config
  62. if is_train:
  63. ## mosaic prob.
  64. if args.mosaic is not None:
  65. trans_config['mosaic_prob']=args.mosaic if is_train else 0.0
  66. else:
  67. trans_config['mosaic_prob']=trans_config['mosaic_prob'] if is_train else 0.0
  68. ## mixup prob.
  69. if args.mixup is not None:
  70. trans_config['mixup_prob']=args.mixup if is_train else 0.0
  71. else:
  72. trans_config['mixup_prob']=trans_config['mixup_prob'] if is_train else 0.0
  73. # Transform
  74. if trans_config['aug_type'] == 'ssd':
  75. if is_train:
  76. transform = SSDAugmentation(img_size=args.img_size,)
  77. else:
  78. transform = SSDBaseTransform(img_size=args.img_size,)
  79. trans_config['mosaic_prob'] = 0.0
  80. trans_config['mixup_prob'] = 0.0
  81. elif trans_config['aug_type'] == 'yolov5':
  82. if is_train:
  83. transform = YOLOv5Augmentation(
  84. img_size=args.img_size,
  85. trans_config=trans_config,
  86. use_ablu=trans_config['use_ablu']
  87. )
  88. else:
  89. transform = YOLOv5BaseTransform(
  90. img_size=args.img_size,
  91. max_stride=max_stride
  92. )
  93. return transform, trans_config