convert_ours_to_coco.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import os
  2. import json
  3. import xml.etree.ElementTree as ET
  4. import glob
  5. import sys
  6. sys.path.append("..")
  7. from config.data_config.dataset_config import dataset_cfg
  8. data_config = dataset_cfg["ourdataset"]
  9. num_classes = data_config["num_classes"]
  10. categories = data_config["class_names"]
  11. START_BOUNDING_BOX_ID = 1
  12. PRE_DEFINE_CATEGORIES = {categories[i]: i + 1 for i in range(num_classes)}
  13. def get(root, name):
  14. vars = root.findall(name)
  15. return vars
  16. def get_and_check(root, name, length):
  17. vars = root.findall(name)
  18. if len(vars) == 0:
  19. raise ValueError("Can not find %s in %s." % (name, root.tag))
  20. if length > 0 and len(vars) != length:
  21. raise ValueError(
  22. "The size of %s is supposed to be %d, but is %d."
  23. % (name, length, len(vars))
  24. )
  25. if length == 1:
  26. vars = vars[0]
  27. return vars
  28. def get_filename_as_int(filename):
  29. try:
  30. filename = filename.replace("\\", "/")
  31. filename = os.path.splitext(os.path.basename(filename))[0]
  32. return int(filename)
  33. except:
  34. raise ValueError("Filename %s is supposed to be an integer." % (filename))
  35. def get_categories(xml_files):
  36. """Generate category name to id mapping from a list of xml files.
  37. Arguments:
  38. xml_files {list} -- A list of xml file paths.
  39. Returns:
  40. dict -- category name to id mapping.
  41. """
  42. classes_names = []
  43. for xml_file in xml_files:
  44. tree = ET.parse(xml_file)
  45. root = tree.getroot()
  46. for member in root.findall("object"):
  47. classes_names.append(member[0].text)
  48. classes_names = list(set(classes_names))
  49. classes_names.sort()
  50. return {name: i for i, name in enumerate(classes_names)}
  51. def convert(xml_files, json_file):
  52. json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
  53. if PRE_DEFINE_CATEGORIES is not None:
  54. categories = PRE_DEFINE_CATEGORIES
  55. else:
  56. categories = get_categories(xml_files)
  57. bnd_id = START_BOUNDING_BOX_ID
  58. for i, xml_file in enumerate(xml_files):
  59. if i % 100 == 0:
  60. print('[{}] / [{}]'.format(i, len(xml_files)))
  61. tree = ET.parse(xml_file)
  62. root = tree.getroot()
  63. filename = get_and_check(root, "filename", 1).text
  64. ## The filename must be a number
  65. image_id = get_filename_as_int(filename)
  66. size = get_and_check(root, "size", 1)
  67. width = int(get_and_check(size, "width", 1).text)
  68. height = int(get_and_check(size, "height", 1).text)
  69. image = {
  70. "file_name": filename,
  71. "height": height,
  72. "width": width,
  73. "id": image_id,
  74. }
  75. json_dict["images"].append(image)
  76. ## Currently we do not support segmentation.
  77. # segmented = get_and_check(root, 'segmented', 1).text
  78. # assert segmented == '0'
  79. for obj in get(root, "object"):
  80. category = get_and_check(obj, "name", 1).text
  81. if category not in categories:
  82. new_id = len(categories)
  83. categories[category] = new_id
  84. category_id = categories[category]
  85. bndbox = get_and_check(obj, "bndbox", 1)
  86. xmin = int(get_and_check(bndbox, "xmin", 1).text) - 1
  87. ymin = int(get_and_check(bndbox, "ymin", 1).text) - 1
  88. xmax = int(get_and_check(bndbox, "xmax", 1).text)
  89. ymax = int(get_and_check(bndbox, "ymax", 1).text)
  90. assert xmax > xmin
  91. assert ymax > ymin
  92. o_width = abs(xmax - xmin)
  93. o_height = abs(ymax - ymin)
  94. ann = {
  95. "area": o_width * o_height,
  96. "iscrowd": 0,
  97. "image_id": image_id,
  98. "bbox": [xmin, ymin, o_width, o_height],
  99. "category_id": category_id,
  100. "id": bnd_id,
  101. "ignore": 0,
  102. "segmentation": [],
  103. }
  104. json_dict["annotations"].append(ann)
  105. bnd_id = bnd_id + 1
  106. for cate, cid in categories.items():
  107. cat = {"supercategory": "none", "id": cid, "name": cate}
  108. json_dict["categories"].append(cat)
  109. os.makedirs(os.path.dirname(json_file), exist_ok=True)
  110. json_fp = open(json_file, "w")
  111. json_str = json.dumps(json_dict)
  112. json_fp.write(json_str)
  113. json_fp.close()
  114. if __name__ == "__main__":
  115. import argparse
  116. parser = argparse.ArgumentParser(
  117. description="Convert VOC-style annotation labele by LabelImg to COCO format."
  118. )
  119. parser.add_argument("--root", help="Directory path to dataset.", type=str)
  120. parser.add_argument("--split", default='train',
  121. help="split of dataset.", type=str)
  122. parser.add_argument("-anno", "--annotations", default='annotations',
  123. help="Directory path to xml files.", type=str)
  124. parser.add_argument("-json", "--json_file", default='train.json',
  125. help="Output COCO format json file.", type=str)
  126. args = parser.parse_args()
  127. data_dir = os.path.join(args.root, args.split)
  128. anno_path = os.path.join(data_dir, args.annotations)
  129. xml_files = glob.glob(os.path.join(anno_path, "*.xml"))
  130. json_file = os.path.join(data_dir, args.annotations, '{}.json'.format(args.split))
  131. print("Number of xml files: {}".format(len(xml_files)))
  132. print("Converting to COCO format ...")
  133. convert(xml_files, json_file)
  134. print("Success: {}".format(json_file))