|
@@ -4,19 +4,24 @@ import torch
|
|
|
import random
|
|
import random
|
|
|
import numpy as np
|
|
import numpy as np
|
|
|
|
|
|
|
|
-import sys
|
|
|
|
|
-sys.path.append("../")
|
|
|
|
|
-from utils import distributed_utils
|
|
|
|
|
-from dataset.voc import VOCDataset, VOC_CLASSES
|
|
|
|
|
-from dataset.coco import COCODataset, coco_class_labels, coco_class_index
|
|
|
|
|
-from config import build_trans_config, build_dataset_config
|
|
|
|
|
-
|
|
|
|
|
-
|
|
|
|
|
-def fix_random_seed(args):
|
|
|
|
|
- seed = args.seed + distributed_utils.get_rank()
|
|
|
|
|
- torch.manual_seed(seed)
|
|
|
|
|
- np.random.seed(seed)
|
|
|
|
|
- random.seed(seed)
|
|
|
|
|
|
|
+from voc import VOCDataset, VOC_CLASSES
|
|
|
|
|
+from coco import COCODataset, coco_class_labels, coco_class_index
|
|
|
|
|
+
|
|
|
|
|
+dataset_cfg = {
|
|
|
|
|
+ 'voc': {
|
|
|
|
|
+ 'data_name': 'VOCdevkit',
|
|
|
|
|
+ 'num_classes': 20,
|
|
|
|
|
+ 'class_indexs': None,
|
|
|
|
|
+ 'class_names': ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'),
|
|
|
|
|
+ },
|
|
|
|
|
+
|
|
|
|
|
+ 'coco':{
|
|
|
|
|
+ 'data_name': 'COCO',
|
|
|
|
|
+ 'num_classes': 80,
|
|
|
|
|
+ 'class_indexs': [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90],
|
|
|
|
|
+ 'class_names': ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'),
|
|
|
|
|
+ },
|
|
|
|
|
+}
|
|
|
|
|
|
|
|
# ------------------------------ Dataset ------------------------------
|
|
# ------------------------------ Dataset ------------------------------
|
|
|
def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
|
|
def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
|
|
@@ -96,11 +101,19 @@ def visualize(image, target, dataset_name="voc"):
|
|
|
cv2.imshow('gt', image)
|
|
cv2.imshow('gt', image)
|
|
|
cv2.waitKey(0)
|
|
cv2.waitKey(0)
|
|
|
|
|
|
|
|
|
|
+def build_dataset_config(args):
|
|
|
|
|
+ if args.dataset in ['coco', 'coco-val', 'coco-test']:
|
|
|
|
|
+ cfg = dataset_cfg['coco']
|
|
|
|
|
+ else:
|
|
|
|
|
+ cfg = dataset_cfg[args.dataset]
|
|
|
|
|
+
|
|
|
|
|
+ print('==============================')
|
|
|
|
|
+ print('Dataset Config: {} \n'.format(cfg))
|
|
|
|
|
+
|
|
|
|
|
+ return cfg
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if __name__ == "__main__":
|
|
|
- import argparse
|
|
|
|
|
- from build import build_transform
|
|
|
|
|
-
|
|
|
|
|
|
|
+ import argparse
|
|
|
parser = argparse.ArgumentParser(description='VOC-Dataset')
|
|
parser = argparse.ArgumentParser(description='VOC-Dataset')
|
|
|
|
|
|
|
|
# Seed
|
|
# Seed
|
|
@@ -126,40 +139,20 @@ if __name__ == "__main__":
|
|
|
help='mosaic augmentation.')
|
|
help='mosaic augmentation.')
|
|
|
parser.add_argument('--mixup', default=None, type=float,
|
|
parser.add_argument('--mixup', default=None, type=float,
|
|
|
help='mixup augmentation.')
|
|
help='mixup augmentation.')
|
|
|
- # DDP train
|
|
|
|
|
- parser.add_argument('-dist', '--distributed', action='store_true', default=False,
|
|
|
|
|
- help='distributed training')
|
|
|
|
|
- parser.add_argument('--dist_url', default='env://',
|
|
|
|
|
- help='url used to set up distributed training')
|
|
|
|
|
- parser.add_argument('--world_size', default=1, type=int,
|
|
|
|
|
- help='number of distributed processes')
|
|
|
|
|
- parser.add_argument('--sybn', action='store_true', default=False,
|
|
|
|
|
- help='use sybn.')
|
|
|
|
|
# Output
|
|
# Output
|
|
|
parser.add_argument('--output_dir', type=str, default='cache_data/',
|
|
parser.add_argument('--output_dir', type=str, default='cache_data/',
|
|
|
help='data root')
|
|
help='data root')
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
args = parser.parse_args()
|
|
|
|
|
|
|
|
-
|
|
|
|
|
- assert args.aug_type in ["yolov5_pico", "yolov5_nano", "yolov5_small", "yolov5_medium", "yolov5_large", "yolov5_huge",
|
|
|
|
|
- "yolox_pico", "yolox_nano", "yolox_small", "yolox_medium", "yolox_large", "yolox_huge"]
|
|
|
|
|
-
|
|
|
|
|
|
|
|
|
|
# ------------- Build transform config -------------
|
|
# ------------- Build transform config -------------
|
|
|
dataset_cfg = build_dataset_config(args)
|
|
dataset_cfg = build_dataset_config(args)
|
|
|
- trans_config = build_trans_config(args.aug_type)
|
|
|
|
|
-
|
|
|
|
|
- # ------------- Build transform -------------
|
|
|
|
|
- transform, trans_config = build_transform(args, trans_config, max_stride=32, is_train=args.is_train)
|
|
|
|
|
|
|
|
|
|
# ------------- Build dataset -------------
|
|
# ------------- Build dataset -------------
|
|
|
- dataset, dataset_info = build_dataset(args, dataset_cfg, trans_config, transform, is_train=args.is_train)
|
|
|
|
|
|
|
+ dataset, dataset_info = build_dataset(args, dataset_cfg, trans_config=None, transform=None, is_train=args.is_train)
|
|
|
print('Data length: ', len(dataset))
|
|
print('Data length: ', len(dataset))
|
|
|
|
|
|
|
|
- # ---------------------------- Fix random seed ----------------------------
|
|
|
|
|
- fix_random_seed(args)
|
|
|
|
|
-
|
|
|
|
|
# ---------------------------- Main process ----------------------------
|
|
# ---------------------------- Main process ----------------------------
|
|
|
# We only cache the taining data
|
|
# We only cache the taining data
|
|
|
data_items = []
|
|
data_items = []
|