rtdetr_dethead.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283
  1. import torch
  2. import torch.nn as nn
  3. from .rtdetr_basic import MLP
  4. class DetectHead(nn.Module):
  5. def __init__(self, cfg, d_model, num_classes, with_box_refine=False):
  6. super().__init__()
  7. # --------- Basic Parameters ----------
  8. self.cfg = cfg
  9. self.num_classes = num_classes
  10. # --------- Network Parameters ----------
  11. self.class_embed = nn.ModuleList([nn.Linear(d_model, self.num_classes)])
  12. self.bbox_embed = nn.ModuleList([MLP(d_model, d_model, 4, 3)])
  13. if with_box_refine:
  14. self.class_embed = nn.ModuleList([
  15. self.class_embed[0] for _ in range(cfg['num_decoder_layers'])])
  16. self.bbox_embed = nn.ModuleList([
  17. self.bbox_embed[0] for _ in range(cfg['num_decoder_layers'])])
  18. self.init_weight()
  19. def init_weight(self):
  20. init_prob = 0.01
  21. bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
  22. # cls pred
  23. for class_embed in self.class_embed:
  24. class_embed.bias.data = torch.ones(self.num_classes) * bias_value
  25. # box pred
  26. for bbox_embed in self.bbox_embed:
  27. nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
  28. nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
  29. def inverse_sigmoid(self, x):
  30. x = x.clamp(min=0, max=1)
  31. return torch.log(x.clamp(min=1e-5)/(1 - x).clamp(min=1e-5))
  32. def decode_bbox(self, outputs_coords):
  33. ## cxcywh -> xyxy
  34. x1y1_pred = outputs_coords[..., :2] - outputs_coords[..., 2:] * 0.5
  35. x2y2_pred = outputs_coords[..., :2] + outputs_coords[..., 2:] * 0.5
  36. box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
  37. return box_pred
  38. def forward(self, hs, reference, multi_layer=False):
  39. if multi_layer:
  40. # class embed
  41. outputs_class = torch.stack([
  42. layer_cls_embed(layer_hs) for layer_cls_embed, layer_hs in zip(self.class_embed, hs)])
  43. # bbox embed
  44. outputs_coords = []
  45. for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):
  46. layer_delta_unsig = layer_bbox_embed(layer_hs)
  47. layer_outputs_unsig = layer_delta_unsig + self.inverse_sigmoid(layer_ref_sig)
  48. layer_outputs_unsig = layer_outputs_unsig.sigmoid()
  49. outputs_coords.append(layer_outputs_unsig)
  50. else:
  51. # class embed
  52. outputs_class = self.class_embed[-1](hs[-1])
  53. # bbox embed
  54. delta_unsig = self.bbox_embed[-1](hs[-1])
  55. ref_sig = reference[-2]
  56. ref_sig = self.inverse_sigmoid(ref_sig)
  57. outputs_unsig = delta_unsig + ref_sig
  58. outputs_coords = outputs_unsig.sigmoid()
  59. # decode bbox
  60. outputs_coords = self.decode_bbox(outputs_coords)
  61. return outputs_class, outputs_coords
  62. def build_dethead(cfg, d_model, num_classes, with_box_refine):
  63. return DetectHead(cfg, d_model, num_classes, with_box_refine)