img_encoder.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import torch
  2. import torch.nn as nn
  3. from .cnn_backbone import build_backbone
  4. from .cnn_neck import build_neck
  5. from .cnn_pafpn import build_fpn
  6. # ------------------------ Image Encoder ------------------------
  7. class ImageEncoder(nn.Module):
  8. def __init__(self, cfg, trainable=False) -> None:
  9. super().__init__()
  10. ## Backbone
  11. self.backbone, feats_dim = build_backbone(cfg, cfg['pretrained']*trainable)
  12. ## Encoder
  13. self.encoder = build_neck(cfg, feats_dim[-1], feats_dim[-1])
  14. ## CSFM
  15. self.csfm = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=round(cfg['d_model']*cfg['width']))
  16. def position_embedding(self, x, temperature=10000):
  17. hs, ws = x.shape[-2:]
  18. device = x.device
  19. num_pos_feats = x.shape[1] // 2
  20. scale = 2 * 3.141592653589793
  21. # generate xy coord mat
  22. y_embed, x_embed = torch.meshgrid(
  23. [torch.arange(1, hs+1, dtype=torch.float32),
  24. torch.arange(1, ws+1, dtype=torch.float32)])
  25. y_embed = y_embed / (hs + 1e-6) * scale
  26. x_embed = x_embed / (ws + 1e-6) * scale
  27. # [H, W] -> [1, H, W]
  28. y_embed = y_embed[None, :, :].to(device)
  29. x_embed = x_embed[None, :, :].to(device)
  30. dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=device)
  31. dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
  32. dim_t = temperature ** (2 * dim_t_)
  33. pos_x = torch.div(x_embed[:, :, :, None], dim_t)
  34. pos_y = torch.div(y_embed[:, :, :, None], dim_t)
  35. pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
  36. pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
  37. # [B, C, H, W]
  38. pos_embed = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
  39. return pos_embed
  40. def forward(self, x):
  41. # Backbone
  42. pyramid_feats = self.backbone(x)
  43. # Encoder
  44. pyramid_feats[-1] = self.encoder(pyramid_feats[-1])
  45. # CSFM
  46. pyramid_feats = self.csfm(pyramid_feats)
  47. # Prepare memory & memoery_pos for Decoder
  48. memory = torch.cat([feat.flatten(2) for feat in pyramid_feats], dim=-1)
  49. memory = memory.permute(0, 2, 1).contiguous()
  50. memory_pos = torch.cat([self.position_embedding(feat).flatten(2)
  51. for feat in pyramid_feats], dim=-1)
  52. memory_pos = memory_pos.permute(0, 2, 1).contiguous()
  53. return memory, memory_pos
  54. # build img-encoder
  55. def build_img_encoder(cfg, trainable):
  56. return ImageEncoder(cfg, trainable)