| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149 |
- import os
- import json
- import xml.etree.ElementTree as ET
- import glob
- import sys
- sys.path.append("..")
- from yolo.dataset.custom import custom_class_labels
- num_classes = len(custom_class_labels)
- categories = custom_class_labels
- START_BOUNDING_BOX_ID = 1
- PRE_DEFINE_CATEGORIES = {categories[i]: i + 1 for i in range(num_classes)}
- def get(root, name):
- vars = root.findall(name)
- return vars
- def get_and_check(root, name, length):
- vars = root.findall(name)
- if len(vars) == 0:
- raise ValueError("Can not find %s in %s." % (name, root.tag))
- if length > 0 and len(vars) != length:
- raise ValueError(
- "The size of %s is supposed to be %d, but is %d."
- % (name, length, len(vars))
- )
- if length == 1:
- vars = vars[0]
- return vars
- def get_filename_as_int(filename):
- try:
- filename = filename.replace("\\", "/")
- filename = os.path.splitext(os.path.basename(filename))[0]
- return int(filename)
- except:
- raise ValueError("Filename %s is supposed to be an integer." % (filename))
- def get_categories(xml_files):
- """Generate category name to id mapping from a list of xml files.
-
- Arguments:
- xml_files {list} -- A list of xml file paths.
-
- Returns:
- dict -- category name to id mapping.
- """
- classes_names = []
- for xml_file in xml_files:
- tree = ET.parse(xml_file)
- root = tree.getroot()
- for member in root.findall("object"):
- classes_names.append(member[0].text)
- classes_names = list(set(classes_names))
- classes_names.sort()
- return {name: i for i, name in enumerate(classes_names)}
- def convert(xml_files, json_file):
- json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
- if PRE_DEFINE_CATEGORIES is not None:
- categories = PRE_DEFINE_CATEGORIES
- else:
- categories = get_categories(xml_files)
- bnd_id = START_BOUNDING_BOX_ID
- for i, xml_file in enumerate(xml_files):
- if i % 100 == 0:
- print('[{}] / [{}]'.format(i, len(xml_files)))
- tree = ET.parse(xml_file)
- root = tree.getroot()
- filename = get_and_check(root, "filename", 1).text
- ## The filename must be a number
- image_id = get_filename_as_int(filename)
- size = get_and_check(root, "size", 1)
- width = int(get_and_check(size, "width", 1).text)
- height = int(get_and_check(size, "height", 1).text)
- image = {
- "file_name": filename,
- "height": height,
- "width": width,
- "id": image_id,
- }
- json_dict["images"].append(image)
- ## Currently we do not support segmentation.
- # segmented = get_and_check(root, 'segmented', 1).text
- # assert segmented == '0'
- for obj in get(root, "object"):
- category = get_and_check(obj, "name", 1).text
- if category not in categories:
- new_id = len(categories)
- categories[category] = new_id
- category_id = categories[category]
- bndbox = get_and_check(obj, "bndbox", 1)
- xmin = int(get_and_check(bndbox, "xmin", 1).text) - 1
- ymin = int(get_and_check(bndbox, "ymin", 1).text) - 1
- xmax = int(get_and_check(bndbox, "xmax", 1).text)
- ymax = int(get_and_check(bndbox, "ymax", 1).text)
- assert xmax > xmin
- assert ymax > ymin
- o_width = abs(xmax - xmin)
- o_height = abs(ymax - ymin)
- ann = {
- "area": o_width * o_height,
- "iscrowd": 0,
- "image_id": image_id,
- "bbox": [xmin, ymin, o_width, o_height],
- "category_id": category_id,
- "id": bnd_id,
- "ignore": 0,
- "segmentation": [],
- }
- json_dict["annotations"].append(ann)
- bnd_id = bnd_id + 1
- for cate, cid in categories.items():
- cat = {"supercategory": "none", "id": cid, "name": cate}
- json_dict["categories"].append(cat)
- os.makedirs(os.path.dirname(json_file), exist_ok=True)
- json_fp = open(json_file, "w")
- json_str = json.dumps(json_dict)
- json_fp.write(json_str)
- json_fp.close()
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser(
- description="Convert VOC-style annotation labele by LabelImg to COCO format."
- )
- parser.add_argument("--root", default="path/to/custom_dataset", type=str,
- help="Directory path to dataset.", )
- parser.add_argument("--split", default='train',
- help="split of dataset.", type=str)
- parser.add_argument("-anno", "--annotations", default='annotations',
- help="Directory path to xml files.", type=str)
- parser.add_argument("-json", "--json_file", default='train.json',
- help="Output COCO format json file.", type=str)
- args = parser.parse_args()
- data_dir = os.path.join(args.root, args.split)
- anno_path = os.path.join(data_dir, args.annotations)
- xml_files = glob.glob(os.path.join(anno_path, "*.xml"))
- json_file = os.path.join(data_dir, args.annotations, '{}.json'.format(args.split))
- print("Number of xml files: {}".format(len(xml_files)))
- print("Converting to COCO format ...")
- convert(xml_files, json_file)
- print("Success: {}".format(json_file))
|