ssd_augment.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611
  1. # ------------------------------------------------------------
  2. # Data preprocessor for Real-time DETR
  3. # ------------------------------------------------------------
  4. import cv2
  5. import numpy as np
  6. from numpy import random
  7. import torch
  8. import torch.nn.functional as F
  9. # ------------------------- Augmentations -------------------------
  10. class Compose(object):
  11. """Composes several augmentations together.
  12. Args:
  13. transforms (List[Transform]): list of transforms to compose.
  14. Example:
  15. >>> augmentations.Compose([
  16. >>> transforms.CenterCrop(10),
  17. >>> transforms.ToTensor(),
  18. >>> ])
  19. """
  20. def __init__(self, transforms):
  21. self.transforms = transforms
  22. def __call__(self, image, target=None):
  23. for t in self.transforms:
  24. image, target = t(image, target)
  25. return image, target
  26. ## Convert color format
  27. class ConvertColorFormat(object):
  28. def __init__(self, color_format='rgb'):
  29. self.color_format = color_format
  30. def __call__(self, image, target=None):
  31. """
  32. Input:
  33. image: (np.array) a OpenCV image with BGR color format.
  34. target: None
  35. Output:
  36. image: (np.array) a OpenCV image with given color format.
  37. target: None
  38. """
  39. # Convert color format
  40. if self.color_format == 'rgb':
  41. image = image[..., (2, 1, 0)] # BGR -> RGB
  42. elif self.color_format == 'bgr':
  43. image = image
  44. else:
  45. raise NotImplementedError("Unknown color format: <{}>".format(self.color_format))
  46. return image, target
  47. ## Random color jitter
  48. class RandomDistort(object):
  49. def __init__(self,
  50. hue=[-18, 18, 0.5],
  51. saturation=[0.5, 1.5, 0.5],
  52. contrast=[0.5, 1.5, 0.5],
  53. brightness=[0.5, 1.5, 0.5],
  54. random_apply=True,
  55. count=4,
  56. random_channel=False,
  57. prob=1.0):
  58. super(RandomDistort, self).__init__()
  59. self.hue = hue
  60. self.saturation = saturation
  61. self.contrast = contrast
  62. self.brightness = brightness
  63. self.random_apply = random_apply
  64. self.count = count
  65. self.random_channel = random_channel
  66. self.prob = prob
  67. def apply_hue(self, image, target=None):
  68. if np.random.uniform(0., 1.) < self.prob:
  69. return image, target
  70. low, high, prob = self.hue
  71. image = image.astype(np.float32)
  72. # it works, but result differ from HSV version
  73. delta = np.random.uniform(low, high)
  74. u = np.cos(delta * np.pi)
  75. w = np.sin(delta * np.pi)
  76. bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
  77. tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
  78. [0.211, -0.523, 0.311]])
  79. ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
  80. [1.0, -1.107, 1.705]])
  81. t = np.dot(np.dot(ityiq, bt), tyiq).T
  82. image = np.dot(image, t)
  83. return image, target
  84. def apply_saturation(self, image, target=None):
  85. low, high, prob = self.saturation
  86. if np.random.uniform(0., 1.) < self.prob:
  87. return image, target
  88. delta = np.random.uniform(low, high)
  89. image = image.astype(np.float32)
  90. # it works, but result differ from HSV version
  91. gray = image * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
  92. gray = gray.sum(axis=2, keepdims=True)
  93. gray *= (1.0 - delta)
  94. image *= delta
  95. image += gray
  96. return image, target
  97. def apply_contrast(self, image, target=None):
  98. if np.random.uniform(0., 1.) < self.prob:
  99. return image, target
  100. low, high, prob = self.contrast
  101. delta = np.random.uniform(low, high)
  102. image = image.astype(np.float32)
  103. image *= delta
  104. return image, target
  105. def apply_brightness(self, image, target=None):
  106. if np.random.uniform(0., 1.) < self.prob:
  107. return image, target
  108. low, high, prob = self.brightness
  109. delta = np.random.uniform(low, high)
  110. image = image.astype(np.float32)
  111. image += delta
  112. return image, target
  113. def __call__(self, image, target=None):
  114. if random.random() > self.prob:
  115. return image, target
  116. if self.random_apply:
  117. functions = [
  118. self.apply_brightness, self.apply_contrast,
  119. self.apply_saturation, self.apply_hue
  120. ]
  121. distortions = np.random.permutation(functions)[:self.count]
  122. for func in distortions:
  123. image, target = func(image, target)
  124. image = np.clip(image, 0.0, 255.)
  125. return image, target
  126. image, target = self.apply_brightness(image, target)
  127. image = np.clip(image, 0.0, 255.)
  128. mode = np.random.randint(0, 2)
  129. if mode:
  130. image, target = self.apply_contrast(image, target)
  131. image = np.clip(image, 0.0, 255.)
  132. image, target = self.apply_saturation(image, target)
  133. image = np.clip(image, 0.0, 255.)
  134. image, target = self.apply_hue(image, target)
  135. image = np.clip(image, 0.0, 255.)
  136. if not mode:
  137. image, target = self.apply_contrast(image, target)
  138. image = np.clip(image, 0.0, 255.)
  139. if self.random_channel:
  140. if np.random.randint(0, 2):
  141. image = image[..., np.random.permutation(3)]
  142. return image, target
  143. ## Random scaling
  144. class RandomExpand(object):
  145. def __init__(self, fill_value) -> None:
  146. self.fill_value = fill_value
  147. def __call__(self, image, target=None):
  148. if random.randint(2):
  149. return image, target
  150. height, width, channels = image.shape
  151. ratio = random.uniform(1, 4)
  152. left = random.uniform(0, width*ratio - width)
  153. top = random.uniform(0, height*ratio - height)
  154. expand_image = np.ones(
  155. (int(height*ratio), int(width*ratio), channels),
  156. dtype=image.dtype) * self.fill_value
  157. expand_image[int(top):int(top + height),
  158. int(left):int(left + width)] = image
  159. image = expand_image
  160. boxes = target['boxes'].copy()
  161. boxes[:, :2] += (int(left), int(top))
  162. boxes[:, 2:] += (int(left), int(top))
  163. target['boxes'] = boxes
  164. return image, target
  165. ## Random IoU based Sample Crop
  166. class RandomIoUCrop(object):
  167. def __init__(self, p=0.5):
  168. self.p = p
  169. self.sample_options = (
  170. # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9
  171. (0.1, None),
  172. (0.3, None),
  173. (0.5, None),
  174. (0.7, None),
  175. (0.9, None),
  176. None,
  177. )
  178. def intersect(self, box_a, box_b):
  179. max_xy = np.minimum(box_a[:, 2:], box_b[2:])
  180. min_xy = np.maximum(box_a[:, :2], box_b[:2])
  181. inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
  182. return inter[:, 0] * inter[:, 1]
  183. def compute_iou(self, box_a, box_b):
  184. inter = self.intersect(box_a, box_b)
  185. area_a = ((box_a[:, 2]-box_a[:, 0]) *
  186. (box_a[:, 3]-box_a[:, 1])) # [A,B]
  187. area_b = ((box_b[2]-box_b[0]) *
  188. (box_b[3]-box_b[1])) # [A,B]
  189. union = area_a + area_b - inter
  190. return inter / union # [A,B]
  191. def __call__(self, image, target=None):
  192. height, width, _ = image.shape
  193. # check target
  194. if len(target["boxes"]) == 0 or random.random() > self.p:
  195. return image, target
  196. while True:
  197. # randomly choose a mode
  198. sample_id = np.random.randint(len(self.sample_options))
  199. mode = self.sample_options[sample_id]
  200. if mode is None:
  201. return image, target
  202. boxes = target["boxes"]
  203. labels = target["labels"]
  204. min_iou, max_iou = mode
  205. if min_iou is None:
  206. min_iou = float('-inf')
  207. if max_iou is None:
  208. max_iou = float('inf')
  209. # max trails (50)
  210. for _ in range(50):
  211. current_image = image
  212. w = random.uniform(0.3 * width, width)
  213. h = random.uniform(0.3 * height, height)
  214. # aspect ratio constraint b/t .5 & 2
  215. if h / w < 0.5 or h / w > 2:
  216. continue
  217. left = random.uniform(width - w)
  218. top = random.uniform(height - h)
  219. # convert to integer rect x1,y1,x2,y2
  220. rect = np.array([int(left), int(top), int(left+w), int(top+h)])
  221. # calculate IoU (jaccard overlap) b/t the cropped and gt boxes
  222. overlap = self.compute_iou(boxes, rect)
  223. # is min and max overlap constraint satisfied? if not try again
  224. if overlap.min() < min_iou and max_iou < overlap.max():
  225. continue
  226. # cut the crop from the image
  227. current_image = current_image[rect[1]:rect[3], rect[0]:rect[2],
  228. :]
  229. # keep overlap with gt box IF center in sampled patch
  230. centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
  231. # mask in all gt boxes that above and to the left of centers
  232. m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
  233. # mask in all gt boxes that under and to the right of centers
  234. m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
  235. # mask in that both m1 and m2 are true
  236. mask = m1 * m2
  237. # have any valid boxes? try again if not
  238. if not mask.any():
  239. continue
  240. # take only matching gt boxes
  241. current_boxes = boxes[mask, :].copy()
  242. # take only matching gt labels
  243. current_labels = labels[mask]
  244. # should we use the box left and top corner or the crop's
  245. current_boxes[:, :2] = np.maximum(current_boxes[:, :2],
  246. rect[:2])
  247. # adjust to crop (by substracting crop's left,top)
  248. current_boxes[:, :2] -= rect[:2]
  249. current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:],
  250. rect[2:])
  251. # adjust to crop (by substracting crop's left,top)
  252. current_boxes[:, 2:] -= rect[:2]
  253. # update target
  254. target["boxes"] = current_boxes
  255. target["labels"] = current_labels
  256. return current_image, target
  257. ## Random JitterCrop
  258. class RandomJitterCrop(object):
  259. """Jitter and crop the image and box."""
  260. def __init__(self, fill_value, p=0.5, jitter_ratio=0.3):
  261. super().__init__()
  262. self.p = p
  263. self.jitter_ratio = jitter_ratio
  264. self.fill_value = fill_value
  265. def crop(self, image, pleft, pright, ptop, pbot, output_size):
  266. oh, ow = image.shape[:2]
  267. swidth, sheight = output_size
  268. src_rect = [pleft, ptop, swidth + pleft,
  269. sheight + ptop] # x1,y1,x2,y2
  270. img_rect = [0, 0, ow, oh]
  271. # rect intersection
  272. new_src_rect = [max(src_rect[0], img_rect[0]),
  273. max(src_rect[1], img_rect[1]),
  274. min(src_rect[2], img_rect[2]),
  275. min(src_rect[3], img_rect[3])]
  276. dst_rect = [max(0, -pleft),
  277. max(0, -ptop),
  278. max(0, -pleft) + new_src_rect[2] - new_src_rect[0],
  279. max(0, -ptop) + new_src_rect[3] - new_src_rect[1]]
  280. # crop the image
  281. cropped = np.ones([sheight, swidth, 3], dtype=image.dtype) * self.fill_value
  282. # cropped[:, :, ] = np.mean(image, axis=(0, 1))
  283. cropped[dst_rect[1]:dst_rect[3], dst_rect[0]:dst_rect[2]] = \
  284. image[new_src_rect[1]:new_src_rect[3],
  285. new_src_rect[0]:new_src_rect[2]]
  286. return cropped
  287. def __call__(self, image, target=None):
  288. if random.random() > self.p:
  289. return image, target
  290. else:
  291. oh, ow = image.shape[:2]
  292. dw = int(ow * self.jitter_ratio)
  293. dh = int(oh * self.jitter_ratio)
  294. pleft = np.random.randint(-dw, dw)
  295. pright = np.random.randint(-dw, dw)
  296. ptop = np.random.randint(-dh, dh)
  297. pbot = np.random.randint(-dh, dh)
  298. swidth = ow - pleft - pright
  299. sheight = oh - ptop - pbot
  300. output_size = (swidth, sheight)
  301. # crop image
  302. cropped_image = self.crop(image=image,
  303. pleft=pleft,
  304. pright=pright,
  305. ptop=ptop,
  306. pbot=pbot,
  307. output_size=output_size)
  308. # crop bbox
  309. if target is not None:
  310. bboxes = target['boxes'].copy()
  311. coords_offset = np.array([pleft, ptop], dtype=np.float32)
  312. bboxes[..., [0, 2]] = bboxes[..., [0, 2]] - coords_offset[0]
  313. bboxes[..., [1, 3]] = bboxes[..., [1, 3]] - coords_offset[1]
  314. swidth, sheight = output_size
  315. bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], 0, swidth - 1)
  316. bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], 0, sheight - 1)
  317. target['boxes'] = bboxes
  318. return cropped_image, target
  319. ## Random HFlip
  320. class RandomHorizontalFlip(object):
  321. def __init__(self, p=0.5):
  322. self.p = p
  323. def __call__(self, image, target=None):
  324. if random.random() < self.p:
  325. orig_h, orig_w = image.shape[:2]
  326. image = image[:, ::-1]
  327. if target is not None:
  328. if "boxes" in target:
  329. boxes = target["boxes"].copy()
  330. boxes[..., [0, 2]] = orig_w - boxes[..., [2, 0]]
  331. target["boxes"] = boxes
  332. return image, target
  333. ## Resize tensor image
  334. class Resize(object):
  335. def __init__(self, img_size=640):
  336. self.img_size = img_size
  337. def __call__(self, image, target=None):
  338. orig_h, orig_w = image.shape[:2]
  339. # resize
  340. image = cv2.resize(image, (self.img_size, self.img_size)).astype(np.float32)
  341. img_h, img_w = image.shape[:2]
  342. # rescale bboxes
  343. if target is not None:
  344. boxes = target["boxes"].astype(np.float32)
  345. boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
  346. boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
  347. target["boxes"] = boxes
  348. return image, target
  349. ## Normalize tensor image
  350. class Normalize(object):
  351. def __init__(self, pixel_mean, pixel_std, normalize_coords=False):
  352. self.pixel_mean = pixel_mean
  353. self.pixel_std = pixel_std
  354. self.normalize_coords = normalize_coords
  355. def __call__(self, image, target=None):
  356. # normalize image
  357. image = (image - self.pixel_mean) / self.pixel_std
  358. # normalize bbox
  359. if target is not None and self.normalize_coords:
  360. img_h, img_w = image.shape[:2]
  361. target["boxes"][..., [0, 2]] = target["boxes"][..., [0, 2]] / float(img_w)
  362. target["boxes"][..., [1, 3]] = target["boxes"][..., [1, 3]] / float(img_h)
  363. return image, target
  364. ## Convert ndarray to torch.Tensor
  365. class ToTensor(object):
  366. def __call__(self, image, target=None):
  367. # Convert torch.Tensor
  368. image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
  369. if target is not None:
  370. target["boxes"] = torch.as_tensor(target["boxes"]).float()
  371. target["labels"] = torch.as_tensor(target["labels"]).long()
  372. return image, target
  373. ## Convert BBox foramt
  374. class ConvertBoxFormat(object):
  375. def __init__(self, box_format="xyxy"):
  376. self.box_format = box_format
  377. def __call__(self, image, target=None):
  378. # convert box format
  379. if self.box_format == "xyxy" or target is None:
  380. pass
  381. elif self.box_format == "xywh":
  382. target = target.copy()
  383. if "boxes" in target:
  384. boxes_xyxy = target["boxes"]
  385. boxes_xywh = torch.zeros_like(boxes_xyxy)
  386. boxes_xywh[..., :2] = (boxes_xyxy[..., :2] + boxes_xyxy[..., 2:]) * 0.5 # cxcy
  387. boxes_xywh[..., 2:] = boxes_xyxy[..., 2:] - boxes_xyxy[..., :2] # bwbh
  388. target["boxes"] = boxes_xywh
  389. else:
  390. raise NotImplementedError("Unknown box format: {}".format(self.box_format))
  391. return image, target
  392. # ------------------------- Preprocessers -------------------------
  393. ## Transform for Train
  394. class SSDAugmentation(object):
  395. def __init__(self,
  396. img_size = 640,
  397. pixel_mean = [123.675, 116.28, 103.53],
  398. pixel_std = [58.395, 57.12, 57.375],
  399. box_format = 'xywh',
  400. normalize_coords = False):
  401. # ----------------- Basic parameters -----------------
  402. self.img_size = img_size
  403. self.box_format = box_format
  404. self.pixel_mean = pixel_mean # RGB format
  405. self.pixel_std = pixel_std # RGB format
  406. self.normalize_coords = normalize_coords
  407. self.color_format = 'rgb'
  408. print("================= Pixel Statistics =================")
  409. print("Pixel mean: {}".format(self.pixel_mean))
  410. print("Pixel std: {}".format(self.pixel_std))
  411. # ----------------- Transforms -----------------
  412. self.augment = Compose([
  413. RandomDistort(prob=0.5),
  414. RandomExpand(fill_value=self.pixel_mean[::-1]),
  415. RandomIoUCrop(p=0.8),
  416. RandomHorizontalFlip(p=0.5),
  417. Resize(img_size=self.img_size),
  418. ConvertColorFormat(self.color_format),
  419. Normalize(self.pixel_mean, self.pixel_std, normalize_coords),
  420. ToTensor(),
  421. ConvertBoxFormat(self.box_format),
  422. ])
  423. def __call__(self, image, target, mosaic=False):
  424. orig_h, orig_w = image.shape[:2]
  425. ratio = [self.img_size / orig_w, self.img_size / orig_h]
  426. image, target = self.augment(image, target)
  427. return image, target, ratio
  428. ## Transform for Eval
  429. class SSDBaseTransform(object):
  430. def __init__(self,
  431. img_size = 640,
  432. pixel_mean = [123.675, 116.28, 103.53],
  433. pixel_std = [58.395, 57.12, 57.375],
  434. box_format = 'xywh',
  435. normalize_coords = False):
  436. # ----------------- Basic parameters -----------------
  437. self.img_size = img_size
  438. self.box_format = box_format
  439. self.pixel_mean = pixel_mean # RGB format
  440. self.pixel_std = pixel_std # RGB format
  441. self.normalize_coords = normalize_coords
  442. self.color_format = 'rgb'
  443. print("================= Pixel Statistics =================")
  444. print("Pixel mean: {}".format(self.pixel_mean))
  445. print("Pixel std: {}".format(self.pixel_std))
  446. # ----------------- Transforms -----------------
  447. self.transform = Compose([
  448. Resize(img_size=self.img_size),
  449. ConvertColorFormat(self.color_format),
  450. Normalize(self.pixel_mean, self.pixel_std, self.normalize_coords),
  451. ToTensor(),
  452. ConvertBoxFormat(self.box_format),
  453. ])
  454. def __call__(self, image, target=None, mosaic=False):
  455. orig_h, orig_w = image.shape[:2]
  456. ratio = [self.img_size / orig_w, self.img_size / orig_h]
  457. image, target = self.transform(image, target)
  458. return image, target, ratio
  459. if __name__ == "__main__":
  460. image_path = "voc_image.jpg"
  461. is_train = True
  462. if is_train:
  463. ssd_augment = SSDAugmentation(img_size=416,
  464. pixel_mean=[0., 0., 0.],
  465. pixel_std=[255., 255., 255.],
  466. box_format="xyxy",
  467. normalize_coords=False,
  468. )
  469. else:
  470. ssd_augment = SSDBaseTransform(img_size=416,
  471. pixel_mean=[0., 0., 0.],
  472. pixel_std=[255., 255., 255.],
  473. box_format="xyxy",
  474. normalize_coords=False,
  475. )
  476. # 读取图像数据
  477. orig_image = cv2.imread(image_path)
  478. target = {
  479. "boxes": np.array([[86, 96, 256, 425], [132, 71, 243, 282]], dtype=np.float32),
  480. "labels": np.array([12, 14], dtype=np.int32),
  481. }
  482. # 绘制原始数据的边界框
  483. image_copy = orig_image.copy()
  484. for box in target["boxes"]:
  485. x1, y1, x2, y2 = box
  486. image_copy = cv2.rectangle(image_copy, (int(x1), int(y1)), (int(x2), int(y2)), [0, 0, 255], 2)
  487. cv2.imshow("original image", image_copy)
  488. cv2.waitKey(0)
  489. # 展示预处理后的输入图像数据和标签信息
  490. image_aug, target_aug, _ = ssd_augment(orig_image, target)
  491. # [c, h, w] -> [h, w, c]
  492. image_aug = image_aug.permute(1, 2, 0).contiguous().numpy()
  493. image_aug = np.clip(image_aug * 255, 0, 255).astype(np.uint8)
  494. image_aug = image_aug[:, :, (2, 1, 0)] # 切换为CV2默认的BGR通道顺序
  495. image_aug = image_aug.copy()
  496. # 绘制处理后的边界框
  497. for box in target_aug["boxes"]:
  498. x1, y1, x2, y2 = box
  499. image_aug = cv2.rectangle(image_aug, (int(x1), int(y1)), (int(x2), int(y2)), [0, 0, 255], 2)
  500. cv2.imshow("processed image", image_aug)
  501. cv2.waitKey(0)