kmeans_anchor.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234
  1. # ---------------------------------------------------------------------
  2. # The code referenced to https://blog.csdn.net/xiaomifanhxx/article/details/81215051?ydreferer=aHR0cHM6Ly93d3cuYmFpZHUuY29tL2xpbms%2FdXJsPUlCVmlkNUlCMl9xRDNFWWpFU25rNDZ4NkhOOGNuRXluaG1rckNSaWowRjBPWUJ4eTVDSDhZVlBhU013bG1DZk5UVEdyMDd2aGNPX3NGSGlHcXFoZ1J4QzZNOGQySWhSZHhQS25IZkJFU2t5JndkPSZlcWlkPWE1ODhmMDI0MDAzNGZhOGIwMDAwMDAwNTY0Mzc0ZWMy
  3. # ---------------------------------------------------------------------
  4. import numpy as np
  5. import random
  6. import argparse
  7. import os
  8. import sys
  9. sys.path.append('..')
  10. from dataset.voc import VOCDetection
  11. from dataset.coco import COCODataset
  12. def parse_args():
  13. parser = argparse.ArgumentParser(description='kmeans for anchor box')
  14. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  15. help='data root')
  16. parser.add_argument('-d', '--dataset', default='coco',
  17. help='coco, widerface, crowdhuman')
  18. parser.add_argument('-na', '--num_anchorbox', default=5, type=int,
  19. help='number of anchor box.')
  20. parser.add_argument('-size', '--img_size', default=416, type=int,
  21. help='input size.')
  22. parser.add_argument('-ab', '--absolute', action='store_true', default=False,
  23. help='absolute coords.')
  24. return parser.parse_args()
  25. args = parse_args()
  26. class Box():
  27. def __init__(self, x, y, w, h):
  28. self.x = x
  29. self.y = y
  30. self.w = w
  31. self.h = h
  32. def iou(box1, box2):
  33. x1, y1, w1, h1 = box1.x, box1.y, box1.w, box1.h
  34. x2, y2, w2, h2 = box2.x, box2.y, box2.w, box2.h
  35. S_1 = w1 * h1
  36. S_2 = w2 * h2
  37. xmin_1, ymin_1 = x1 - w1 / 2, y1 - h1 / 2
  38. xmax_1, ymax_1 = x1 + w1 / 2, y1 + h1 / 2
  39. xmin_2, ymin_2 = x2 - w2 / 2, y2 - h2 / 2
  40. xmax_2, ymax_2 = x2 + w2 / 2, y2 + h2 / 2
  41. I_w = min(xmax_1, xmax_2) - max(xmin_1, xmin_2)
  42. I_h = min(ymax_1, ymax_2) - max(ymin_1, ymin_2)
  43. if I_w < 0 or I_h < 0:
  44. return 0
  45. I = I_w * I_h
  46. IoU = I / (S_1 + S_2 - I)
  47. return IoU
  48. def init_centroids(boxes, n_anchors):
  49. centroids = []
  50. boxes_num = len(boxes)
  51. centroid_index = int(np.random.choice(boxes_num, 1)[0])
  52. centroids.append(boxes[centroid_index])
  53. print(centroids[0].w,centroids[0].h)
  54. for centroid_index in range(0, n_anchors-1):
  55. sum_distance = 0
  56. distance_thresh = 0
  57. distance_list = []
  58. cur_sum = 0
  59. for box in boxes:
  60. min_distance = 1
  61. for centroid_i, centroid in enumerate(centroids):
  62. distance = (1 - iou(box, centroid))
  63. if distance < min_distance:
  64. min_distance = distance
  65. sum_distance += min_distance
  66. distance_list.append(min_distance)
  67. distance_thresh = sum_distance * np.random.random()
  68. for i in range(0, boxes_num):
  69. cur_sum += distance_list[i]
  70. if cur_sum > distance_thresh:
  71. centroids.append(boxes[i])
  72. print(boxes[i].w, boxes[i].h)
  73. break
  74. return centroids
  75. def do_kmeans(n_anchors, boxes, centroids):
  76. loss = 0
  77. groups = []
  78. new_centroids = []
  79. for i in range(n_anchors):
  80. groups.append([])
  81. new_centroids.append(Box(0, 0, 0, 0))
  82. for box in boxes:
  83. min_distance = 1
  84. group_index = 0
  85. for centroid_index, centroid in enumerate(centroids):
  86. distance = (1 - iou(box, centroid))
  87. if distance < min_distance:
  88. min_distance = distance
  89. group_index = centroid_index
  90. groups[group_index].append(box)
  91. loss += min_distance
  92. new_centroids[group_index].w += box.w
  93. new_centroids[group_index].h += box.h
  94. for i in range(n_anchors):
  95. new_centroids[i].w /= max(len(groups[i]), 1)
  96. new_centroids[i].h /= max(len(groups[i]), 1)
  97. return new_centroids, groups, loss# / len(boxes)
  98. def anchor_box_kmeans(total_gt_boxes, n_anchors, loss_convergence, iters, plus=True):
  99. """
  100. This function will use k-means to get appropriate anchor boxes for train dataset.
  101. Input:
  102. total_gt_boxes:
  103. n_anchor : int -> the number of anchor boxes.
  104. loss_convergence : float -> threshold of iterating convergence.
  105. iters: int -> the number of iterations for training kmeans.
  106. Output: anchor_boxes : list -> [[w1, h1], [w2, h2], ..., [wn, hn]].
  107. """
  108. boxes = total_gt_boxes
  109. centroids = []
  110. if plus:
  111. centroids = init_centroids(boxes, n_anchors)
  112. else:
  113. total_indexs = range(len(boxes))
  114. sample_indexs = random.sample(total_indexs, n_anchors)
  115. for i in sample_indexs:
  116. centroids.append(boxes[i])
  117. # iterate k-means
  118. centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids)
  119. iterations = 1
  120. while(True):
  121. centroids, groups, loss = do_kmeans(n_anchors, boxes, centroids)
  122. iterations += 1
  123. print("Loss = %f" % loss)
  124. if abs(old_loss - loss) < loss_convergence or iterations > iters:
  125. break
  126. old_loss = loss
  127. for centroid in centroids:
  128. print(centroid.w, centroid.h)
  129. print("k-means result : ")
  130. for centroid in centroids:
  131. if args.absolute:
  132. print("w, h: ", round(centroid.w, 2), round(centroid.h, 2),
  133. "area: ", round(centroid.w, 2) * round(centroid.h, 2))
  134. else:
  135. print("w, h: ", round(centroid.w / 32, 2), round(centroid.h / 32, 2),
  136. "area: ", round(centroid.w / 32, 2) * round(centroid.h / 32, 2))
  137. return centroids
  138. if __name__ == "__main__":
  139. n_anchors = args.num_anchorbox
  140. img_size = args.img_size
  141. loss_convergence = 1e-6
  142. iters_n = 1000
  143. boxes = []
  144. if args.dataset == 'voc':
  145. data_root = os.path.join(args.root, 'VOCdevkit')
  146. dataset = VOCDetection(data_dir=data_root)
  147. # VOC
  148. for i in range(len(dataset)):
  149. if i % 5000 == 0:
  150. print('Loading voc data [%d / %d]' % (i+1, len(dataset)))
  151. # For VOC
  152. img, _ = dataset.pull_image(i)
  153. img_h, img_w = img.shape[:2]
  154. _, annotation = dataset.pull_anno(i)
  155. # prepare bbox datas
  156. for box_and_label in annotation:
  157. box = box_and_label[:-1]
  158. xmin, ymin, xmax, ymax = box
  159. bw = (xmax - xmin) / max(img_w, img_h) * img_size
  160. bh = (ymax - ymin) / max(img_w, img_h) * img_size
  161. # check bbox
  162. if bw < 1.0 or bh < 1.0:
  163. continue
  164. boxes.append(Box(0, 0, bw, bh))
  165. break
  166. elif args.dataset == 'coco':
  167. data_root = os.path.join(args.root, 'COCO')
  168. dataset = COCODataset(data_dir=data_root, img_size=img_size)
  169. for i in range(len(dataset)):
  170. if i % 5000 == 0:
  171. print('Loading coco datat [%d / %d]' % (i+1, len(dataset)))
  172. # For COCO
  173. img, _ = dataset.pull_image(i)
  174. img_h, img_w = img.shape[:2]
  175. annotation = dataset.pull_anno(i)
  176. # prepare bbox datas
  177. for box_and_label in annotation:
  178. box = box_and_label[:-1]
  179. xmin, ymin, xmax, ymax = box
  180. bw = (xmax - xmin) / max(img_w, img_h) * img_size
  181. bh = (ymax - ymin) / max(img_w, img_h) * img_size
  182. # check bbox
  183. if bw < 1.0 or bh < 1.0:
  184. continue
  185. boxes.append(Box(0, 0, bw, bh))
  186. print("Number of all bboxes: ", len(boxes))
  187. print("Start k-means !")
  188. centroids = anchor_box_kmeans(boxes, n_anchors, loss_convergence, iters_n, plus=True)