convert_ours_to_coco.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159
  1. import os
  2. import json
  3. import xml.etree.ElementTree as ET
  4. import glob
  5. START_BOUNDING_BOX_ID = 1
  6. PRE_DEFINE_CATEGORIES = None
  7. # If necessary, pre-define category and its id
  8. # PRE_DEFINE_CATEGORIES = {"aeroplane": 1, "bicycle": 2, "bird": 3, "boat": 4,
  9. # "bottle":5, "bus": 6, "car": 7, "cat": 8, "chair": 9,
  10. # "cow": 10, "diningtable": 11, "dog": 12, "horse": 13,
  11. # "motorbike": 14, "person": 15, "pottedplant": 16,
  12. # "sheep": 17, "sofa": 18, "train": 19, "tvmonitor": 20}
  13. def get(root, name):
  14. vars = root.findall(name)
  15. return vars
  16. def get_and_check(root, name, length):
  17. vars = root.findall(name)
  18. if len(vars) == 0:
  19. raise ValueError("Can not find %s in %s." % (name, root.tag))
  20. if length > 0 and len(vars) != length:
  21. raise ValueError(
  22. "The size of %s is supposed to be %d, but is %d."
  23. % (name, length, len(vars))
  24. )
  25. if length == 1:
  26. vars = vars[0]
  27. return vars
  28. def get_filename_as_int(filename):
  29. try:
  30. filename = filename.replace("\\", "/")
  31. filename = os.path.splitext(os.path.basename(filename))[0]
  32. return int(filename)
  33. except:
  34. raise ValueError("Filename %s is supposed to be an integer." % (filename))
  35. def get_categories(xml_files):
  36. """Generate category name to id mapping from a list of xml files.
  37. Arguments:
  38. xml_files {list} -- A list of xml file paths.
  39. Returns:
  40. dict -- category name to id mapping.
  41. """
  42. classes_names = []
  43. for xml_file in xml_files:
  44. tree = ET.parse(xml_file)
  45. root = tree.getroot()
  46. for member in root.findall("object"):
  47. classes_names.append(member[0].text)
  48. classes_names = list(set(classes_names))
  49. classes_names.sort()
  50. return {name: i for i, name in enumerate(classes_names)}
  51. def convert(xml_files, json_file):
  52. json_dict = {"images": [], "type": "instances", "annotations": [], "categories": []}
  53. if PRE_DEFINE_CATEGORIES is not None:
  54. categories = PRE_DEFINE_CATEGORIES
  55. else:
  56. categories = get_categories(xml_files)
  57. bnd_id = START_BOUNDING_BOX_ID
  58. for i, xml_file in enumerate(xml_files):
  59. if i % 100 == 0:
  60. print('[{}] / [{}]'.format(i, len(xml_files)))
  61. tree = ET.parse(xml_file)
  62. root = tree.getroot()
  63. path = get(root, "path")
  64. if len(path) == 1:
  65. filename = os.path.basename(path[0].text)
  66. elif len(path) == 0:
  67. filename = get_and_check(root, "filename", 1).text
  68. else:
  69. raise ValueError("%d paths found in %s" % (len(path), xml_file))
  70. ## The filename must be a number
  71. image_id = get_filename_as_int(filename)
  72. size = get_and_check(root, "size", 1)
  73. width = int(get_and_check(size, "width", 1).text)
  74. height = int(get_and_check(size, "height", 1).text)
  75. image = {
  76. "file_name": filename,
  77. "height": height,
  78. "width": width,
  79. "id": image_id,
  80. }
  81. json_dict["images"].append(image)
  82. ## Currently we do not support segmentation.
  83. # segmented = get_and_check(root, 'segmented', 1).text
  84. # assert segmented == '0'
  85. for obj in get(root, "object"):
  86. category = get_and_check(obj, "name", 1).text
  87. if category not in categories:
  88. new_id = len(categories)
  89. categories[category] = new_id
  90. category_id = categories[category]
  91. bndbox = get_and_check(obj, "bndbox", 1)
  92. xmin = int(get_and_check(bndbox, "xmin", 1).text) - 1
  93. ymin = int(get_and_check(bndbox, "ymin", 1).text) - 1
  94. xmax = int(get_and_check(bndbox, "xmax", 1).text)
  95. ymax = int(get_and_check(bndbox, "ymax", 1).text)
  96. assert xmax > xmin
  97. assert ymax > ymin
  98. o_width = abs(xmax - xmin)
  99. o_height = abs(ymax - ymin)
  100. ann = {
  101. "area": o_width * o_height,
  102. "iscrowd": 0,
  103. "image_id": image_id,
  104. "bbox": [xmin, ymin, o_width, o_height],
  105. "category_id": category_id,
  106. "id": bnd_id,
  107. "ignore": 0,
  108. "segmentation": [],
  109. }
  110. json_dict["annotations"].append(ann)
  111. bnd_id = bnd_id + 1
  112. for cate, cid in categories.items():
  113. cat = {"supercategory": "none", "id": cid, "name": cate}
  114. json_dict["categories"].append(cat)
  115. os.makedirs(os.path.dirname(json_file), exist_ok=True)
  116. json_fp = open(json_file, "w")
  117. json_str = json.dumps(json_dict)
  118. json_fp.write(json_str)
  119. json_fp.close()
  120. if __name__ == "__main__":
  121. import argparse
  122. parser = argparse.ArgumentParser(
  123. description="Convert VOC-style annotation labele by LabelImg to COCO format."
  124. )
  125. parser.add_argument("--root", help="Directory path to dataset.", type=str)
  126. parser.add_argument("--split", default='train',
  127. help="split of dataset.", type=str)
  128. parser.add_argument("-anno", "--annotations", default='annotations',
  129. help="Directory path to xml files.", type=str)
  130. parser.add_argument("-json", "--json_file", default='train.json',
  131. help="Output COCO format json file.", type=str)
  132. args = parser.parse_args()
  133. data_dir = os.path.join(args.root, args.split)
  134. anno_path = os.path.join(data_dir, args.annotations)
  135. xml_files = glob.glob(os.path.join(anno_path, "*.xml"))
  136. json_file = os.path.join(data_dir, args.annotations, '{}.json'.format(args.split))
  137. print("Number of xml files: {}".format(len(xml_files)))
  138. print("Converting to COCO format ...")
  139. convert(xml_files, json_file)
  140. print("Success: {}".format(args.json_file))