loss.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. import math
  2. import torch
  3. import torch.nn as nn
  4. import torch.nn.functional as F
  5. try:
  6. from .loss_utils import varifocal_loss_with_logits, sigmoid_focal_loss
  7. from .loss_utils import box_cxcywh_to_xyxy, bbox_iou
  8. from .loss_utils import is_dist_avail_and_initialized, get_world_size
  9. from .loss_utils import GIoULoss
  10. from .matcher import HungarianMatcher
  11. except:
  12. from loss_utils import varifocal_loss_with_logits, sigmoid_focal_loss
  13. from loss_utils import box_cxcywh_to_xyxy, bbox_iou
  14. from loss_utils import is_dist_avail_and_initialized, get_world_size
  15. from loss_utils import GIoULoss
  16. from matcher import HungarianMatcher
  17. # --------------- Criterion for RT-DETR ---------------
  18. def build_criterion(cfg, num_classes=80):
  19. return Criterion(cfg, num_classes)
  20. class Criterion(object):
  21. def __init__(self, cfg, num_classes=80):
  22. self.matcher = HungarianMatcher(cfg['matcher_hpy']['cost_class'],
  23. cfg['matcher_hpy']['cost_bbox'],
  24. cfg['matcher_hpy']['cost_giou'],
  25. alpha=0.25,
  26. gamma=2.0)
  27. self.loss = DINOLoss(num_classes = num_classes,
  28. matcher = self.matcher,
  29. aux_loss = True,
  30. use_vfl = cfg['use_vfl'],
  31. loss_coeff = cfg['loss_coeff'])
  32. def __call__(self, dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta, targets=None):
  33. assert targets is not None
  34. gt_labels = [t['labels'].to(dec_out_bboxes.device) for t in targets] # (List[torch.Tensor]) -> List[[N,]]
  35. gt_boxes = [t['boxes'].to(dec_out_bboxes.device) for t in targets] # (List[torch.Tensor]) -> List[[N, 4]]
  36. if dn_meta is not None:
  37. if isinstance(dn_meta, list):
  38. dual_groups = len(dn_meta) - 1
  39. dec_out_bboxes = torch.chunk(
  40. dec_out_bboxes, dual_groups + 1, dim=2)
  41. dec_out_logits = torch.chunk(
  42. dec_out_logits, dual_groups + 1, dim=2)
  43. enc_topk_bboxes = torch.chunk(
  44. enc_topk_bboxes, dual_groups + 1, dim=1)
  45. enc_topk_logits = torch.splchunkt(
  46. enc_topk_logits, dual_groups + 1, dim=1)
  47. loss = {}
  48. for g_id in range(dual_groups + 1):
  49. if dn_meta[g_id] is not None:
  50. dn_out_bboxes_gid, dec_out_bboxes_gid = torch.split(
  51. dec_out_bboxes[g_id],
  52. dn_meta[g_id]['dn_num_split'],
  53. dim=2)
  54. dn_out_logits_gid, dec_out_logits_gid = torch.split(
  55. dec_out_logits[g_id],
  56. dn_meta[g_id]['dn_num_split'],
  57. dim=2)
  58. else:
  59. dn_out_bboxes_gid, dn_out_logits_gid = None, None
  60. dec_out_bboxes_gid = dec_out_bboxes[g_id]
  61. dec_out_logits_gid = dec_out_logits[g_id]
  62. out_bboxes_gid = torch.cat([
  63. enc_topk_bboxes[g_id].unsqueeze(0),
  64. dec_out_bboxes_gid
  65. ])
  66. out_logits_gid = torch.cat([
  67. enc_topk_logits[g_id].unsqueeze(0),
  68. dec_out_logits_gid
  69. ])
  70. loss_gid = self.loss(
  71. out_bboxes_gid,
  72. out_logits_gid,
  73. gt_boxes,
  74. gt_labels,
  75. dn_out_bboxes=dn_out_bboxes_gid,
  76. dn_out_logits=dn_out_logits_gid,
  77. dn_meta=dn_meta[g_id])
  78. # sum loss
  79. for key, value in loss_gid.items():
  80. loss.update({
  81. key: loss.get(key, torch.zeros([1], device=out_bboxes_gid.device)) + value
  82. })
  83. # average across (dual_groups + 1)
  84. for key, value in loss.items():
  85. loss.update({key: value / (dual_groups + 1)})
  86. return loss
  87. else:
  88. dn_out_bboxes, dec_out_bboxes = torch.split(
  89. dec_out_bboxes, dn_meta['dn_num_split'], dim=2)
  90. dn_out_logits, dec_out_logits = torch.split(
  91. dec_out_logits, dn_meta['dn_num_split'], dim=2)
  92. else:
  93. dn_out_bboxes, dn_out_logits = None, None
  94. out_bboxes = torch.cat(
  95. [enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
  96. out_logits = torch.cat(
  97. [enc_topk_logits.unsqueeze(0), dec_out_logits])
  98. return self.loss(out_bboxes,
  99. out_logits,
  100. gt_boxes,
  101. gt_labels,
  102. dn_out_bboxes=dn_out_bboxes,
  103. dn_out_logits=dn_out_logits,
  104. dn_meta=dn_meta)
  105. # --------------- DETR series loss ---------------
  106. class DETRLoss(nn.Module):
  107. """Modified Paddle DETRLoss class without mask loss."""
  108. def __init__(self,
  109. num_classes=80,
  110. matcher='HungarianMatcher',
  111. aux_loss=True,
  112. use_vfl=False,
  113. loss_coeff={'class': 1,
  114. 'bbox': 5,
  115. 'giou': 2,},
  116. ):
  117. super(DETRLoss, self).__init__()
  118. self.num_classes = num_classes
  119. self.matcher = matcher
  120. self.loss_coeff = loss_coeff
  121. self.aux_loss = aux_loss
  122. self.use_vfl = use_vfl
  123. self.giou_loss = GIoULoss(reduction='none')
  124. def _get_loss_class(self,
  125. logits,
  126. gt_class,
  127. match_indices,
  128. bg_index,
  129. num_gts,
  130. postfix="",
  131. iou_score=None):
  132. # logits: [b, query, num_classes], gt_class: list[[n, 1]]
  133. name_class = "loss_class" + postfix
  134. target_label = torch.full(logits.shape[:2], bg_index, device=logits.device).long()
  135. bs, num_query_objects = target_label.shape
  136. num_gt = sum(len(a) for a in gt_class)
  137. if num_gt > 0:
  138. index, updates = self._get_index_updates(
  139. num_query_objects, gt_class, match_indices)
  140. target_label = target_label.reshape(-1, 1)
  141. target_label[index] = updates.long()[:, None]
  142. # target_label = paddle.scatter(target_label, index, updates.long())
  143. target_label = target_label.reshape(bs, num_query_objects)
  144. # one-hot label
  145. target_label = F.one_hot(target_label, self.num_classes + 1)[..., :-1].float()
  146. if iou_score is not None and self.use_vfl:
  147. target_score = torch.zeros([bs, num_query_objects], device=logits.device)
  148. if num_gt > 0:
  149. target_score = target_score.reshape(-1, 1)
  150. target_score[index] = iou_score.float()
  151. # target_score = paddle.scatter(target_score, index, iou_score)
  152. target_score = target_score.reshape(bs, num_query_objects, 1) * target_label
  153. loss_cls = varifocal_loss_with_logits(logits,
  154. target_score,
  155. target_label,
  156. num_gts / num_query_objects)
  157. else:
  158. loss_cls = sigmoid_focal_loss(logits,
  159. target_label,
  160. num_gts)
  161. return {name_class: loss_cls * self.loss_coeff['class']}
  162. def _get_loss_bbox(self, boxes, gt_bbox, match_indices, num_gts,
  163. postfix=""):
  164. # boxes: [b, query, 4], gt_bbox: list[[n, 4]]
  165. name_bbox = "loss_bbox" + postfix
  166. name_giou = "loss_giou" + postfix
  167. loss = dict()
  168. if sum(len(a) for a in gt_bbox) == 0:
  169. loss[name_bbox] = torch.as_tensor([0.], device=boxes.device)
  170. loss[name_giou] = torch.as_tensor([0.], device=boxes.device)
  171. return loss
  172. # prepare positive samples
  173. src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox, match_indices)
  174. # Compute L1 loss
  175. loss[name_bbox] = F.l1_loss(src_bbox, target_bbox, reduction='none')
  176. loss[name_bbox] = loss[name_bbox].sum() / num_gts
  177. loss[name_bbox] = self.loss_coeff['bbox'] * loss[name_bbox]
  178. # Compute GIoU loss
  179. loss[name_giou] = self.giou_loss(box_cxcywh_to_xyxy(src_bbox),
  180. box_cxcywh_to_xyxy(target_bbox))
  181. loss[name_giou] = loss[name_giou].sum() / num_gts
  182. loss[name_giou] = self.loss_coeff['giou'] * loss[name_giou]
  183. return loss
  184. def _get_loss_aux(self,
  185. boxes,
  186. logits,
  187. gt_bbox,
  188. gt_class,
  189. bg_index,
  190. num_gts,
  191. dn_match_indices=None,
  192. postfix=""):
  193. loss_class = []
  194. loss_bbox, loss_giou = [], []
  195. if dn_match_indices is not None:
  196. match_indices = dn_match_indices
  197. for i, (aux_boxes, aux_logits) in enumerate(zip(boxes, logits)):
  198. if dn_match_indices is None:
  199. match_indices = self.matcher(
  200. aux_boxes,
  201. aux_logits,
  202. gt_bbox,
  203. gt_class,
  204. )
  205. if self.use_vfl:
  206. if sum(len(a) for a in gt_bbox) > 0:
  207. src_bbox, target_bbox = self._get_src_target_assign(
  208. aux_boxes.detach(), gt_bbox, match_indices)
  209. iou_score = bbox_iou(box_cxcywh_to_xyxy(src_bbox),
  210. box_cxcywh_to_xyxy(target_bbox))
  211. else:
  212. iou_score = None
  213. else:
  214. iou_score = None
  215. loss_class.append(
  216. self._get_loss_class(aux_logits, gt_class, match_indices,
  217. bg_index, num_gts, postfix, iou_score)[
  218. 'loss_class' + postfix])
  219. loss_ = self._get_loss_bbox(aux_boxes, gt_bbox, match_indices,
  220. num_gts, postfix)
  221. loss_bbox.append(loss_['loss_bbox' + postfix])
  222. loss_giou.append(loss_['loss_giou' + postfix])
  223. loss = {
  224. "loss_class_aux" + postfix: sum(loss_class),
  225. "loss_bbox_aux" + postfix: sum(loss_bbox),
  226. "loss_giou_aux" + postfix: sum(loss_giou)
  227. }
  228. return loss
  229. def _get_index_updates(self, num_query_objects, target, match_indices):
  230. batch_idx = torch.cat([
  231. torch.full_like(src, i) for i, (src, _) in enumerate(match_indices)
  232. ])
  233. src_idx = torch.cat([src for (src, _) in match_indices])
  234. src_idx += (batch_idx * num_query_objects)
  235. target_assign = torch.cat([
  236. torch.gather(t, 0, dst.to(t.device)) for t, (_, dst) in zip(target, match_indices)
  237. ])
  238. return src_idx, target_assign
  239. def _get_src_target_assign(self, src, target, match_indices):
  240. src_assign = torch.cat([t[I] if len(I) > 0 else torch.zeros([0, t.shape[-1]], device=src.device)
  241. for t, (I, _) in zip(src, match_indices)
  242. ])
  243. target_assign = torch.cat([t[J] if len(J) > 0 else torch.zeros([0, t.shape[-1]], device=src.device)
  244. for t, (_, J) in zip(target, match_indices)
  245. ])
  246. return src_assign, target_assign
  247. def _get_num_gts(self, targets):
  248. num_gts = sum(len(a) for a in targets)
  249. num_gts = torch.as_tensor([num_gts], device=targets[0].device).float()
  250. if is_dist_avail_and_initialized():
  251. torch.distributed.all_reduce(num_gts)
  252. num_gts = torch.clamp(num_gts / get_world_size(), min=1).item()
  253. return num_gts
  254. def _get_prediction_loss(self,
  255. boxes,
  256. logits,
  257. gt_bbox,
  258. gt_class,
  259. postfix="",
  260. dn_match_indices=None,
  261. num_gts=1):
  262. if dn_match_indices is None:
  263. match_indices = self.matcher(boxes, logits, gt_bbox, gt_class)
  264. else:
  265. match_indices = dn_match_indices
  266. if self.use_vfl:
  267. if sum(len(a) for a in gt_bbox) > 0:
  268. src_bbox, target_bbox = self._get_src_target_assign(
  269. boxes.detach(), gt_bbox, match_indices)
  270. iou_score = bbox_iou(box_cxcywh_to_xyxy(src_bbox),
  271. box_cxcywh_to_xyxy(target_bbox))
  272. else:
  273. iou_score = None
  274. else:
  275. iou_score = None
  276. loss = dict()
  277. loss.update(
  278. self._get_loss_class(logits, gt_class, match_indices,
  279. self.num_classes, num_gts, postfix, iou_score))
  280. loss.update(
  281. self._get_loss_bbox(boxes, gt_bbox, match_indices, num_gts,
  282. postfix))
  283. return loss
  284. def forward(self,
  285. boxes,
  286. logits,
  287. gt_bbox,
  288. gt_class,
  289. postfix="",
  290. **kwargs):
  291. r"""
  292. Args:
  293. boxes (Tensor): [l, b, query, 4]
  294. logits (Tensor): [l, b, query, num_classes]
  295. gt_bbox (List(Tensor)): list[[n, 4]]
  296. gt_class (List(Tensor)): list[[n, 1]]
  297. masks (Tensor, optional): [l, b, query, h, w]
  298. gt_mask (List(Tensor), optional): list[[n, H, W]]
  299. postfix (str): postfix of loss name
  300. """
  301. dn_match_indices = kwargs.get("dn_match_indices", None)
  302. num_gts = kwargs.get("num_gts", None)
  303. if num_gts is None:
  304. num_gts = self._get_num_gts(gt_class)
  305. total_loss = self._get_prediction_loss(
  306. boxes[-1],
  307. logits[-1],
  308. gt_bbox,
  309. gt_class,
  310. postfix=postfix,
  311. dn_match_indices=dn_match_indices,
  312. num_gts=num_gts)
  313. if self.aux_loss:
  314. total_loss.update(
  315. self._get_loss_aux(
  316. boxes[:-1],
  317. logits[:-1],
  318. gt_bbox,
  319. gt_class,
  320. self.num_classes,
  321. num_gts,
  322. dn_match_indices,
  323. postfix,
  324. ))
  325. return total_loss
  326. class DINOLoss(DETRLoss):
  327. def forward(self,
  328. boxes,
  329. logits,
  330. gt_bbox,
  331. gt_class,
  332. postfix="",
  333. dn_out_bboxes=None,
  334. dn_out_logits=None,
  335. dn_meta=None,
  336. **kwargs):
  337. num_gts = self._get_num_gts(gt_class)
  338. total_loss = super(DINOLoss, self).forward(
  339. boxes, logits, gt_bbox, gt_class, num_gts=num_gts)
  340. if dn_meta is not None:
  341. dn_positive_idx, dn_num_group = \
  342. dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
  343. assert len(gt_class) == len(dn_positive_idx)
  344. # denoising match indices
  345. dn_match_indices = self.get_dn_match_indices(
  346. gt_class, dn_positive_idx, dn_num_group)
  347. # compute denoising training loss
  348. num_gts *= dn_num_group
  349. dn_loss = super(DINOLoss, self).forward(
  350. dn_out_bboxes,
  351. dn_out_logits,
  352. gt_bbox,
  353. gt_class,
  354. postfix="_dn",
  355. dn_match_indices=dn_match_indices,
  356. num_gts=num_gts)
  357. total_loss.update(dn_loss)
  358. else:
  359. total_loss.update(
  360. {k + '_dn': torch.as_tensor([0.])
  361. for k in total_loss.keys()})
  362. return total_loss
  363. @staticmethod
  364. def get_dn_match_indices(labels, dn_positive_idx, dn_num_group):
  365. dn_match_indices = []
  366. for i in range(len(labels)):
  367. num_gt = len(labels[i])
  368. if num_gt > 0:
  369. gt_idx = torch.arange(num_gt).long()
  370. gt_idx = gt_idx.tile([dn_num_group])
  371. assert len(dn_positive_idx[i]) == len(gt_idx)
  372. dn_match_indices.append((dn_positive_idx[i], gt_idx))
  373. else:
  374. dn_match_indices.append((torch.zeros([0], device=labels[i].device).long(),
  375. torch.zeros([0], device=labels[i].device).long()))
  376. return dn_match_indices