clean_coco.py 3.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485
  1. import os
  2. import json
  3. if __name__ == "__main__":
  4. import argparse
  5. parser = argparse.ArgumentParser(description='COCO-Dataset')
  6. # --------------- opt parameters ---------------
  7. parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
  8. help='data root')
  9. parser.add_argument('--image_set', type=str, default='val',
  10. help='augmentation type')
  11. parser.add_argument('--task', type=str, default='det',
  12. help='augmentation type')
  13. args = parser.parse_args()
  14. # --------------- load json ---------------
  15. if args.task == 'det':
  16. task_prefix = 'instances_{}2017.json'
  17. clean_task_prefix = 'instances_{}2017_clean.json'
  18. elif args.task == 'pos':
  19. task_prefix = 'person_keypoints_{}2017.json'
  20. clean_task_prefix = 'person_keypoints_{}2017_clean.json'
  21. else:
  22. raise NotImplementedError('Unkown task !')
  23. json_path = os.path.join(args.root, 'annotations', task_prefix.format(args.image_set))
  24. clean_json_file = dict()
  25. with open(json_path, 'r') as file:
  26. json_file = json.load(file)
  27. # json_file is a Dict: dict_keys(['info', 'licenses', 'images', 'annotations', 'categories'])
  28. clean_json_file['info'] = json_file['info']
  29. clean_json_file['licenses'] = json_file['licenses']
  30. clean_json_file['categories'] = json_file['categories']
  31. images_list = json_file['images']
  32. annots_list = json_file['annotations']
  33. num_images = len(images_list)
  34. # -------------- Filter annotations --------------
  35. print("Processing annotations ...")
  36. valid_image_ids = []
  37. clean_annots_list = []
  38. for i, anno in enumerate(annots_list):
  39. if i % 5000 == 0:
  40. print("[{}] / [{}] ...".format(i, len(annots_list)))
  41. x1, y1, bw, bh = anno['bbox']
  42. if bw > 0 and bh > 0:
  43. clean_annots_list.append(anno)
  44. if anno['image_id'] not in valid_image_ids:
  45. valid_image_ids.append(anno['image_id'])
  46. print("Valid number of images: ", len(valid_image_ids))
  47. print("Valid number of annots: ", len(clean_annots_list))
  48. print("Original number of annots: ", len(annots_list))
  49. # -------------- Filter images --------------
  50. print("Processing images ...")
  51. clean_images_list = []
  52. for i in range(num_images):
  53. if args.image_set == 'train' and i % 5000 == 0:
  54. print("[{}] / [{}] ...".format(i, num_images))
  55. if args.image_set == 'val' and i % 500 == 0:
  56. print("[{}] / [{}] ...".format(i, num_images))
  57. # A single image dict
  58. image_dict = images_list[i]
  59. image_id = image_dict['id']
  60. if image_id in valid_image_ids:
  61. clean_images_list.append(image_dict)
  62. print('Number of images after cleaning: ', len(clean_images_list))
  63. print('Number of annotations after cleaning: ', len(clean_annots_list))
  64. clean_json_file['images'] = clean_images_list
  65. clean_json_file['annotations'] = clean_annots_list
  66. # --------------- Save filterd json file ---------------
  67. new_json_path = os.path.join(args.root, 'annotations', clean_task_prefix.format(args.image_set))
  68. with open(new_json_path, 'w') as f:
  69. json.dump(clean_json_file, f)