| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447 |
- import math
- import copy
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import torch.utils.checkpoint as checkpoint
- try:
- from .basic import FFN, GlobalCrossAttention
- from .basic import trunc_normal_
- except:
- from basic import FFN, GlobalCrossAttention
- from basic import trunc_normal_
- def get_clones(module, N):
- if N <= 0:
- return None
- else:
- return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
- def inverse_sigmoid(x, eps=1e-5):
- x = x.clamp(min=0., max=1.)
- return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
- # ----------------- Transformer modules -----------------
- ## Transformer Encoder layer
- class TransformerEncoderLayer(nn.Module):
- def __init__(self,
- d_model :int = 256,
- num_heads :int = 8,
- ffn_dim :int = 1024,
- dropout :float = 0.1,
- act_type :str = "relu",
- ):
- super().__init__()
- # ----------- Basic parameters -----------
- self.d_model = d_model
- self.num_heads = num_heads
- self.ffn_dim = ffn_dim
- self.dropout = dropout
- self.act_type = act_type
- # ----------- Basic parameters -----------
- # Multi-head Self-Attn
- self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
- self.dropout = nn.Dropout(dropout)
- self.norm = nn.LayerNorm(d_model)
- # Feedforwaed Network
- self.ffn = FFN(d_model, ffn_dim, dropout, act_type)
- def with_pos_embed(self, tensor, pos):
- return tensor if pos is None else tensor + pos
- def forward(self, src, pos_embed):
- """
- Input:
- src: [torch.Tensor] -> [B, N, C]
- pos_embed: [torch.Tensor] -> [B, N, C]
- Output:
- src: [torch.Tensor] -> [B, N, C]
- """
- q = k = self.with_pos_embed(src, pos_embed)
- # -------------- MHSA --------------
- src2 = self.self_attn(q, k, value=src)[0]
- src = src + self.dropout(src2)
- src = self.norm(src)
- # -------------- FFN --------------
- src = self.ffn(src)
-
- return src
- ## Transformer Encoder
- class TransformerEncoder(nn.Module):
- def __init__(self,
- d_model :int = 256,
- num_heads :int = 8,
- num_layers :int = 1,
- ffn_dim :int = 1024,
- pe_temperature : float = 10000.,
- dropout :float = 0.1,
- act_type :str = "relu",
- ):
- super().__init__()
- # ----------- Basic parameters -----------
- self.d_model = d_model
- self.num_heads = num_heads
- self.num_layers = num_layers
- self.ffn_dim = ffn_dim
- self.dropout = dropout
- self.act_type = act_type
- self.pe_temperature = pe_temperature
- self.pos_embed = None
- # ----------- Basic parameters -----------
- self.encoder_layers = get_clones(
- TransformerEncoderLayer(d_model, num_heads, ffn_dim, dropout, act_type), num_layers)
- def build_2d_sincos_position_embedding(self, device, w, h, embed_dim=256, temperature=10000.):
- assert embed_dim % 4 == 0, \
- 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
-
- # ----------- Check cahed pos_embed -----------
- if self.pos_embed is not None and \
- self.pos_embed.shape[2:] == [h, w]:
- return self.pos_embed
-
- # ----------- Generate grid coords -----------
- grid_w = torch.arange(int(w), dtype=torch.float32)
- grid_h = torch.arange(int(h), dtype=torch.float32)
- grid_w, grid_h = torch.meshgrid([grid_w, grid_h]) # shape: [H, W]
- pos_dim = embed_dim // 4
- omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
- omega = 1. / (temperature**omega)
- out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
- out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
- # shape: [1, N, C]
- pos_embed = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], dim=1)[None, :, :]
- pos_embed = pos_embed.to(device)
- self.pos_embed = pos_embed
- return pos_embed
- def forward(self, src):
- """
- Input:
- src: [torch.Tensor] -> [B, C, H, W]
- Output:
- src: [torch.Tensor] -> [B, C, H, W]
- """
- # -------- Transformer encoder --------
- channels, fmp_h, fmp_w = src.shape[1:]
- # [B, C, H, W] -> [B, N, C], N=HxW
- src_flatten = src.flatten(2).permute(0, 2, 1)
- memory = src_flatten
- # PosEmbed: [1, N, C]
- pos_embed = self.build_2d_sincos_position_embedding(
- src.device, fmp_w, fmp_h, channels, self.pe_temperature)
-
- # Transformer Encoder layer
- for encoder in self.encoder_layers:
- memory = encoder(memory, pos_embed=pos_embed)
- # Output: [B, N, C] -> [B, C, N] -> [B, C, H, W]
- src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
- return src
- ## PlainDETR's Decoder layer
- class GlobalDecoderLayer(nn.Module):
- def __init__(self,
- d_model :int = 256,
- num_heads :int = 8,
- ffn_dim :int = 1024,
- dropout :float = 0.1,
- act_type :str = "relu",
- pre_norm :bool = False,
- rpe_hidden_dim :int = 512,
- feature_stride :int = 16,
- ) -> None:
- super().__init__()
- # ------------ Basic parameters ------------
- self.d_model = d_model
- self.num_heads = num_heads
- self.rpe_hidden_dim = rpe_hidden_dim
- self.ffn_dim = ffn_dim
- self.act_type = act_type
- self.pre_norm = pre_norm
- # ------------ Network parameters ------------
- ## Multi-head Self-Attn
- self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
- self.dropout1 = nn.Dropout(dropout)
- self.norm1 = nn.LayerNorm(d_model)
- ## Box-reparam Global Cross-Attn
- self.cross_attn = GlobalCrossAttention(d_model, num_heads, rpe_hidden_dim=rpe_hidden_dim, feature_stride=feature_stride)
- self.dropout2 = nn.Dropout(dropout)
- self.norm2 = nn.LayerNorm(d_model)
- ## FFN
- self.ffn = FFN(d_model, ffn_dim, dropout, act_type, pre_norm)
- @staticmethod
- def with_pos_embed(tensor, pos):
- return tensor if pos is None else tensor + pos
- def forward_pre_norm(self,
- tgt,
- query_pos,
- reference_points,
- src,
- src_pos_embed,
- src_spatial_shapes,
- src_padding_mask=None,
- self_attn_mask=None,
- ):
- # ----------- Multi-head self attention -----------
- tgt1 = self.norm1(tgt)
- q = k = self.with_pos_embed(tgt1, query_pos)
- tgt1 = self.self_attn(q.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
- k.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
- tgt1.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
- attn_mask=self_attn_mask,
- )[0].transpose(0, 1) # [N, B, C] -> [B, N, C]
- tgt = tgt + self.dropout1(tgt1)
- # ----------- Global corss attention -----------
- tgt1 = self.norm2(tgt)
- tgt1 = self.cross_attn(self.with_pos_embed(tgt1, query_pos),
- reference_points,
- self.with_pos_embed(src, src_pos_embed),
- src,
- src_spatial_shapes,
- src_padding_mask,
- )
- tgt = tgt + self.dropout2(tgt1)
- # ----------- FeedForward Network -----------
- tgt = self.ffn(tgt)
- return tgt
- def forward_post_norm(self,
- tgt,
- query_pos,
- reference_points,
- src,
- src_pos_embed,
- src_spatial_shapes,
- src_padding_mask=None,
- self_attn_mask=None,
- ):
- # ----------- Multi-head self attention -----------
- q = k = self.with_pos_embed(tgt, query_pos)
- tgt1 = self.self_attn(q.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
- k.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
- tgt.transpose(0, 1), # [B, N, C] -> [N, B, C], batch_first = False
- attn_mask=self_attn_mask,
- )[0].transpose(0, 1) # [N, B, C] -> [B, N, C]
- tgt = tgt + self.dropout1(tgt1)
- tgt = self.norm1(tgt)
- # ----------- Global corss attention -----------
- tgt1 = self.cross_attn(self.with_pos_embed(tgt, query_pos),
- reference_points,
- self.with_pos_embed(src, src_pos_embed),
- src,
- src_spatial_shapes,
- src_padding_mask,
- )
- tgt = tgt + self.dropout2(tgt1)
- tgt = self.norm2(tgt)
- # ----------- FeedForward Network -----------
- tgt = self.ffn(tgt)
- return tgt
- def forward(self,
- tgt,
- query_pos,
- reference_points,
- src,
- src_pos_embed,
- src_spatial_shapes,
- src_padding_mask=None,
- self_attn_mask=None,
- ):
- if self.pre_norm:
- return self.forward_pre_norm(tgt, query_pos, reference_points, src, src_pos_embed, src_spatial_shapes,
- src_padding_mask, self_attn_mask)
- else:
- return self.forward_post_norm(tgt, query_pos, reference_points, src, src_pos_embed, src_spatial_shapes,
- src_padding_mask, self_attn_mask)
- ## PlainDETR's Decoder
- class GlobalDecoder(nn.Module):
- def __init__(self,
- # Decoder layer params
- d_model :int = 256,
- num_heads :int = 8,
- ffn_dim :int = 1024,
- dropout :float = 0.1,
- act_type :str = "relu",
- pre_norm :bool = False,
- rpe_hidden_dim :int = 512,
- feature_stride :int = 16,
- num_layers :int = 6,
- # Decoder params
- return_intermediate :bool = False,
- use_checkpoint :bool = False,
- ):
- super().__init__()
- # ------------ Basic parameters ------------
- self.d_model = d_model
- self.num_heads = num_heads
- self.rpe_hidden_dim = rpe_hidden_dim
- self.ffn_dim = ffn_dim
- self.act_type = act_type
- self.num_layers = num_layers
- self.return_intermediate = return_intermediate
- self.use_checkpoint = use_checkpoint
- # ------------ Network parameters ------------
- decoder_layer = GlobalDecoderLayer(
- d_model, num_heads, ffn_dim, dropout, act_type, pre_norm, rpe_hidden_dim, feature_stride,)
- self.layers = get_clones(decoder_layer, num_layers)
- self.bbox_embed = None
- self.class_embed = None
- if pre_norm:
- self.final_layer_norm = nn.LayerNorm(d_model)
- else:
- self.final_layer_norm = None
- def _reset_parameters(self):
- # stolen from Swin Transformer
- def _init_weights(m):
- if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=0.02)
- if isinstance(m, nn.Linear) and m.bias is not None:
- nn.init.constant_(m.bias, 0)
- elif isinstance(m, nn.LayerNorm):
- nn.init.constant_(m.bias, 0)
- nn.init.constant_(m.weight, 1.0)
- self.apply(_init_weights)
- def inverse_sigmoid(self, x, eps=1e-5):
- x = x.clamp(min=0, max=1)
- x1 = x.clamp(min=eps)
- x2 = (1 - x).clamp(min=eps)
- return torch.log(x1 / x2)
- def box_xyxy_to_cxcywh(self, x):
- x0, y0, x1, y1 = x.unbind(-1)
- b = [(x0 + x1) / 2, (y0 + y1) / 2, (x1 - x0), (y1 - y0)]
-
- return torch.stack(b, dim=-1)
- def delta2bbox(self, proposals,
- deltas,
- max_shape=None,
- wh_ratio_clip=16 / 1000,
- clip_border=True,
- add_ctr_clamp=False,
- ctr_clamp=32):
- dxy = deltas[..., :2]
- dwh = deltas[..., 2:]
- # Compute width/height of each roi
- pxy = proposals[..., :2]
- pwh = proposals[..., 2:]
- dxy_wh = pwh * dxy
- wh_ratio_clip = torch.as_tensor(wh_ratio_clip)
- max_ratio = torch.abs(torch.log(wh_ratio_clip)).item()
-
- if add_ctr_clamp:
- dxy_wh = torch.clamp(dxy_wh, max=ctr_clamp, min=-ctr_clamp)
- dwh = torch.clamp(dwh, max=max_ratio)
- else:
- dwh = dwh.clamp(min=-max_ratio, max=max_ratio)
- gxy = pxy + dxy_wh
- gwh = pwh * dwh.exp()
- x1y1 = gxy - (gwh * 0.5)
- x2y2 = gxy + (gwh * 0.5)
- bboxes = torch.cat([x1y1, x2y2], dim=-1)
- if clip_border and max_shape is not None:
- bboxes[..., 0::2].clamp_(min=0).clamp_(max=max_shape[1])
- bboxes[..., 1::2].clamp_(min=0).clamp_(max=max_shape[0])
- return bboxes
- def forward(self,
- tgt,
- reference_points,
- src,
- src_pos_embed,
- src_spatial_shapes,
- query_pos=None,
- src_padding_mask=None,
- self_attn_mask=None,
- max_shape=None,
- ):
- output = tgt
- intermediate = []
- intermediate_reference_points = []
- for lid, layer in enumerate(self.layers):
- reference_points_input = reference_points[:, :, None]
- if self.use_checkpoint:
- output = checkpoint.checkpoint(
- layer,
- output,
- query_pos,
- reference_points_input,
- src,
- src_pos_embed,
- src_spatial_shapes,
- src_padding_mask,
- self_attn_mask,
- )
- else:
- output = layer(
- output,
- query_pos,
- reference_points_input,
- src,
- src_pos_embed,
- src_spatial_shapes,
- src_padding_mask,
- self_attn_mask,
- )
- if self.final_layer_norm is not None:
- output_after_norm = self.final_layer_norm(output)
- else:
- output_after_norm = output
- # hack implementation for iterative bounding box refinement
- if self.bbox_embed is not None:
- tmp = self.bbox_embed[lid](output_after_norm)
- new_reference_points = self.box_xyxy_to_cxcywh(
- self.delta2bbox(reference_points, tmp, max_shape))
- reference_points = new_reference_points.detach()
- if self.return_intermediate:
- intermediate.append(output_after_norm)
- intermediate_reference_points.append(new_reference_points)
- if self.return_intermediate:
- return torch.stack(intermediate), torch.stack(intermediate_reference_points)
- return output_after_norm, reference_points
|