|
|
@@ -4,19 +4,24 @@ import torch
|
|
|
import random
|
|
|
import numpy as np
|
|
|
|
|
|
-import sys
|
|
|
-sys.path.append("../")
|
|
|
-from utils import distributed_utils
|
|
|
-from dataset.voc import VOCDataset, VOC_CLASSES
|
|
|
-from dataset.coco import COCODataset, coco_class_labels, coco_class_index
|
|
|
-from config import build_trans_config, build_dataset_config
|
|
|
-
|
|
|
-
|
|
|
-def fix_random_seed(args):
|
|
|
- seed = args.seed + distributed_utils.get_rank()
|
|
|
- torch.manual_seed(seed)
|
|
|
- np.random.seed(seed)
|
|
|
- random.seed(seed)
|
|
|
+from voc import VOCDataset, VOC_CLASSES
|
|
|
+from coco import COCODataset, coco_class_labels, coco_class_index
|
|
|
+
|
|
|
+dataset_cfg = {
|
|
|
+ 'voc': {
|
|
|
+ 'data_name': 'VOCdevkit',
|
|
|
+ 'num_classes': 20,
|
|
|
+ 'class_indexs': None,
|
|
|
+ 'class_names': ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'),
|
|
|
+ },
|
|
|
+
|
|
|
+ 'coco':{
|
|
|
+ 'data_name': 'COCO',
|
|
|
+ 'num_classes': 80,
|
|
|
+ 'class_indexs': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90],
|
|
|
+ 'class_names': ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
|
|
|
+ },
|
|
|
+}
|
|
|
|
|
|
# ------------------------------ Dataset ------------------------------
|
|
|
def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
|
|
|
@@ -96,11 +101,19 @@ def visualize(image, target, dataset_name="voc"):
|
|
|
cv2.imshow('gt', image)
|
|
|
cv2.waitKey(0)
|
|
|
|
|
|
+def build_dataset_config(args):
|
|
|
+ if args.dataset in ['coco', 'coco-val', 'coco-test']:
|
|
|
+ cfg = dataset_cfg['coco']
|
|
|
+ else:
|
|
|
+ cfg = dataset_cfg[args.dataset]
|
|
|
+
|
|
|
+ print('==============================')
|
|
|
+ print('Dataset Config: {} \n'.format(cfg))
|
|
|
+
|
|
|
+ return cfg
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
- import argparse
|
|
|
- from build import build_transform
|
|
|
-
|
|
|
+ import argparse
|
|
|
parser = argparse.ArgumentParser(description='VOC-Dataset')
|
|
|
|
|
|
# Seed
|
|
|
@@ -126,40 +139,20 @@ if __name__ == "__main__":
|
|
|
help='mosaic augmentation.')
|
|
|
parser.add_argument('--mixup', default=None, type=float,
|
|
|
help='mixup augmentation.')
|
|
|
- # DDP train
|
|
|
- parser.add_argument('-dist', '--distributed', action='store_true', default=False,
|
|
|
- help='distributed training')
|
|
|
- parser.add_argument('--dist_url', default='env://',
|
|
|
- help='url used to set up distributed training')
|
|
|
- parser.add_argument('--world_size', default=1, type=int,
|
|
|
- help='number of distributed processes')
|
|
|
- parser.add_argument('--sybn', action='store_true', default=False,
|
|
|
- help='use sybn.')
|
|
|
# Output
|
|
|
parser.add_argument('--output_dir', type=str, default='cache_data/',
|
|
|
help='data root')
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
-
|
|
|
- assert args.aug_type in ["yolov5_pico", "yolov5_nano", "yolov5_small", "yolov5_medium", "yolov5_large", "yolov5_huge",
|
|
|
- "yolox_pico", "yolox_nano", "yolox_small", "yolox_medium", "yolox_large", "yolox_huge"]
|
|
|
-
|
|
|
|
|
|
# ------------- Build transform config -------------
|
|
|
dataset_cfg = build_dataset_config(args)
|
|
|
- trans_config = build_trans_config(args.aug_type)
|
|
|
-
|
|
|
- # ------------- Build transform -------------
|
|
|
- transform, trans_config = build_transform(args, trans_config, max_stride=32, is_train=args.is_train)
|
|
|
|
|
|
# ------------- Build dataset -------------
|
|
|
- dataset, dataset_info = build_dataset(args, dataset_cfg, trans_config, transform, is_train=args.is_train)
|
|
|
+ dataset, dataset_info = build_dataset(args, dataset_cfg, trans_config=None, transform=None, is_train=args.is_train)
|
|
|
print('Data length: ', len(dataset))
|
|
|
|
|
|
- # ---------------------------- Fix random seed ----------------------------
|
|
|
- fix_random_seed(args)
|
|
|
-
|
|
|
# ---------------------------- Main process ----------------------------
|
|
|
# We only cache the taining data
|
|
|
data_items = []
|