rtdetr_augment.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471
  1. # ------------------------------------------------------------
  2. # Data preprocessor for Real-time DETR
  3. # ------------------------------------------------------------
  4. import cv2
  5. import numpy as np
  6. from numpy import random
  7. import torch
  8. import torch.nn.functional as F
  9. # ------------------------- Augmentations -------------------------
  10. class Compose(object):
  11. """Composes several augmentations together.
  12. Args:
  13. transforms (List[Transform]): list of transforms to compose.
  14. Example:
  15. >>> augmentations.Compose([
  16. >>> transforms.CenterCrop(10),
  17. >>> transforms.ToTensor(),
  18. >>> ])
  19. """
  20. def __init__(self, transforms):
  21. self.transforms = transforms
  22. def __call__(self, image, target=None):
  23. for t in self.transforms:
  24. image, target = t(image, target)
  25. return image, target
  26. ## Convert color format
  27. class ConvertColorFormat(object):
  28. def __init__(self, color_format='rgb'):
  29. self.color_format = color_format
  30. def __call__(self, image, target=None):
  31. """
  32. Input:
  33. image: (np.array) a OpenCV image with BGR color format.
  34. target: None
  35. Output:
  36. image: (np.array) a OpenCV image with given color format.
  37. target: None
  38. """
  39. # Convert color format
  40. if self.color_format == 'rgb':
  41. image = image[..., (2, 1, 0)] # BGR -> RGB
  42. elif self.color_format == 'bgr':
  43. image = image
  44. else:
  45. raise NotImplementedError("Unknown color format: <{}>".format(self.color_format))
  46. return image, target
  47. ## Random Photometric Distort
  48. class RandomPhotometricDistort(object):
  49. """
  50. Distort image w.r.t hue, saturation and exposure.
  51. """
  52. def __init__(self, hue=0.1, saturation=1.5, exposure=1.5):
  53. super().__init__()
  54. self.hue = hue
  55. self.saturation = saturation
  56. self.exposure = exposure
  57. def __call__(self, image: np.ndarray, target=None) -> np.ndarray:
  58. """
  59. Args:
  60. img (ndarray): of shape HxW, HxWxC, or NxHxWxC. The array can be
  61. of type uint8 in range [0, 255], or floating point in range
  62. [0, 1] or [0, 255].
  63. Returns:
  64. ndarray: the distorted image(s).
  65. """
  66. if random.random() < 0.5:
  67. dhue = np.random.uniform(low=-self.hue, high=self.hue)
  68. dsat = self._rand_scale(self.saturation)
  69. dexp = self._rand_scale(self.exposure)
  70. image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
  71. image = np.asarray(image, dtype=np.float32) / 255.
  72. image[:, :, 1] *= dsat
  73. image[:, :, 2] *= dexp
  74. H = image[:, :, 0] + dhue * 179 / 255.
  75. if dhue > 0:
  76. H[H > 1.0] -= 1.0
  77. else:
  78. H[H < 0.0] += 1.0
  79. image[:, :, 0] = H
  80. image = (image * 255).clip(0, 255).astype(np.uint8)
  81. image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
  82. image = np.asarray(image, dtype=np.uint8)
  83. return image, target
  84. def _rand_scale(self, upper_bound):
  85. """
  86. Calculate random scaling factor.
  87. Args:
  88. upper_bound (float): range of the random scale.
  89. Returns:
  90. random scaling factor (float) whose range is
  91. from 1 / s to s .
  92. """
  93. scale = np.random.uniform(low=1, high=upper_bound)
  94. if np.random.rand() > 0.5:
  95. return scale
  96. return 1 / scale
  97. ## Random scaling
  98. class RandomExpand(object):
  99. def __init__(self, fill_value) -> None:
  100. self.fill_value = fill_value
  101. def __call__(self, image, target=None):
  102. if random.randint(2):
  103. return image, target
  104. height, width, channels = image.shape
  105. ratio = random.uniform(1, 4)
  106. left = random.uniform(0, width*ratio - width)
  107. top = random.uniform(0, height*ratio - height)
  108. expand_image = np.ones(
  109. (int(height*ratio), int(width*ratio), channels),
  110. dtype=image.dtype) * self.fill_value
  111. expand_image[int(top):int(top + height),
  112. int(left):int(left + width)] = image
  113. image = expand_image
  114. boxes = target['boxes'].copy()
  115. boxes[:, :2] += (int(left), int(top))
  116. boxes[:, 2:] += (int(left), int(top))
  117. target['boxes'] = boxes
  118. return image, target
  119. ## Random IoU based Sample Crop
  120. class RandomSampleCrop(object):
  121. def __init__(self):
  122. self.sample_options = (
  123. # using entire original input image
  124. None,
  125. # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9
  126. (0.1, None),
  127. (0.3, None),
  128. (0.5, None),
  129. (0.7, None),
  130. (0.9, None),
  131. # randomly sample a patch
  132. (None, None),
  133. )
  134. def intersect(self, box_a, box_b):
  135. max_xy = np.minimum(box_a[:, 2:], box_b[2:])
  136. min_xy = np.maximum(box_a[:, :2], box_b[:2])
  137. inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
  138. return inter[:, 0] * inter[:, 1]
  139. def compute_iou(self, box_a, box_b):
  140. inter = self.intersect(box_a, box_b)
  141. area_a = ((box_a[:, 2]-box_a[:, 0]) *
  142. (box_a[:, 3]-box_a[:, 1])) # [A,B]
  143. area_b = ((box_b[2]-box_b[0]) *
  144. (box_b[3]-box_b[1])) # [A,B]
  145. union = area_a + area_b - inter
  146. return inter / union # [A,B]
  147. def __call__(self, image, target=None):
  148. height, width, _ = image.shape
  149. # check target
  150. if len(target["boxes"]) == 0:
  151. return image, target
  152. while True:
  153. # randomly choose a mode
  154. sample_id = np.random.randint(len(self.sample_options))
  155. mode = self.sample_options[sample_id]
  156. if mode is None:
  157. return image, target
  158. boxes = target["boxes"]
  159. labels = target["labels"]
  160. min_iou, max_iou = mode
  161. if min_iou is None:
  162. min_iou = float('-inf')
  163. if max_iou is None:
  164. max_iou = float('inf')
  165. # max trails (50)
  166. for _ in range(50):
  167. current_image = image
  168. w = random.uniform(0.3 * width, width)
  169. h = random.uniform(0.3 * height, height)
  170. # aspect ratio constraint b/t .5 & 2
  171. if h / w < 0.5 or h / w > 2:
  172. continue
  173. left = random.uniform(width - w)
  174. top = random.uniform(height - h)
  175. # convert to integer rect x1,y1,x2,y2
  176. rect = np.array([int(left), int(top), int(left+w), int(top+h)])
  177. # calculate IoU (jaccard overlap) b/t the cropped and gt boxes
  178. overlap = self.compute_iou(boxes, rect)
  179. # is min and max overlap constraint satisfied? if not try again
  180. if overlap.min() < min_iou and max_iou < overlap.max():
  181. continue
  182. # cut the crop from the image
  183. current_image = current_image[rect[1]:rect[3], rect[0]:rect[2],
  184. :]
  185. # keep overlap with gt box IF center in sampled patch
  186. centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
  187. # mask in all gt boxes that above and to the left of centers
  188. m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
  189. # mask in all gt boxes that under and to the right of centers
  190. m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
  191. # mask in that both m1 and m2 are true
  192. mask = m1 * m2
  193. # have any valid boxes? try again if not
  194. if not mask.any():
  195. continue
  196. # take only matching gt boxes
  197. current_boxes = boxes[mask, :].copy()
  198. # take only matching gt labels
  199. current_labels = labels[mask]
  200. # should we use the box left and top corner or the crop's
  201. current_boxes[:, :2] = np.maximum(current_boxes[:, :2],
  202. rect[:2])
  203. # adjust to crop (by substracting crop's left,top)
  204. current_boxes[:, :2] -= rect[:2]
  205. current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:],
  206. rect[2:])
  207. # adjust to crop (by substracting crop's left,top)
  208. current_boxes[:, 2:] -= rect[:2]
  209. # update target
  210. target["boxes"] = current_boxes
  211. target["labels"] = current_labels
  212. return current_image, target
  213. ## Random JitterCrop
  214. class RandomJitterCrop(object):
  215. """Jitter and crop the image and box."""
  216. def __init__(self, fill_value, p=0.5, jitter_ratio=0.3):
  217. super().__init__()
  218. self.p = p
  219. self.jitter_ratio = jitter_ratio
  220. self.fill_value = fill_value
  221. def crop(self, image, pleft, pright, ptop, pbot, output_size):
  222. oh, ow = image.shape[:2]
  223. swidth, sheight = output_size
  224. src_rect = [pleft, ptop, swidth + pleft,
  225. sheight + ptop] # x1,y1,x2,y2
  226. img_rect = [0, 0, ow, oh]
  227. # rect intersection
  228. new_src_rect = [max(src_rect[0], img_rect[0]),
  229. max(src_rect[1], img_rect[1]),
  230. min(src_rect[2], img_rect[2]),
  231. min(src_rect[3], img_rect[3])]
  232. dst_rect = [max(0, -pleft),
  233. max(0, -ptop),
  234. max(0, -pleft) + new_src_rect[2] - new_src_rect[0],
  235. max(0, -ptop) + new_src_rect[3] - new_src_rect[1]]
  236. # crop the image
  237. cropped = np.ones([sheight, swidth, 3], dtype=image.dtype) * self.fill_value
  238. # cropped[:, :, ] = np.mean(image, axis=(0, 1))
  239. cropped[dst_rect[1]:dst_rect[3], dst_rect[0]:dst_rect[2]] = \
  240. image[new_src_rect[1]:new_src_rect[3],
  241. new_src_rect[0]:new_src_rect[2]]
  242. return cropped
  243. def __call__(self, image, target=None):
  244. if random.random() > self.p:
  245. return image, target
  246. else:
  247. oh, ow = image.shape[:2]
  248. dw = int(ow * self.jitter_ratio)
  249. dh = int(oh * self.jitter_ratio)
  250. pleft = np.random.randint(-dw, dw)
  251. pright = np.random.randint(-dw, dw)
  252. ptop = np.random.randint(-dh, dh)
  253. pbot = np.random.randint(-dh, dh)
  254. swidth = ow - pleft - pright
  255. sheight = oh - ptop - pbot
  256. output_size = (swidth, sheight)
  257. # crop image
  258. cropped_image = self.crop(image=image,
  259. pleft=pleft,
  260. pright=pright,
  261. ptop=ptop,
  262. pbot=pbot,
  263. output_size=output_size)
  264. # crop bbox
  265. if target is not None:
  266. bboxes = target['boxes'].copy()
  267. coords_offset = np.array([pleft, ptop], dtype=np.float32)
  268. bboxes[..., [0, 2]] = bboxes[..., [0, 2]] - coords_offset[0]
  269. bboxes[..., [1, 3]] = bboxes[..., [1, 3]] - coords_offset[1]
  270. swidth, sheight = output_size
  271. bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], 0, swidth - 1)
  272. bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], 0, sheight - 1)
  273. target['boxes'] = bboxes
  274. return cropped_image, target
  275. ## Random HFlip
  276. class RandomHorizontalFlip(object):
  277. def __init__(self, p=0.5):
  278. self.p = p
  279. def __call__(self, image, target=None):
  280. if random.random() < self.p:
  281. orig_h, orig_w = image.shape[:2]
  282. image = image[:, ::-1]
  283. if target is not None:
  284. if "boxes" in target:
  285. boxes = target["boxes"].copy()
  286. boxes[..., [0, 2]] = orig_w - boxes[..., [2, 0]]
  287. target["boxes"] = boxes
  288. return image, target
  289. ## Resize tensor image
  290. class Resize(object):
  291. def __init__(self, img_size=640):
  292. self.img_size = img_size
  293. def __call__(self, image, target=None):
  294. orig_h, orig_w = image.shape[:2]
  295. # resize
  296. image = cv2.resize(image, (self.img_size, self.img_size)).astype(np.float32)
  297. img_h, img_w = image.shape[:2]
  298. # rescale bboxes
  299. if target is not None:
  300. boxes = target["boxes"]
  301. boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
  302. boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
  303. target["boxes"] = boxes
  304. return image, target
  305. ## Normalize tensor image
  306. class Normalize(object):
  307. def __init__(self, pixel_mean, pixel_std):
  308. self.pixel_mean = pixel_mean
  309. self.pixel_std = pixel_std
  310. def __call__(self, image, target=None):
  311. # normalize image
  312. image = (image - self.pixel_mean) / self.pixel_std
  313. return image, target
  314. ## Convert ndarray to torch.Tensor
  315. class ToTensor(object):
  316. def __call__(self, image, target=None):
  317. # Convert torch.Tensor
  318. image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
  319. if target is not None:
  320. target["boxes"] = torch.as_tensor(target["boxes"]).float()
  321. target["labels"] = torch.as_tensor(target["labels"]).long()
  322. return image, target
  323. # ------------------------- Preprocessers -------------------------
  324. ## Transform for Train
  325. class RTDetrAugmentation(object):
  326. def __init__(self, img_size=640, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375]):
  327. # ----------------- Basic parameters -----------------
  328. self.img_size = img_size
  329. self.pixel_mean = pixel_mean # RGB format
  330. self.pixel_std = pixel_std # RGB format
  331. self.color_format = 'rgb'
  332. print("================= Pixel Statistics =================")
  333. print("Pixel mean: {}".format(self.pixel_mean))
  334. print("Pixel std: {}".format(self.pixel_std))
  335. # ----------------- Transforms -----------------
  336. self.augment = Compose([
  337. RandomPhotometricDistort(hue=0.5, saturation=1.5, exposure=1.5),
  338. RandomJitterCrop(p=0.8, jitter_ratio=0.3, fill_value=self.pixel_mean[::-1]),
  339. RandomHorizontalFlip(p=0.5),
  340. Resize(img_size=self.img_size),
  341. ConvertColorFormat(self.color_format),
  342. Normalize(self.pixel_mean, self.pixel_std),
  343. ToTensor()
  344. ])
  345. def reset_weak_augment(self):
  346. print("Reset transform with weak augmentation ...")
  347. self.augment = Compose([
  348. RandomHorizontalFlip(p=0.5),
  349. Resize(img_size=self.img_size),
  350. ConvertColorFormat(self.color_format),
  351. Normalize(self.pixel_mean, self.pixel_std),
  352. ToTensor()
  353. ])
  354. def __call__(self, image, target, mosaic=False):
  355. orig_h, orig_w = image.shape[:2]
  356. ratio = [self.img_size / orig_w, self.img_size / orig_h]
  357. image, target = self.augment(image, target)
  358. return image, target, ratio
  359. ## Transform for Eval
  360. class RTDetrBaseTransform(object):
  361. def __init__(self, img_size=640, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375]):
  362. # ----------------- Basic parameters -----------------
  363. self.img_size = img_size
  364. self.pixel_mean = pixel_mean # RGB format
  365. self.pixel_std = pixel_std # RGB format
  366. self.color_format = 'rgb'
  367. print("================= Pixel Statistics =================")
  368. print("Pixel mean: {}".format(self.pixel_mean))
  369. print("Pixel std: {}".format(self.pixel_std))
  370. # ----------------- Transforms -----------------
  371. self.transform = Compose([
  372. Resize(img_size=self.img_size),
  373. ConvertColorFormat(self.color_format),
  374. Normalize(self.pixel_mean, self.pixel_std),
  375. ToTensor()
  376. ])
  377. def __call__(self, image, target=None, mosaic=False):
  378. orig_h, orig_w = image.shape[:2]
  379. ratio = [self.img_size / orig_w, self.img_size / orig_h]
  380. image, target = self.transform(image, target)
  381. return image, target, ratio