kmeans_anchor.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. import numpy as np
  2. import random
  3. import argparse
  4. import os
  5. import sys
  6. sys.path.append('..')
  7. from dataset.voc import VOCDetection
  8. from dataset.coco import COCODataset
  9. def parse_args():
  10. parser = argparse.ArgumentParser(description='kmeans for anchor box')
  11. parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
  12. help='data root')
  13. parser.add_argument('-d', '--dataset', default='coco',
  14. help='coco, widerface, crowdhuman')
  15. parser.add_argument('-na', '--num_anchorbox', default=5, type=int,
  16. help='number of anchor box.')
  17. parser.add_argument('-size', '--img_size', default=416, type=int,
  18. help='input size.')
  19. parser.add_argument('-ab', '--absolute', action='store_true', default=False,
  20. help='absolute coords.')
  21. return parser.parse_args()
  22. args = parse_args()
  23. class Box():
  24. def __init__(self, x, y, w, h):
  25. self.x = x
  26. self.y = y
  27. self.w = w
  28. self.h = h
  29. def iou(box1, box2):
  30. x1, y1, w1, h1 = box1.x, box1.y, box1.w, box1.h
  31. x2, y2, w2, h2 = box2.x, box2.y, box2.w, box2.h
  32. S_1 = w1 * h1
  33. S_2 = w2 * h2
  34. xmin_1, ymin_1 = x1 - w1 / 2, y1 - h1 / 2
  35. xmax_1, ymax_1 = x1 + w1 / 2, y1 + h1 / 2
  36. xmin_2, ymin_2 = x2 - w2 / 2, y2 - h2 / 2
  37. xmax_2, ymax_2 = x2 + w2 / 2, y2 + h2 / 2
  38. I_w = min(xmax_1, xmax_2) - max(xmin_1, xmin_2)
  39. I_h = min(ymax_1, ymax_2) - max(ymin_1, ymin_2)
  40. if I_w < 0 or I_h < 0:
  41. return 0
  42. I = I_w * I_h
  43. IoU = I / (S_1 + S_2 - I)
  44. return IoU
  45. def init_centroids(boxes, n_anchors):
  46. centroids = []
  47. boxes_num = len(boxes)
  48. centroid_index = int(np.random.choice(boxes_num, 1)[0])
  49. centroids.append(boxes[centroid_index])
  50. print(centroids[0].w,centroids[0].h)
  51. for centroid_index in range(0, n_anchors-1):
  52. sum_distance = 0
  53. distance_thresh = 0
  54. distance_list = []
  55. cur_sum = 0
  56. for box in boxes:
  57. min_distance = 1
  58. for centroid_i, centroid in enumerate(centroids):
  59. distance = (1 - iou(box, centroid))
  60. if distance < min_distance:
  61. min_distance = distance
  62. sum_distance += min_distance
  63. distance_list.append(min_distance)
  64. distance_thresh = sum_distance * np.random.random()
  65. for i in range(0, boxes_num):
  66. cur_sum += distance_list[i]
  67. if cur_sum > distance_thresh:
  68. centroids.append(boxes[i])
  69. print(boxes[i].w, boxes[i].h)
  70. break
  71. return centroids
  72. def do_kmeans(n_anchors, boxes, centroids):
  73. loss = 0
  74. groups = []
  75. new_centroids = []
  76. for i in range(n_anchors):
  77. groups.append([])
  78. new_centroids.append(Box(0, 0, 0, 0))
  79. for box in boxes:
  80. min_distance = 1
  81. group_index = 0
  82. for centroid_index, centroid in enumerate(centroids):
  83. distance = (1 - iou(box, centroid))
  84. if distance < min_distance:
  85. min_distance = distance
  86. group_index = centroid_index
  87. groups[group_index].append(box)
  88. loss += min_distance
  89. new_centroids[group_index].w += box.w
  90. new_centroids[group_index].h += box.h
  91. for i in range(n_anchors):
  92. new_centroids[i].w /= max(len(groups[i]), 1)
  93. new_centroids[i].h /= max(len(groups[i]), 1)
  94. return new_centroids, groups, loss# / len(boxes)
  95. def anchor_box_kmeans(total_gt_boxes, n_anchors, loss_convergence, iters, plus=True):
  96. """
  97. This function will use k-means to get appropriate anchor boxes for train dataset.
  98. Input:
  99. total_gt_boxes:
  100. n_anchor : int -> the number of anchor boxes.
  101. loss_convergence : float -> threshold of iterating convergence.
  102. iters: int -> the number of iterations for training kmeans.
  103. Output: anchor_boxes : list -> [[w1, h1], [w2, h2], ..., [wn, hn]].
  104. """
  105. boxes = total_gt_boxes
  106. centroids = []
  107. if plus:
  108. centroids = init_centroids(boxes, n_anchors)
  109. else:
  110. total_indexs = range(len(boxes))
  111. sample_indexs = random.sample(total_indexs, n_anchors)
  112. for i in sample_indexs:
  113. centroids.append(boxes[i])
  114. # iterate k-means
  115. centroids, groups, old_loss = do_kmeans(n_anchors, boxes, centroids)
  116. iterations = 1
  117. while(True):
  118. centroids, groups, loss = do_kmeans(n_anchors, boxes, centroids)
  119. iterations += 1
  120. print("Loss = %f" % loss)
  121. if abs(old_loss - loss) < loss_convergence or iterations > iters:
  122. break
  123. old_loss = loss
  124. for centroid in centroids:
  125. print(centroid.w, centroid.h)
  126. print("k-means result : ")
  127. for centroid in centroids:
  128. if args.absolute:
  129. print("w, h: ", round(centroid.w, 2), round(centroid.h, 2),
  130. "area: ", round(centroid.w, 2) * round(centroid.h, 2))
  131. else:
  132. print("w, h: ", round(centroid.w / 32, 2), round(centroid.h / 32, 2),
  133. "area: ", round(centroid.w / 32, 2) * round(centroid.h / 32, 2))
  134. return centroids
  135. if __name__ == "__main__":
  136. n_anchors = args.num_anchorbox
  137. img_size = args.img_size
  138. loss_convergence = 1e-6
  139. iters_n = 1000
  140. boxes = []
  141. if args.dataset == 'voc':
  142. data_root = os.path.join(args.root, 'VOCdevkit')
  143. dataset = VOCDetection(data_dir=data_root)
  144. # VOC
  145. for i in range(len(dataset)):
  146. if i % 5000 == 0:
  147. print('Loading voc data [%d / %d]' % (i+1, len(dataset)))
  148. # For VOC
  149. img, _ = dataset.pull_image(i)
  150. img_h, img_w = img.shape[:2]
  151. _, annotation = dataset.pull_anno(i)
  152. # prepare bbox datas
  153. for box_and_label in annotation:
  154. box = box_and_label[:-1]
  155. xmin, ymin, xmax, ymax = box
  156. bw = (xmax - xmin) / max(img_w, img_h) * img_size
  157. bh = (ymax - ymin) / max(img_w, img_h) * img_size
  158. # check bbox
  159. if bw < 1.0 or bh < 1.0:
  160. continue
  161. boxes.append(Box(0, 0, bw, bh))
  162. break
  163. elif args.dataset == 'coco':
  164. data_root = os.path.join(args.root, 'COCO')
  165. dataset = COCODataset(data_dir=data_root, img_size=img_size)
  166. for i in range(len(dataset)):
  167. if i % 5000 == 0:
  168. print('Loading coco datat [%d / %d]' % (i+1, len(dataset)))
  169. # For COCO
  170. img, _ = dataset.pull_image(i)
  171. img_h, img_w = img.shape[:2]
  172. annotation = dataset.pull_anno(i)
  173. # prepare bbox datas
  174. for box_and_label in annotation:
  175. box = box_and_label[:-1]
  176. xmin, ymin, xmax, ymax = box
  177. bw = (xmax - xmin) / max(img_w, img_h) * img_size
  178. bh = (ymax - ymin) / max(img_w, img_h) * img_size
  179. # check bbox
  180. if bw < 1.0 or bh < 1.0:
  181. continue
  182. boxes.append(Box(0, 0, bw, bh))
  183. print("Number of all bboxes: ", len(boxes))
  184. print("Start k-means !")
  185. centroids = anchor_box_kmeans(boxes, n_anchors, loss_convergence, iters_n, plus=True)