kmeans_anchor.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # ---------------------------------------------------------------------
  2. # The code referenced to https://blog.csdn.net/xiaomifanhxx/article/details/81215051?ydreferer=aHR0cHM6Ly93d3cuYmFpZHUuY29tL2xpbms%2FdXJsPUlCVmlkNUlCMl9xRDNFWWpFU25rNDZ4NkhOOGNuRXluaG1rckNSaWowRjBPWUJ4eTVDSDhZVlBhU013bG1DZk5UVEdyMDd2aGNPX3NGSGlHcXFoZ1J4QzZNOGQySWhSZHhQS25IZkJFU2t5JndkPSZlcWlkPWE1ODhmMDI0MDAzNGZhOGIwMDAwMDAwNTY0Mzc0ZWMy
  3. # ---------------------------------------------------------------------
  4. import numpy as np
  5. import random
  6. import argparse
  7. import os
  8. import sys
  9. sys.path.append('..')
  10. from dataset.voc import VOCDetection
  11. from dataset.coco import COCODataset
  12. def parse_args():
  13. parser = argparse.ArgumentParser(description='kmeans for anchor box')
  14. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  15. help='data root')
  16. parser.add_argument('-d', '--dataset', default='coco',
  17. help='coco, widerface, crowdhuman')
  18. parser.add_argument('-na', '--num_anchorbox', default=5, type=int,
  19. help='number of anchor box.')
  20. parser.add_argument('-size', '--img_size', default=416, type=int,
  21. help='input size.')
  22. parser.add_argument('--max_iter', default=1000, type=int,
  23. help='input size.')
  24. parser.add_argument('-ab', '--absolute', action='store_true', default=False,
  25. help='absolute coords.')
  26. return parser.parse_args()
  27. args = parse_args()
  28. class Box():
  29. def __init__(self, x, y, w, h):
  30. self.x = x
  31. self.y = y
  32. self.w = w
  33. self.h = h
  34. def iou(box1, box2):
  35. x1, y1, w1, h1 = box1.x, box1.y, box1.w, box1.h
  36. x2, y2, w2, h2 = box2.x, box2.y, box2.w, box2.h
  37. S_1 = w1 * h1
  38. S_2 = w2 * h2
  39. xmin_1, ymin_1 = x1 - w1 / 2, y1 - h1 / 2
  40. xmax_1, ymax_1 = x1 + w1 / 2, y1 + h1 / 2
  41. xmin_2, ymin_2 = x2 - w2 / 2, y2 - h2 / 2
  42. xmax_2, ymax_2 = x2 + w2 / 2, y2 + h2 / 2
  43. I_w = min(xmax_1, xmax_2) - max(xmin_1, xmin_2)
  44. I_h = min(ymax_1, ymax_2) - max(ymin_1, ymin_2)
  45. if I_w < 0 or I_h < 0:
  46. return 0
  47. I = I_w * I_h
  48. IoU = I / (S_1 + S_2 - I)
  49. return IoU
  50. def init_centroids(boxes, n_anchors):
  51. centroids = []
  52. boxes_num = len(boxes)
  53. centroid_index = int(np.random.choice(boxes_num, 1)[0])
  54. centroids.append(boxes[centroid_index])
  55. print(centroids[0].w,centroids[0].h)
  56. for centroid_index in range(0, n_anchors-1):
  57. sum_distance = 0
  58. distance_thresh = 0
  59. distance_list = []
  60. cur_sum = 0
  61. for box in boxes:
  62. min_distance = 1
  63. for centroid_i, centroid in enumerate(centroids):
  64. distance = (1 - iou(box, centroid))
  65. if distance < min_distance:
  66. min_distance = distance
  67. sum_distance += min_distance
  68. distance_list.append(min_distance)
  69. distance_thresh = sum_distance * np.random.random()
  70. for i in range(0, boxes_num):
  71. cur_sum += distance_list[i]
  72. if cur_sum > distance_thresh:
  73. centroids.append(boxes[i])
  74. print(boxes[i].w, boxes[i].h)
  75. break
  76. return centroids
  77. def do_kmeans(n_anchors, boxes, centroids):
  78. loss = 0
  79. groups = []
  80. new_centroids = []
  81. for i in range(n_anchors):
  82. groups.append([])
  83. new_centroids.append(Box(0, 0, 0, 0))
  84. for box in boxes:
  85. min_distance = 1
  86. group_index = 0
  87. for centroid_index, centroid in enumerate(centroids):
  88. distance = (1 - iou(box, centroid))
  89. if distance < min_distance:
  90. min_distance = distance
  91. group_index = centroid_index
  92. groups[group_index].append(box)
  93. loss += min_distance
  94. new_centroids[group_index].w += box.w
  95. new_centroids[group_index].h += box.h
  96. for i in range(n_anchors):
  97. new_centroids[i].w /= max(len(groups[i]), 1)
  98. new_centroids[i].h /= max(len(groups[i]), 1)
  99. return new_centroids, groups, loss# / len(boxes)
  100. def anchor_box_kmeans(total_gt_boxes, n_anchors, loss_convergence, iters, plus=True):
  101. """
  102. This function will use k-means to get appropriate anchor boxes for train dataset.
  103. Input:
  104. total_gt_boxes:
  105. n_anchor : int -> the number of anchor boxes.
  106. loss_convergence : float -> threshold of iterating convergence.
  107. iters: int -> the number of iterations for training kmeans.
  108. Output: anchor_boxes : list -> [[w1, h1], [w2, h2], ..., [wn, hn]].
  109. """
  110. boxes = total_gt_boxes
  111. centroids = []
  112. if plus:
  113. centroids = init_centroids(boxes, n_anchors)
  114. else:
  115. total_indexs = range(len(boxes))
  116. sample_indexs = random.sample(total_indexs, n_anchors)
  117. for i in sample_indexs:
  118. centroids.append(boxes[i])
  119. # iterate k-means
  120. centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids)
  121. iterations = 1
  122. while(True):
  123. centroids, groups, loss = do_kmeans(n_anchors, boxes, centroids)
  124. iterations += 1
  125. print("Loss = %f" % loss)
  126. if abs(old_loss - loss) < loss_convergence or iterations > iters:
  127. break
  128. old_loss = loss
  129. for centroid in centroids:
  130. print(centroid.w, centroid.h)
  131. print("k-means result : ")
  132. for centroid in centroids:
  133. if args.absolute:
  134. print("w, h: ", round(centroid.w, 2), round(centroid.h, 2),
  135. "area: ", round(centroid.w, 2) * round(centroid.h, 2))
  136. else:
  137. print("w, h: ", round(centroid.w / 32, 2), round(centroid.h / 32, 2),
  138. "area: ", round(centroid.w / 32, 2) * round(centroid.h / 32, 2))
  139. return centroids
  140. if __name__ == "__main__":
  141. # prepare
  142. boxes = []
  143. if args.dataset == 'voc':
  144. data_root = os.path.join(args.root, 'VOCdevkit')
  145. dataset = VOCDetection(data_dir=data_root)
  146. # VOC
  147. for i in range(len(dataset)):
  148. if i % 5000 == 0:
  149. print('Loading voc data [%d / %d]' % (i+1, len(dataset)))
  150. # For VOC
  151. img, _ = dataset.pull_image(i)
  152. img_h, img_w = img.shape[:2]
  153. _, annotation = dataset.pull_anno(i)
  154. # prepare bbox datas
  155. for box_and_label in annotation:
  156. box = box_and_label[:-1]
  157. xmin, ymin, xmax, ymax = box
  158. bw = (xmax - xmin) / max(img_w, img_h) * args.img_size
  159. bh = (ymax - ymin) / max(img_w, img_h) * args.img_size
  160. # check bbox
  161. if bw < 1.0 or bh < 1.0:
  162. continue
  163. boxes.append(Box(0, 0, bw, bh))
  164. break
  165. elif args.dataset == 'coco':
  166. data_root = os.path.join(args.root, 'COCO')
  167. dataset = COCODataset(data_dir=data_root, img_size=args.img_size)
  168. for i in range(len(dataset)):
  169. if i % 5000 == 0:
  170. print('Loading coco datat [%d / %d]' % (i+1, len(dataset)))
  171. # For COCO
  172. img, _ = dataset.pull_image(i)
  173. img_h, img_w = img.shape[:2]
  174. annotation = dataset.pull_anno(i)
  175. # prepare bbox datas
  176. for box_and_label in annotation:
  177. box = box_and_label[:-1]
  178. xmin, ymin, xmax, ymax = box
  179. bw = (xmax - xmin) / max(img_w, img_h) * args.img_size
  180. bh = (ymax - ymin) / max(img_w, img_h) * args.img_size
  181. # check bbox
  182. if bw < 1.0 or bh < 1.0:
  183. continue
  184. boxes.append(Box(0, 0, bw, bh))
  185. print("Number of all bboxes: ", len(boxes))
  186. print("Start k-means !")
  187. centroids = anchor_box_kmeans(boxes, args.num_anchorbox, 1e-6, args.max_iter, plus=True)