build.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. try:
  2. # dataset class
  3. from .voc import VOCDataset
  4. from .coco import COCODataset
  5. # transform class
  6. from .data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  7. from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  8. except:
  9. # dataset class
  10. from voc import VOCDataset
  11. from coco import COCODataset
  12. # transform class
  13. from data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
  14. from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
  15. # ------------------------------ Dataset ------------------------------
  16. def build_dataset(args, cfg, transform=None, is_train=False):
  17. # ------------------------- Build dataset -------------------------
  18. ## VOC dataset
  19. if args.dataset == 'voc':
  20. dataset = VOCDataset(cfg = cfg,
  21. data_dir = args.root,
  22. transform = transform,
  23. is_train = is_train,
  24. )
  25. ## COCO dataset
  26. elif args.dataset == 'coco':
  27. dataset = COCODataset(cfg = cfg,
  28. data_dir = args.root,
  29. transform = transform,
  30. is_train = is_train,
  31. )
  32. cfg.class_labels = dataset.class_labels
  33. cfg.class_indexs = dataset.class_indexs
  34. cfg.num_classes = dataset.num_classes
  35. return dataset
  36. # ------------------------------ Transform ------------------------------
  37. def build_transform(cfg, is_train=False):
  38. # ---------------- Build transform ----------------
  39. ## YOLO style transform
  40. if cfg.aug_type == 'yolo':
  41. if is_train:
  42. transform = YOLOAugmentation(cfg.train_img_size,
  43. cfg.affine_params,
  44. cfg.pixel_mean,
  45. cfg.pixel_std,
  46. cfg.box_format,
  47. cfg.normalize_coords)
  48. else:
  49. transform = YOLOBaseTransform(cfg.test_img_size,
  50. cfg.max_stride,
  51. cfg.pixel_mean,
  52. cfg.pixel_std,
  53. cfg.box_format,
  54. cfg.normalize_coords)
  55. ## RT-DETR style transform
  56. elif cfg.aug_type == 'ssd':
  57. if is_train:
  58. transform = SSDAugmentation(cfg.train_img_size,
  59. cfg.pixel_mean,
  60. cfg.pixel_std,
  61. cfg.box_format,
  62. cfg.normalize_coords)
  63. else:
  64. transform = SSDBaseTransform(cfg.test_img_size,
  65. cfg.pixel_mean,
  66. cfg.pixel_std,
  67. cfg.box_format,
  68. cfg.normalize_coords)
  69. return transform