| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import torch
- import torch.nn as nn
- from .rtdetr_basic import MLP
- class DetectHead(nn.Module):
- def __init__(self, cfg, d_model, num_classes, with_box_refine=False):
- super().__init__()
- # --------- Basic Parameters ----------
- self.cfg = cfg
- self.num_classes = num_classes
- # --------- Network Parameters ----------
- self.class_embed = nn.ModuleList([nn.Linear(d_model, self.num_classes)])
- self.bbox_embed = nn.ModuleList([MLP(d_model, d_model, 4, 3)])
- if with_box_refine:
- self.class_embed = nn.ModuleList([
- self.class_embed[0] for _ in range(cfg['num_decoder_layers'])])
- self.bbox_embed = nn.ModuleList([
- self.bbox_embed[0] for _ in range(cfg['num_decoder_layers'])])
- self.init_weight()
- def init_weight(self):
- init_prob = 0.01
- bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
- # cls pred
- for class_embed in self.class_embed:
- class_embed.bias.data = torch.ones(self.num_classes) * bias_value
- # box pred
- for bbox_embed in self.bbox_embed:
- nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
- nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
-
- def inverse_sigmoid(self, x):
- x = x.clamp(min=0, max=1)
- return torch.log(x.clamp(min=1e-5)/(1 - x).clamp(min=1e-5))
- def decode_bbox(self, outputs_coords):
- ## cxcywh -> xyxy
- x1y1_pred = outputs_coords[..., :2] - outputs_coords[..., 2:] * 0.5
- x2y2_pred = outputs_coords[..., :2] + outputs_coords[..., 2:] * 0.5
- box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
-
- return box_pred
- def forward(self, hs, reference, multi_layer=False):
- if multi_layer:
- # class embed
- outputs_class = torch.stack([
- layer_cls_embed(layer_hs) for layer_cls_embed, layer_hs in zip(self.class_embed, hs)])
- # bbox embed
- outputs_coords = []
- for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):
- layer_delta_unsig = layer_bbox_embed(layer_hs)
- layer_outputs_unsig = layer_delta_unsig + self.inverse_sigmoid(layer_ref_sig)
- layer_outputs_unsig = layer_outputs_unsig.sigmoid()
- outputs_coords.append(layer_outputs_unsig)
- else:
- # class embed
- outputs_class = self.class_embed[-1](hs[-1])
- # bbox embed
- delta_unsig = self.bbox_embed[-1](hs[-1])
- ref_sig = reference[-2]
- ref_sig = self.inverse_sigmoid(ref_sig)
- outputs_unsig = delta_unsig + ref_sig
- outputs_coords = outputs_unsig.sigmoid()
- # decode bbox
- outputs_coords = self.decode_bbox(outputs_coords)
- return outputs_class, outputs_coords
- def build_dethead(cfg, d_model, num_classes, with_box_refine):
- return DetectHead(cfg, d_model, num_classes, with_box_refine)
|