yjh0410 1 سال پیش
والد
کامیت
39fabc437a

+ 78 - 3
odlab/config/detr_config.py

@@ -9,7 +9,79 @@ def build_detr_config(args):
 
 class DetrBaseConfig(object):
     def __init__(self):
-        pass
+        # --------- Backbone ---------
+        self.backbone = "resnet50"
+        self.bk_norm  = "FrozeBN"
+        self.res5_dilation = False
+        self.use_pretrained = True
+        self.freeze_at = 1
+        self.max_stride = 32
+        self.out_stride = 32
+
+        # --------- Transformer ---------
+        self.transformer = "detr_transformer"
+        self.hidden_dim = 256
+        self.num_heads = 8
+        self.feedforward_dim = 2048
+        self.num_enc_layers = 6
+        self.num_dec_layers = 6
+        self.dropout = 0.1
+        self.tr_act = 'relu'
+        self.pre_norm = False
+
+        # --------- Post-process ---------
+        self.train_topk = 300
+        self.train_conf_thresh = 0.05
+        self.test_topk = 300
+        self.test_conf_thresh = 0.3
+
+        # --------- Label Assignment ---------
+        self.matcher_hpy = {'cost_class': 1.0,
+                            'cost_bbox':  5.0,
+                            'cost_giou':  2.0,
+                              }
+
+        # --------- Loss weight ---------
+        self.loss_cls  = 1.0
+        self.loss_box  = 5.0
+        self.loss_giou = 2.0
+
+        # --------- Optimizer ---------
+        self.optimizer = 'adamw'
+        self.batch_size_base = 16
+        self.per_image_lr  = 0.0001 / 16
+        self.bk_lr_ratio   = 0.1
+        self.momentum      = None
+        self.weight_decay  = 1e-4
+        self.clip_max_norm = 0.1
+
+        # --------- LR Scheduler ---------
+        self.lr_scheduler = 'step'
+        self.warmup = 'linear'
+        self.warmup_iters = 100
+        self.warmup_factor = 0.00066667
+
+        # --------- Train epoch ---------
+        self.max_epoch = 500
+        self.lr_epoch  = [400]
+        self.eval_epoch = 2
+
+        # --------- Data process ---------
+        ## input size
+        self.train_min_size = [800]   # short edge of image
+        self.train_min_size2 = [400, 500, 600]
+        self.train_max_size = 1333
+        self.test_min_size  = [800]
+        self.test_max_size  = 1333
+        self.random_crop_size = [320, 600]
+        ## Pixel mean & std
+        self.pixel_mean = [0.485, 0.456, 0.406]
+        self.pixel_std  = [0.229, 0.224, 0.225]
+        ## Transforms
+        self.box_format = 'xywh'
+        self.normalize_coords = True
+        self.detr_style = True
+        self.trans_config = None
 
     def print_config(self):
         config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
@@ -19,5 +91,8 @@ class DetrBaseConfig(object):
 class Detr_R50_Config(DetrBaseConfig):
     def __init__(self) -> None:
         super().__init__()
-        ## Backbone
-        pass
+        # --------- Backbone ---------
+        self.backbone = "resnet50"
+        self.bk_norm  = "FrozeBN"
+        self.res5_dilation = False
+        self.use_pretrained = True

+ 4 - 0
odlab/models/detectors/__init__.py

@@ -3,6 +3,7 @@ import torch
 
 from .fcos.build  import build_fcos, build_fcos_rt
 from .yolof.build import build_yolof
+from .detr.build  import build_detr
 
 
 def build_model(args, cfg, is_val=False):
@@ -16,6 +17,9 @@ def build_model(args, cfg, is_val=False):
     ## YOLOF    
     elif 'yolof' in args.model:
         model, criterion = build_yolof(cfg, is_val)
+    ## DETR    
+    elif 'detr' in args.model:
+        model, criterion = build_detr(cfg, is_val)
     else:
         raise NotImplementedError("Unknown detector: {}".args.model)
     

+ 23 - 0
odlab/models/detectors/detr/build.py

@@ -0,0 +1,23 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+from .criterion import SetCriterion
+from .detr import DETR
+
+
+# build DETR
+def build_detr(cfg, is_val=False):
+    # -------------- Build DETR --------------
+    model = DETR(cfg         = cfg,
+                 num_classes = cfg.num_classes,
+                 conf_thresh = cfg.train_conf_thresh if is_val else cfg.test_conf_thresh,
+                 topk        = cfg.train_topk        if is_val else cfg.test_topk,
+                 )
+            
+    # -------------- Build Criterion --------------
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+
+    return model, criterion

+ 129 - 0
odlab/models/detectors/detr/criterion.py

@@ -0,0 +1,129 @@
+"""
+reference: 
+https://github.com/facebookresearch/detr/blob/main/models/detr.py
+
+by lyuwenyu
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+from .matcher import HungarianMatcher
+
+
+# --------------- Criterion for DETR ---------------
+class SetCriterion(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.num_classes = cfg.num_classes
+        self.losses = ['labels', 'boxes']
+        # -------- Loss weights --------
+        self.weight_dict = {'loss_cls':  cfg.loss_cls,
+                            'loss_box':  cfg.loss_box,
+                            'loss_giou': cfg.loss_giou}
+        for i in range(cfg.num_dec_layers - 1):
+            self.weight_dict.update({k + f'_aux_{i}': v for k, v in self.weight_dict.items()})
+        # -------- Matcher --------
+        self.matcher = HungarianMatcher(cfg.cost_class, cfg.cost_bbox, cfg.cost_giou)
+
+    def loss_labels(self, outputs, targets, indices, num_boxes):
+        assert 'pred_logits' in outputs
+        src_logits = outputs['pred_logits']
+
+        idx = self._get_src_permutation_idx(indices)
+        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
+                                    dtype=torch.int64, device=src_logits.device)
+        target_classes[idx] = target_classes_o
+
+        loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
+
+        return {'loss_cls': loss_cls.sum() / num_boxes}
+
+    def loss_boxes(self, outputs, targets, indices, num_boxes):
+        assert 'pred_boxes' in outputs
+        idx = self._get_src_permutation_idx(indices)
+        src_boxes = outputs['pred_boxes'][idx]
+        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
+
+        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
+        loss_giou = 1 - torch.diag(generalized_box_iou(box_cxcywh_to_xyxy(src_boxes),
+                                                       box_cxcywh_to_xyxy(target_boxes)))
+
+        return {'loss_box': loss_bbox.sum() / num_boxes,
+                'loss_giou': loss_giou.sum() / num_boxes}
+
+    def _get_src_permutation_idx(self, indices):
+        # permute predictions following indices
+        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+        src_idx = torch.cat([src for (src, _) in indices])
+
+        return batch_idx, src_idx
+
+    def _get_tgt_permutation_idx(self, indices):
+        # permute targets following indices
+        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+
+        return batch_idx, tgt_idx
+
+    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
+        loss_map = {
+            'boxes': self.loss_boxes,
+            'labels': self.loss_labels,
+        }
+        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
+
+    def forward(self, outputs, targets):
+        outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
+
+        # Retrieve the matching between the outputs of the last layer and the targets
+        indices = self.matcher(outputs_without_aux, targets)
+
+        # Compute the average number of target boxes accross all nodes, for normalization purposes
+        num_boxes = sum(len(t["labels"]) for t in targets)
+        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_boxes)
+        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
+
+        # Compute all the requested losses
+        losses = {}
+        for loss in self.losses:
+            l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
+            losses.update(l_dict)
+
+        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+        if 'aux_outputs' in outputs:
+            for i, aux_outputs in enumerate(outputs['aux_outputs']):
+                indices = self.matcher(aux_outputs, targets)
+                for loss in self.losses:
+                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
+                    l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()}
+                    losses.update(l_dict)
+
+        return losses
+
+    @staticmethod
+    def get_cdn_matched_indices(dn_meta, targets):
+        '''get_cdn_matched_indices
+        '''
+        dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
+        num_gts = [len(t['labels']) for t in targets]
+        device = targets[0]['labels'].device
+        
+        dn_match_indices = []
+        for i, num_gt in enumerate(num_gts):
+            if num_gt > 0:
+                gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
+                gt_idx = gt_idx.tile(dn_num_group)
+                assert len(dn_positive_idx[i]) == len(gt_idx)
+                dn_match_indices.append((dn_positive_idx[i], gt_idx))
+            else:
+                dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \
+                    torch.zeros(0, dtype=torch.int64,  device=device)))
+        
+        return dn_match_indices

+ 124 - 0
odlab/models/detectors/detr/detr.py

@@ -0,0 +1,124 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+# --------------- Model components ---------------
+from ...backbone    import build_backbone
+from ...transformer import build_transformer
+from ...basic.mlp   import MLP
+
+
+# Detection with Transformer
+class DETR(nn.Module):
+    def __init__(self, 
+                 cfg,
+                 num_classes :int   = 80, 
+                 conf_thresh :float = 0.05,
+                 topk        :int   = 1000,
+                 ):
+        super().__init__()
+        # ---------------------- Basic Parameters ----------------------
+        self.cfg = cfg
+        self.topk = topk
+        self.num_classes = num_classes
+        self.conf_thresh = conf_thresh
+
+        # ---------------------- Network Parameters ----------------------
+        ## Backbone
+        self.backbone, feat_dims = build_backbone(cfg)
+
+        ## Input proj
+        self.input_proj = nn.Conv2d(feat_dims[-1], cfg.hidden_dim, kernel_size=1)
+
+        ## Object Queries
+        self.query_embed = nn.Embedding(cfg.num_queries, cfg.hidden_dim)
+        
+        ## Transformer
+        self.transformer = build_transformer(cfg, return_intermediate_dec=True)
+
+        ## Output
+        self.class_embed = nn.Linear(cfg.hidden_dim, num_classes + 1)
+        self.bbox_embed  = MLP(cfg.hidden_dim, cfg.feedward_dim, 4, 3)
+
+    @torch.jit.unused
+    def set_aux_loss(self, outputs_class, outputs_coord):
+        return [{'pred_logits': a, 'pred_boxes': b}
+                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+
+    def post_process(self, cls_pred, box_pred):
+        """
+        Input:
+            cls_pred: (Tensor) [Nq, C]
+            box_pred: (Tensor) [Nq, 4]
+        """        
+        # [Nq x C,]
+        scores_i = cls_pred.flatten()
+
+        # Keep top k top scoring indices only.
+        num_topk = min(self.topk, box_pred.size(0))
+
+        # torch.sort is actually faster than .topk (at least on GPUs)
+        predicted_prob, topk_idxs = scores_i.sort(descending=True)
+        topk_scores = predicted_prob[:num_topk]
+        topk_idxs = topk_idxs[:num_topk]
+
+        # filter out the proposals with low confidence score
+        keep_idxs = topk_scores > self.conf_thresh
+        topk_idxs = topk_idxs[keep_idxs]
+
+        # final scores
+        scores = topk_scores[keep_idxs]
+        # final labels
+        labels = topk_idxs % self.num_classes
+        # final bboxes
+        anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+        bboxes = box_pred[anchor_idxs]
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        return bboxes, scores, labels
+
+    def forward(self, src, src_mask=None):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(src)
+        feat = self.input_proj(pyramid_feats[-1])
+
+        if src_mask is not None:
+            src_mask = F.interpolate(src_mask[None].float(), size=feat.shape[-2:]).bool()[0]
+        else:
+            src_mask = torch.zeros([feat.shape[0], *feat.shape[-2:]], device=feat.device, dtype=torch.bool)
+
+        # ---------------- Transformer ----------------
+        hs = self.transformer(feat, src_mask, self.query_embed.weight)[0]
+
+        # ---------------- Head ----------------
+        outputs_class = self.class_embed(hs)
+        outputs_coord = self.bbox_embed(hs).sigmoid()
+
+        if self.training:
+            outputs = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
+            outputs['aux_outputs'] = self.set_aux_loss(outputs_class, outputs_coord)
+        else:
+            # [B, N, C] -> [N, C]
+            cls_pred = outputs_class[-1].softmax(-1)[..., :-1]
+            box_pred = outputs_coord[-1]
+
+            cxcy_pred = box_pred[..., :2]
+            bwbh_pred = box_pred[..., 2:]
+            x1y1_pred = cxcy_pred - 0.5 * bwbh_pred
+            x2y2_pred = cxcy_pred + 0.5 * bwbh_pred
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+            # Post-process (no NMS)
+            bboxes, scores, labels = self.post_process(cls_pred, box_pred)
+
+            outputs = {
+                'scores': scores,
+                'labels': labels,
+                'bboxes': bboxes
+            }
+
+        return outputs 

+ 51 - 0
odlab/models/detectors/detr/matcher.py

@@ -0,0 +1,51 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# https://github.com/facebookresearch/detr
+
+import torch
+import torch.nn as nn
+from scipy.optimize import linear_sum_assignment
+from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
+
+
+class HungarianMatcher(nn.Module):
+    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
+        super().__init__()
+        self.cost_class = cost_class
+        self.cost_bbox = cost_bbox
+        self.cost_giou = cost_giou
+
+    @torch.no_grad()
+    def forward(self, outputs, targets):
+        bs, num_queries = outputs["pred_logits"].shape[:2]
+
+        # [B * num_queries, C] = [N, C]
+        out_prob = outputs["pred_logits"].flatten(0, 1).softmax(-1)
+        out_bbox = outputs["pred_boxes"].flatten(0, 1)
+
+        # [M,] where M is number of all targets in this batch
+        tgt_ids = torch.cat([v["labels"] for v in targets])
+        # [M, 4]
+        tgt_bbox = torch.cat([v["boxes"] for v in targets])
+
+        # [N, M]
+        cost_class = -out_prob[:, tgt_ids] 
+
+        # [N, M]
+        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
+
+        # [N, M]
+        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
+
+        # Final cost matrix: [N, M]
+        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
+        # [N, M] -> [B, num_queries, M]
+        C = C.view(bs, num_queries, -1).cpu()
+
+        # Optimziee cost
+        sizes = [len(v["boxes"]) for v in targets]
+        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
+
+        return [(torch.as_tensor(i, dtype=torch.int64),   # tgt indexes
+                 torch.as_tensor(j, dtype=torch.int64))   # pred indexes
+                 for i, j in indices]
+

+ 1 - 0
odlab/models/detectors/yolof/build.py

@@ -9,6 +9,7 @@ from .yolof import YOLOF
 def build_yolof(cfg, is_val=False):
     # -------------- Build YOLOF --------------
     model = YOLOF(cfg         = cfg,
+                  num_classes = cfg.num_classes,
                   conf_thresh = cfg.train_conf_thresh if is_val else cfg.test_conf_thresh,
                   nms_thresh  = cfg.train_nms_thresh  if is_val else cfg.test_nms_thresh,
                   topk        = cfg.train_topk        if is_val else cfg.test_topk,

+ 17 - 0
odlab/models/transformer/__init__.py

@@ -0,0 +1,17 @@
+from .transformer import DETRTransformer
+
+
+def build_transformer(cfg, return_intermediate_dec):
+    if cfg.transformer == "detr_transformer":
+        return DETRTransformer(hidden_dim     = cfg.hidden_dim,
+                               num_heads      = cfg.num_heads,
+                               ffn_dim        = cfg.feedforward_dim,
+                               num_enc_layers = cfg.num_enc_layers,
+                               num_dec_layers = cfg.num_dec_layers,
+                               dropout        = cfg.dropout,
+                               act_type       = cfg.tr_act,
+                               pre_norm       = cfg.pre_norm,
+                               return_intermediate_dec=return_intermediate_dec)
+    else:
+        raise NotImplementedError("Unknown transformer: {}".format(cfg.transformer))
+    

+ 115 - 0
odlab/models/transformer/transformer.py

@@ -0,0 +1,115 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# https://github.com/facebookresearch/detr
+
+import torch
+import torch.nn as nn
+
+try:
+    from .transformer_encoder import TransformerEncoderLayer, TransformerEncoder
+    from .transformer_decoder import TransformerDecoderLayer, TransformerDecoder
+except:
+    from  transformer_encoder import TransformerEncoderLayer, TransformerEncoder
+    from  transformer_decoder import TransformerDecoderLayer, TransformerDecoder
+
+
+class DETRTransformer(nn.Module):
+    def __init__(self,
+                 hidden_dim     :int = 512,
+                 num_heads      :int = 8,
+                 ffn_dim        :int = 2048,
+                 num_enc_layers :int = 6,
+                 num_dec_layers :int = 6,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 pre_norm       :bool  = False,
+                 return_intermediate_dec :bool = False):
+        super().__init__()
+        # ---------- Basic parameters ----------
+        self.hidden_dim = hidden_dim
+        self.num_heads = num_heads
+        self.ffn_dim  = ffn_dim
+        self.act_type = act_type
+        self.pre_norm = pre_norm
+        self.num_enc_layers = num_enc_layers
+        self.num_dec_layers = num_dec_layers
+        self.return_intermediate_dec = return_intermediate_dec
+        # ---------- Model parameters ----------
+        ## Encoder module
+        encoder_layer = TransformerEncoderLayer(
+            hidden_dim, num_heads, ffn_dim, dropout, act_type, pre_norm)
+        encoder_norm = nn.LayerNorm(hidden_dim) if pre_norm else None
+        self.encoder = TransformerEncoder(encoder_layer, num_enc_layers, encoder_norm)
+        ## Decoder module
+        decoder_layer = TransformerDecoderLayer(
+            hidden_dim, num_heads, ffn_dim, dropout, act_type, pre_norm)
+        decoder_norm = nn.LayerNorm(hidden_dim)
+        self.decoder = TransformerDecoder(decoder_layer, num_dec_layers, decoder_norm,
+                                          return_intermediate=return_intermediate_dec)
+
+        self._reset_parameters()
+
+    def _reset_parameters(self):
+        for p in self.parameters():
+            if p.dim() > 1:
+                nn.init.xavier_uniform_(p)
+
+    def get_posembed(self, embed_dim, src_mask, temperature=10000, normalize=False):
+        scale = 2 * torch.pi
+        num_pos_feats = embed_dim // 2
+        not_mask = ~src_mask
+
+        # [B, H, W]
+        y_embed = not_mask.cumsum(1, dtype=torch.float32)
+        x_embed = not_mask.cumsum(2, dtype=torch.float32)
+
+        # normalize grid coords
+        if normalize:
+            y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * scale
+            x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * scale
+    
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
+        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
+        dim_t = temperature ** (2 * dim_t_)
+
+        pos_x = torch.div(x_embed[..., None], dim_t)
+        pos_y = torch.div(y_embed[..., None], dim_t)
+        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
+        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
+
+        # [B, H, W, C] -> [B, C, H, W]
+        pos_embed = torch.cat((pos_y, pos_x), dim=-1).permute(0, 3, 1, 2)
+
+        return pos_embed
+
+    def forward(self, src, src_mask, query_embed):
+        bs, c, h, w = src.shape
+
+        # Get position embedding
+        pos_embed = self.get_posembed(c, src_mask, normalize=True)
+
+        # reshape: [B, C, H, W] -> [N, B, C], H=HW
+        src = src.flatten(2).permute(2, 0, 1)
+        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+        src_mask = src_mask.flatten(1)
+
+        # [Nq, C] -> [Nq, B, C]
+        query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
+
+        # Encoder
+        memory = self.encoder(src, src_key_padding_mask=src_mask, pos_embed=pos_embed)
+
+        # Decoder
+        tgt = torch.zeros_like(query_embed)
+        hs = self.decoder(tgt         = tgt,
+                          tgt_mask    = None,
+                          memory      = memory,
+                          memory_mask = src_mask,
+                          memory_pos  = pos_embed,
+                          query_pos   = query_embed)
+        
+        # [M, Nq, B, C] -> [M, B, Nq, C]
+        hs = hs.transpose(1, 2)
+        # [N, B, C] -> [B, C, N] -> [B, C, H, W]
+        memory = memory.permute(1, 2, 0).view(bs, c, h, w)
+
+        return hs, memory

+ 167 - 0
odlab/models/transformer/transformer_decoder.py

@@ -0,0 +1,167 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# https://github.com/facebookresearch/detr
+
+import torch
+import torch.nn as nn
+
+try:
+    from .utils import get_clones, get_activation_fn
+except:
+    from  utils import get_clones, get_activation_fn
+
+
+class TransformerDecoder(nn.Module):
+    def __init__(self,
+                 decoder_layer,
+                 num_layers,
+                 norm=None,
+                 return_intermediate=False):
+        super().__init__()
+        # --------- Basic parameters ---------
+        self.num_layers = num_layers
+        self.return_intermediate = return_intermediate
+        # --------- Model parameters ---------
+        self.layers = get_clones(decoder_layer, num_layers)
+        self.norm = norm
+
+    def forward(self,
+                tgt,
+                tgt_mask,
+                memory,
+                memory_mask,
+                memory_pos,
+                query_pos):
+        output = tgt
+
+        intermediate = []
+
+        for layer in self.layers:
+            output = layer(output,
+                           tgt_mask,
+                           memory,
+                           memory_mask,
+                           memory_pos,
+                           query_pos)
+            if self.return_intermediate:
+                intermediate.append(self.norm(output))
+
+        if self.norm is not None:
+            output = self.norm(output)
+            if self.return_intermediate:
+                intermediate.pop()
+                intermediate.append(output)
+
+        if self.return_intermediate:
+            return torch.stack(intermediate)
+
+        return output.unsqueeze(0)   # [M, N, B, C]
+
+class TransformerDecoderLayer(nn.Module):
+    def __init__(self,
+                 hidden_dim,
+                 num_heads,
+                 ffn_dim=2048,
+                 dropout=0.1,
+                 act_type="relu",
+                 pre_norm=False):
+        super().__init__()
+        # ---------- Basic parameters ----------
+        self.hidden_dim = hidden_dim
+        self.num_heads = num_heads
+        self.ffn_dim  = ffn_dim
+        self.act_type = act_type
+        self.pre_norm = pre_norm
+        # ---------- Model parameters ----------
+        ## MHSA for object queries
+        self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)
+        self.dropout1  = nn.Dropout(dropout)
+        self.norm1     = nn.LayerNorm(hidden_dim)
+
+        ## MHCA for object queries
+        self.multihead_attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.norm2    = nn.LayerNorm(hidden_dim)
+
+        ## Feedforward network
+        self.linear1    = nn.Linear(hidden_dim, ffn_dim)
+        self.activation = get_activation_fn(act_type)
+        self.dropout    = nn.Dropout(dropout)
+        self.linear2    = nn.Linear(ffn_dim, hidden_dim)
+        self.dropout3   = nn.Dropout(dropout)
+        self.norm3      = nn.LayerNorm(hidden_dim)
+
+
+    def with_pos_embed(self, tensor, pos_embed):
+        return tensor if pos_embed is None else tensor + pos_embed
+
+    def forward_post(self,
+                     tgt,
+                     tgt_mask,
+                     memory,
+                     memory_mask,
+                     memory_pos,
+                     query_pos,
+                     ):
+        # MHSA for object queries
+        q = k = self.with_pos_embed(tgt, query_pos)
+        tgt2 = self.self_attn(q, k, tgt, attn_mask=tgt_mask)[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+
+        # MHCA between object queries and image features
+        q = self.with_pos_embed(tgt, query_pos)
+        k = self.with_pos_embed(memory, memory_pos)
+        tgt2 = self.multihead_attn(q, k, memory, key_padding_mask=memory_mask)[0]
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+
+        # FFN
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+        tgt = tgt + self.dropout3(tgt2)
+        tgt = self.norm3(tgt)
+
+        return tgt
+
+    def forward_pre(self,
+                    tgt,
+                    tgt_mask,
+                    memory,
+                    memory_mask,
+                    memory_pos,
+                    query_pos,
+                    ):
+        # MHSA for object queries
+        tgt2 = self.norm1(tgt)
+        q = k = self.with_pos_embed(tgt2, query_pos)
+        tgt2 = self.self_attn(q, k, tgt2, attn_mask=tgt_mask)[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt2 = self.norm2(tgt)
+
+        # MHCA between object queries and image features
+        q = self.with_pos_embed(tgt2, query_pos)
+        k = self.with_pos_embed(memory, memory_pos)
+        tgt2 = self.multihead_attn(q, k, memory, key_padding_mask=memory_mask)[0]
+        tgt = tgt + self.dropout2(tgt2)
+
+        # FFN
+        tgt2 = self.norm3(tgt)
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
+        tgt = tgt + self.dropout3(tgt2)
+
+        return tgt
+
+    def forward(self,
+                tgt,
+                tgt_mask,
+                memory,
+                memory_mask,
+                memory_pos,
+                query_pos,):
+        if self.pre_norm:
+            return self.forward_pre(tgt, tgt_mask, memory, memory_mask, memory_pos, query_pos)
+        else:
+            return self.forward_post(tgt, tgt_mask, memory, memory_mask, memory_pos, query_pos)
+
+
+if __name__ == "__main__":
+    pass

+ 105 - 0
odlab/models/transformer/transformer_encoder.py

@@ -0,0 +1,105 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# https://github.com/facebookresearch/detr
+
+import torch
+import torch.nn as nn
+
+try:
+    from .utils import get_clones, get_activation_fn
+except:
+    from  utils import get_clones, get_activation_fn
+
+
+class TransformerEncoder(nn.Module):
+    def __init__(self,
+                 encoder_layer,
+                 num_layers,
+                 norm=None):
+        super().__init__()
+        # -------- Basic parameters --------
+        self.num_layers = num_layers
+        # -------- Model parameters --------
+        self.layers = get_clones(encoder_layer, num_layers)
+        self.norm = norm
+
+    def forward(self, src, src_mask, pos_embed):
+        output = src
+
+        for layer in self.layers:
+            output = layer(output, src_mask, pos_embed)
+
+        if self.norm is not None:
+            output = self.norm(output)
+
+        return output
+
+class TransformerEncoderLayer(nn.Module):
+    def __init__(self,
+                 hidden_dim :int = 256,
+                 num_heads  :int = 8,
+                 ffn_dim    :int = 2048,
+                 dropout    :float = 0.1,
+                 act_type   :str   = "relu",
+                 pre_norm   :bool  = False,):
+        super().__init__()
+        # ---------- Basic parameters ----------
+        self.hidden_dim = hidden_dim
+        self.num_heads = num_heads
+        self.ffn_dim  = ffn_dim
+        self.act_type = act_type
+        self.pre_norm = pre_norm
+        # ---------- Model parameters ----------
+        # Multi-head Self-Attn
+        self.self_attn = nn.MultiheadAttention(hidden_dim, num_heads, dropout=dropout)
+        self.dropout1  = nn.Dropout(dropout)
+        self.norm1     = nn.LayerNorm(hidden_dim)
+
+        ## Feedforward network
+        self.linear1    = nn.Linear(hidden_dim, ffn_dim)
+        self.activation = get_activation_fn(act_type)
+        self.dropout    = nn.Dropout(dropout)
+        self.linear2    = nn.Linear(ffn_dim, hidden_dim)
+        self.dropout2   = nn.Dropout(dropout)
+        self.norm2      = nn.LayerNorm(hidden_dim)
+
+
+    def with_pos_embed(self, tensor, pos_embed):
+        return tensor if pos_embed is None else tensor + pos_embed
+
+    def forward_post(self, src, src_mask, pos_embed):
+        # MSHA
+        q = k = self.with_pos_embed(src, pos_embed)
+        src2 = self.self_attn(q, k, src, src_key_padding_mask=src_mask)[0]
+        src = src + self.dropout1(src2)
+        src = self.norm1(src)
+
+        # FFN
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+        src = src + self.dropout2(src2)
+        src = self.norm2(src)
+
+        return src
+
+    def forward_pre(self, src, src_mask, pos_embed):
+        # MSHA
+        src2 = self.norm1(src)
+        q = k = self.with_pos_embed(src2, pos_embed)
+        src2 = self.self_attn(q, k, src2, src_key_padding_mask=src_mask)[0]
+        src = src + self.dropout1(src2)
+
+        # FFN
+        src2 = self.norm2(src)
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
+        src = src + self.dropout2(src2)
+
+        return src
+
+    def forward(self, src, src_mask, pos_embed):
+        if self.pre_norm:
+            return self.forward_pre(src, src_mask, pos_embed)
+        else:
+            return self.forward_post(src, src_mask, pos_embed)
+
+
+if __name__ == "__main__":
+    pass

+ 21 - 0
odlab/models/transformer/utils.py

@@ -0,0 +1,21 @@
+# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
+# https://github.com/facebookresearch/detr
+
+import copy
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+def get_clones(module, N):
+    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+
+def get_activation_fn(activation):
+    """Return an activation function given a string"""
+    if activation == "relu":
+        return F.relu
+    if activation == "gelu":
+        return F.gelu
+    if activation == "glu":
+        return F.glu
+    raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
+