convert_ours_to_coco.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149
  1. import os
  2. import json
  3. import xml.etree.ElementTree as ET
  4. import glob
  5. import sys
  6. sys.path.append("..")
  7. from dataset.customed import customed_class_labels
  8. num_classes = len(customed_class_labels)
  9. categories = customed_class_labels
  10. START_BOUNDING_BOX_ID = 1
  11. PRE_DEFINE_CATEGORIES = {categories[i]: i + 1 for i in range(num_classes)}
  12. def get(root, name):
  13. vars = root.findall(name)
  14. return vars
  15. def get_and_check(root, name, length):
  16. vars = root.findall(name)
  17. if len(vars) == 0:
  18. raise ValueError("Can not find %s in %s." % (name, root.tag))
  19. if length > 0 and len(vars) != length:
  20. raise ValueError(
  21. "The size of %s is supposed to be %d, but is %d."
  22. % (name, length, len(vars))
  23. )
  24. if length == 1:
  25. vars = vars[0]
  26. return vars
  27. def get_filename_as_int(filename):
  28. try:
  29. filename = filename.replace("\\", "/")
  30. filename = os.path.splitext(os.path.basename(filename))[0]
  31. return int(filename)
  32. except:
  33. raise ValueError("Filename %s is supposed to be an integer." % (filename))
  34. def get_categories(xml_files):
  35. """Generate category name to id mapping from a list of xml files.
  36. Arguments:
  37. xml_files {list} -- A list of xml file paths.
  38. Returns:
  39. dict -- category name to id mapping.
  40. """
  41. classes_names = []
  42. for xml_file in xml_files:
  43. tree = ET.parse(xml_file)
  44. root = tree.getroot()
  45. for member in root.findall("object"):
  46. classes_names.append(member[0].text)
  47. classes_names = list(set(classes_names))
  48. classes_names.sort()
  49. return {name: i for i, name in enumerate(classes_names)}
  50. def convert(xml_files, json_file):
  51. json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
  52. if PRE_DEFINE_CATEGORIES is not None:
  53. categories = PRE_DEFINE_CATEGORIES
  54. else:
  55. categories = get_categories(xml_files)
  56. bnd_id = START_BOUNDING_BOX_ID
  57. for i, xml_file in enumerate(xml_files):
  58. if i % 100 == 0:
  59. print('[{}] / [{}]'.format(i, len(xml_files)))
  60. tree = ET.parse(xml_file)
  61. root = tree.getroot()
  62. filename = get_and_check(root, "filename", 1).text
  63. ## The filename must be a number
  64. image_id = get_filename_as_int(filename)
  65. size = get_and_check(root, "size", 1)
  66. width = int(get_and_check(size, "width", 1).text)
  67. height = int(get_and_check(size, "height", 1).text)
  68. image = {
  69. "file_name": filename,
  70. "height": height,
  71. "width": width,
  72. "id": image_id,
  73. }
  74. json_dict["images"].append(image)
  75. ## Currently we do not support segmentation.
  76. # segmented = get_and_check(root, 'segmented', 1).text
  77. # assert segmented == '0'
  78. for obj in get(root, "object"):
  79. category = get_and_check(obj, "name", 1).text
  80. if category not in categories:
  81. new_id = len(categories)
  82. categories[category] = new_id
  83. category_id = categories[category]
  84. bndbox = get_and_check(obj, "bndbox", 1)
  85. xmin = int(get_and_check(bndbox, "xmin", 1).text) - 1
  86. ymin = int(get_and_check(bndbox, "ymin", 1).text) - 1
  87. xmax = int(get_and_check(bndbox, "xmax", 1).text)
  88. ymax = int(get_and_check(bndbox, "ymax", 1).text)
  89. assert xmax > xmin
  90. assert ymax > ymin
  91. o_width = abs(xmax - xmin)
  92. o_height = abs(ymax - ymin)
  93. ann = {
  94. "area": o_width * o_height,
  95. "iscrowd": 0,
  96. "image_id": image_id,
  97. "bbox": [xmin, ymin, o_width, o_height],
  98. "category_id": category_id,
  99. "id": bnd_id,
  100. "ignore": 0,
  101. "segmentation": [],
  102. }
  103. json_dict["annotations"].append(ann)
  104. bnd_id = bnd_id + 1
  105. for cate, cid in categories.items():
  106. cat = {"supercategory": "none", "id": cid, "name": cate}
  107. json_dict["categories"].append(cat)
  108. os.makedirs(os.path.dirname(json_file), exist_ok=True)
  109. json_fp = open(json_file, "w")
  110. json_str = json.dumps(json_dict)
  111. json_fp.write(json_str)
  112. json_fp.close()
  113. if __name__ == "__main__":
  114. import argparse
  115. parser = argparse.ArgumentParser(
  116. description="Convert VOC-style annotation labele by LabelImg to COCO format."
  117. )
  118. parser.add_argument("--root", default="path/to/customed_dataset", type=str,
  119. help="Directory path to dataset.", )
  120. parser.add_argument("--split", default='train',
  121. help="split of dataset.", type=str)
  122. parser.add_argument("-anno", "--annotations", default='annotations',
  123. help="Directory path to xml files.", type=str)
  124. parser.add_argument("-json", "--json_file", default='train.json',
  125. help="Output COCO format json file.", type=str)
  126. args = parser.parse_args()
  127. data_dir = os.path.join(args.root, args.split)
  128. anno_path = os.path.join(data_dir, args.annotations)
  129. xml_files = glob.glob(os.path.join(anno_path, "*.xml"))
  130. json_file = os.path.join(data_dir, args.annotations, '{}.json'.format(args.split))
  131. print("Number of xml files: {}".format(len(xml_files)))
  132. print("Converting to COCO format ...")
  133. convert(xml_files, json_file)
  134. print("Success: {}".format(json_file))