浏览代码

modify TAL

yjh0410 1 年之前
父节点
当前提交
24de483708

+ 2 - 2
config/model_config/rtdetr_config.py

@@ -10,7 +10,7 @@ rtdetr_cfg = {
         'depth': 1.0,
         ## Image Encoder - Backbone
         'backbone': 'resnet18',
-        'backbone_norm': 'BN',
+        'backbone_norm': 'FrozeBN',
         'res5_dilation': False,
         'pretrained': True,
         'pretrained_weight': 'imagenet1k_v1',
@@ -34,7 +34,7 @@ rtdetr_cfg = {
         'de_num_layers': 3,
         'de_mlp_ratio': 4.0,
         'de_dropout': 0.0,
-        'de_act': 'gelu',
+        'de_act': 'relu',
         'de_num_points': 4,
         'num_queries': 300,
         'learnt_init_query': False,

+ 4 - 2
engine.py

@@ -1137,7 +1137,7 @@ class RTRTrainer(object):
         self.heavy_eval = False
         # weak augmentatino stage
         self.second_stage = False
-        self.second_stage_epoch = args.no_aug_epoch
+        self.second_stage_epoch = -1
         # path to save model
         self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
         os.makedirs(self.path_to_save, exist_ok=True)
@@ -1157,6 +1157,8 @@ class RTRTrainer(object):
             args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['out_stride'][-1], is_train=True)
         self.val_transform, _ = build_transform(
             args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['out_stride'][-1], is_train=False)
+        if self.trans_cfg["mosaic_prob"] > 0.5:
+            self.second_stage_epoch = 5
 
         # ---------------------------- Build Dataset & Dataloader ----------------------------
         self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
@@ -1169,7 +1171,7 @@ class RTRTrainer(object):
         self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
 
         # ---------------------------- Build Optimizer ----------------------------
-        self.optimizer_dict['lr0'] *= self.args.batch_size / 16.
+        self.optimizer_dict['lr0'] *= self.args.batch_size / 16.  # auto lr scaling
         self.optimizer, self.start_epoch = build_detr_optimizer(self.optimizer_dict, model, self.args.resume)
 
         # ---------------------------- Build LR Scheduler ----------------------------

+ 5 - 5
models/detectors/rtcdet/matcher.py

@@ -64,7 +64,7 @@ class TaskAlignedAssigner(nn.Module):
     def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts):
         """Compute alignment metric given predicted and ground truth bounding boxes."""
         na = pd_bboxes.shape[-2]
-        mask_gt = mask_gt.bool()  # b, max_num_obj, h*w
+        mask_in_gts = mask_in_gts.bool()  # b, max_num_obj, h*w
         overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
         bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
 
@@ -72,12 +72,12 @@ class TaskAlignedAssigner(nn.Module):
         ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
         ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
         # Get the scores of each grid for each gt cls
-        bbox_scores[mask_gt] = pd_scores[ind[0], :, ind[1]][mask_gt]  # b, max_num_obj, h*w
+        bbox_scores[mask_in_gts] = pd_scores[ind[0], :, ind[1]][mask_in_gts]  # b, max_num_obj, h*w
 
         # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
-        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_gt]
-        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_gt]
-        overlaps[mask_gt] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
+        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_in_gts]
+        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_in_gts]
+        overlaps[mask_in_gts] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
 
         align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
         return align_metric, overlaps

+ 4 - 3
models/detectors/rtdetr/basic_modules/dn_compoments.py

@@ -40,18 +40,19 @@ def get_contrastive_denoising_training_group(targets,
     input_query_class = torch.full([bs, max_gt_num], num_classes, device=class_embed.device).long()
     # [bs, max_gt_num, 4]
     input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=class_embed.device)
+    # [bs, max_gt_num]
     pad_gt_mask = torch.zeros([bs, max_gt_num], device=class_embed.device)
     for i in range(bs):
         num_gt = num_gts[i]
         if num_gt > 0:
-            input_query_class[i, :num_gt] = targets[i]["labels"].squeeze(-1)
+            input_query_class[i, :num_gt] = targets[i]["labels"]
             input_query_bbox[i, :num_gt] = targets[i]["boxes"]
             pad_gt_mask[i, :num_gt] = 1
 
     # each group has positive and negative queries.
     input_query_class = input_query_class.repeat(1, 2 * num_group)  # [bs, 2*num_denoising], num_denoising = 2 * num_group * max_gt_num
     input_query_bbox = input_query_bbox.repeat(1, 2 * num_group, 1) # [bs, 2*num_denoising, 4]
-    pad_gt_mask = pad_gt_mask.repeat(1, 2 * num_group)
+    pad_gt_mask = pad_gt_mask.repeat(1, 2 * num_group)              # [bs, 2*num_denoising]
 
     # positive and negative mask
     negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=class_embed.device)
@@ -75,7 +76,7 @@ def get_contrastive_denoising_training_group(targets,
         chosen_idx = torch.nonzero(mask * pad_gt_mask).squeeze(-1)
         # randomly put a new one here
         new_label = torch.randint_like(
-            chosen_idx, 0, num_classes, dtype=input_query_class.dtype, device=class_embed.device)
+            chosen_idx, 0, num_classes, dtype=input_query_class.dtype, device=class_embed.device) # [b * num_denoising]
         # [bs * num_denoising]
         input_query_class = torch.scatter(input_query_class, 0, chosen_idx, new_label)
         # input_query_class.scatter_(chosen_idx, new_label)

+ 16 - 10
models/detectors/rtdetr/basic_modules/transformer.py

@@ -340,20 +340,26 @@ class TransformerEncoder(nn.Module):
     def forward(self, src):
         """
         Input:
-            src:       [torch.Tensor] -> [B, C, H, W]
+            src:  [torch.Tensor] -> [B, C, H, W]
         Output:
-            src:       [torch.Tensor] -> [B, N, C]
+            src:  [torch.Tensor] -> [B, C, H, W]
         """
         # -------- Transformer encoder --------
+        channels, fmp_h, fmp_w = src.shape[1:]
+        # [B, C, H, W] -> [B, N, C], N=HxW
+        src_flatten = src.flatten(2).permute(0, 2, 1)
+        memory = src_flatten
+
+        # PosEmbed: [1, N, C]
+        pos_embed = self.build_2d_sincos_position_embedding(
+            src.device, fmp_w, fmp_h, channels, self.pe_temperature)
+        
+        # Transformer Encoder layer
         for encoder in self.encoder_layers:
-            channels, fmp_h, fmp_w = src.shape[1:]
-            # [B, C, H, W] -> [B, N, C], N=HxW
-            src_flatten = src.flatten(2).permute(0, 2, 1)
-            pos_embed = self.build_2d_sincos_position_embedding(src.device,
-                    fmp_w, fmp_h, channels, self.pe_temperature)
-            memory = encoder(src_flatten, pos_embed=pos_embed)
-            # [B, N, C] -> [B, C, N] -> [B, C, H, W]
-            src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
+            memory = encoder(memory, pos_embed=pos_embed)
+
+        # Output: [B, N, C] -> [B, C, N] -> [B, C, H, W]
+        src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
 
         return src
 

+ 19 - 19
models/detectors/rtdetr/loss.py

@@ -28,11 +28,11 @@ class Criterion(object):
                                         cfg['matcher_hpy']['cost_giou'],
                                         alpha=0.25,
                                         gamma=2.0)
-        self.loss = DINOLoss(num_classes = num_classes,
-                                matcher     = self.matcher,
-                                aux_loss    = True,
-                                use_vfl     = cfg['use_vfl'],
-                                loss_coeff  = cfg['loss_coeff'])
+        self.loss = DINOLoss(num_classes   = num_classes,
+                                matcher    = self.matcher,
+                                aux_loss   = True,
+                                use_vfl    = cfg['use_vfl'],
+                                loss_coeff = cfg['loss_coeff'])
 
     def __call__(self, dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta, targets=None):
         assert targets is not None
@@ -43,13 +43,13 @@ class Criterion(object):
         if dn_meta is not None:
             if isinstance(dn_meta, list):
                 dual_groups = len(dn_meta) - 1
-                dec_out_bboxes = torch.split(
+                dec_out_bboxes = torch.chunk(
                     dec_out_bboxes, dual_groups + 1, dim=2)
-                dec_out_logits = torch.split(
+                dec_out_logits = torch.chunk(
                     dec_out_logits, dual_groups + 1, dim=2)
-                enc_topk_bboxes = torch.split(
+                enc_topk_bboxes = torch.chunk(
                     enc_topk_bboxes, dual_groups + 1, dim=1)
-                enc_topk_logits = torch.split(
+                enc_topk_logits = torch.splchunkt(
                     enc_topk_logits, dual_groups + 1, dim=1)
 
                 loss = {}
@@ -86,7 +86,7 @@ class Criterion(object):
                     # sum loss
                     for key, value in loss_gid.items():
                         loss.update({
-                            key: loss.get(key, torch.zeros([1])) + value
+                            key: loss.get(key, torch.zeros([1], device=out_bboxes_gid.device)) + value
                         })
 
                 # average across (dual_groups + 1)
@@ -124,9 +124,8 @@ class DETRLoss(nn.Module):
                  aux_loss=True,
                  use_vfl=False,
                  loss_coeff={'class': 1,
-                             'bbox': 5,
-                             'giou': 2,
-                             'no_object': 0.1,},
+                             'bbox':  5,
+                             'giou':  2,},
                  ):
         super(DETRLoss, self).__init__()
         self.num_classes = num_classes
@@ -186,20 +185,21 @@ class DETRLoss(nn.Module):
 
         loss = dict()
         if sum(len(a) for a in gt_bbox) == 0:
-            loss[name_bbox] = torch.as_tensor([0.])
-            loss[name_giou] = torch.as_tensor([0.])
+            loss[name_bbox] = torch.as_tensor([0.], device=boxes.device)
+            loss[name_giou] = torch.as_tensor([0.], device=boxes.device)
             return loss
 
         # prepare positive samples
         src_bbox, target_bbox = self._get_src_target_assign(boxes, gt_bbox, match_indices)
 
         # Compute L1 loss
-        loss[name_bbox] = self.loss_coeff['bbox'] * F.l1_loss(
-            src_bbox, target_bbox, reduction='sum') / num_gts
+        loss[name_bbox] = F.l1_loss(src_bbox, target_bbox, reduction='none')
+        loss[name_bbox] = loss[name_bbox].sum() / num_gts
+        loss[name_bbox] = self.loss_coeff['bbox'] * loss[name_bbox]
         
         # Compute GIoU loss
-        loss[name_giou] = self.giou_loss(
-            box_cxcywh_to_xyxy(src_bbox), box_cxcywh_to_xyxy(target_bbox))
+        loss[name_giou] = self.giou_loss(box_cxcywh_to_xyxy(src_bbox),
+                                         box_cxcywh_to_xyxy(target_bbox))
         loss[name_giou] = loss[name_giou].sum() / num_gts
         loss[name_giou] = self.loss_coeff['giou'] * loss[name_giou]
 

+ 1 - 1
models/detectors/rtdetr/rtdetr.py

@@ -15,7 +15,7 @@ class RT_DETR(nn.Module):
                  cfg,
                  num_classes = 80,
                  conf_thresh = 0.1,
-                 topk        = 100,
+                 topk        = 300,
                  deploy      = False,
                  no_multi_labels = False,
                  ):

+ 43 - 33
models/detectors/rtdetr/rtdetr_decoder.py

@@ -2,7 +2,7 @@ import math
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from torch.nn.init import constant_, xavier_uniform_, uniform_
+from torch.nn.init import constant_, xavier_uniform_, uniform_, normal_
 from typing import List
 
 try:
@@ -96,7 +96,7 @@ class RTDETRTransformer(nn.Module):
         )
 
         ## Deformable transformer decoder
-        self.transformer_decoder = DeformableTransformerDecoder(
+        self.decoder = DeformableTransformerDecoder(
                                     d_model    = hidden_dim,
                                     num_heads  = num_heads,
                                     num_layers = num_layers,
@@ -116,7 +116,7 @@ class RTDETRTransformer(nn.Module):
         self.enc_class_head = nn.Linear(hidden_dim, num_classes)
         self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
 
-        ##  Detection head for Decoder
+        ## Detection head for Decoder
         self.dec_class_head = nn.ModuleList([
             nn.Linear(hidden_dim, num_classes)
             for _ in range(num_layers)
@@ -126,18 +126,18 @@ class RTDETRTransformer(nn.Module):
             for _ in range(num_layers)
         ])
 
-        ## Denoising part
-        self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim)
-
         ## Object query
         if learnt_init_query:
             self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
         self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
 
+        ## Denoising part
+        self.denoising_class_embed = nn.Embedding(num_classes, hidden_dim)
+
         self._reset_parameters()
 
     def _reset_parameters(self):
-        def _linear_init(module):
+        def linear_init_(module):
             bound = 1 / math.sqrt(module.weight.shape[0])
             uniform_(module.weight, -bound, bound)
             if hasattr(module, "bias") and module.bias is not None:
@@ -146,17 +146,17 @@ class RTDETRTransformer(nn.Module):
         # class and bbox head init
         prior_prob = 0.01
         cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob))
-        _linear_init(self.enc_class_head)
+        linear_init_(self.enc_class_head)
         constant_(self.enc_class_head.bias, cls_bias_init)
         constant_(self.enc_bbox_head.layers[-1].weight, 0.)
         constant_(self.enc_bbox_head.layers[-1].bias, 0.)
         for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head):
-            _linear_init(cls_)
+            linear_init_(cls_)
             constant_(cls_.bias, cls_bias_init)
             constant_(reg_.layers[-1].weight, 0.)
             constant_(reg_.layers[-1].bias, 0.)
 
-        _linear_init(self.enc_output[0])
+        linear_init_(self.enc_output[0])
         xavier_uniform_(self.enc_output[0].weight)
         if self.learnt_init_query:
             xavier_uniform_(self.tgt_embed.weight)
@@ -164,21 +164,25 @@ class RTDETRTransformer(nn.Module):
         xavier_uniform_(self.query_pos_head.layers[1].weight)
         for l in self.input_proj_layers:
             xavier_uniform_(l.conv.weight)
+        normal_(self.denoising_class_embed.weight)
 
     def generate_anchors(self, spatial_shapes, grid_size=0.05):
         anchors = []
         for lvl, (h, w) in enumerate(spatial_shapes):
             grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w))
+            # [H, W, 2]
             grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
 
             valid_WH = torch.as_tensor([w, h]).float()
             grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
             wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
-            anchors.append(torch.cat([grid_xy, wh], -1).reshape([-1, h * w, 4]))
-
-        anchors = torch.cat(anchors, 1)
+            # [H, W, 4] -> [1, N, 4], N=HxW
+            anchors.append(torch.cat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4))
+        # List[L, 1, N_i, 4] -> [1, N, 4], N=N_0 + N_1 + N_2 + ...
+        anchors = torch.cat(anchors, dim=1)
         valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
         anchors = torch.log(anchors / (1 - anchors))
+        # Equal to operation: anchors = torch.masked_fill(anchors, ~valid_mask, torch.as_tensor(float("inf")))
         anchors = torch.where(valid_mask, anchors, torch.as_tensor(float("inf")))
         
         return anchors, valid_mask
@@ -193,15 +197,14 @@ class RTDETRTransformer(nn.Module):
         level_start_index = [0, ]
         for i, feat in enumerate(proj_feats):
             _, _, h, w = feat.shape
-            # [b, c, h, w] -> [b, h*w, c]
-            feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
-            # [num_levels, 2]
             spatial_shapes.append([h, w])
             # [l], start index of each level
             level_start_index.append(h * w + level_start_index[-1])
+            # [B, C, H, W] -> [B, N, C], N=HxW
+            feat_flatten.append(feat.flatten(2).permute(0, 2, 1))
 
-        # [b, l, c]
-        feat_flatten = torch.cat(feat_flatten, 1)
+        # [B, N, C], N = N_0 + N_1 + ...
+        feat_flatten = torch.cat(feat_flatten, dim=1)
         level_start_index.pop()
 
         return (feat_flatten, spatial_shapes, level_start_index)
@@ -212,20 +215,26 @@ class RTDETRTransformer(nn.Module):
                           denoising_class=None,
                           denoising_bbox_unact=None):
         bs, _, _ = memory.shape
-        # prepare input for decoder
+        # Prepare input for decoder
         anchors, valid_mask = self.generate_anchors(spatial_shapes)
         anchors = anchors.to(memory.device)
         valid_mask = valid_mask.to(memory.device)
-        memory = torch.where(valid_mask, memory, torch.as_tensor(0.))
+        
+        # Process encoder's output
+        memory = torch.where(valid_mask, memory, torch.as_tensor(0., device=memory.device))
         output_memory = self.enc_output(memory)
 
-        # [bs, num_quries, c]
+        # Head for encoder's output : [bs, num_quries, c]
         enc_outputs_class = self.enc_class_head(output_memory)
         enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
 
+        # Topk proposals from encoder's output
         topk = self.num_queries
-        topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1]  # [bs, topk]
-        reference_points_unact = torch.gather(enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4)) # [bs, topk, 4]
+        topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1]  # [bs, num_queries]
+        enc_topk_logits = torch.gather(
+            enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes))  # [bs, num_queries, nc]
+        reference_points_unact = torch.gather(
+            enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4))    # [bs, num_queries, 4]
         enc_topk_bboxes = F.sigmoid(reference_points_unact)
 
         if denoising_bbox_unact is not None:
@@ -233,12 +242,13 @@ class RTDETRTransformer(nn.Module):
                 [denoising_bbox_unact, reference_points_unact], 1)
         if self.training:
             reference_points_unact = reference_points_unact.detach()
-        enc_topk_logits = torch.gather(enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes))  # [bs, topk, nc]
 
-        # extract region features
+        # Extract region features
         if self.learnt_init_query:
+            # [num_queries, c] -> [b, num_queries, c]
             target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
         else:
+            # [num_queries, c] -> [b, num_queries, c]
             target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
             if self.training:
                 target = target.detach()
@@ -269,14 +279,14 @@ class RTDETRTransformer(nn.Module):
             memory, spatial_shapes, denoising_class, denoising_bbox_unact)
 
         # decoder
-        out_bboxes, out_logits = self.transformer_decoder(target,
-                                                          init_ref_points_unact,
-                                                          memory,
-                                                          spatial_shapes,
-                                                          self.dec_bbox_head,
-                                                          self.dec_class_head,
-                                                          self.query_pos_head,
-                                                          attn_mask)
+        out_bboxes, out_logits = self.decoder(target,
+                                              init_ref_points_unact,
+                                              memory,
+                                              spatial_shapes,
+                                              self.dec_bbox_head,
+                                              self.dec_class_head,
+                                              self.query_pos_head,
+                                              attn_mask)
         
         return out_bboxes, out_logits, enc_topk_bboxes, enc_topk_logits, dn_meta
 

+ 121 - 0
models/detectors/rtpdetr/basic_modules/backbone.py

@@ -0,0 +1,121 @@
+import torch
+import torchvision
+from torch import nn
+from torchvision.models._utils import IntermediateLayerGetter
+from torchvision.models.resnet import (ResNet18_Weights,
+                                       ResNet34_Weights,
+                                       ResNet50_Weights,
+                                       ResNet101_Weights)
+try:
+    from .basic import FrozenBatchNorm2d
+except:
+    from basic  import FrozenBatchNorm2d
+   
+
+# IN1K pretrained weights
+pretrained_urls = {
+    # ResNet series
+    'resnet18':  ResNet18_Weights,
+    'resnet34':  ResNet34_Weights,
+    'resnet50':  ResNet50_Weights,
+    'resnet101': ResNet101_Weights,
+    # ShuffleNet series
+}
+
+
+# ----------------- Model functions -----------------
+## Build backbone network
+def build_backbone(cfg, pretrained):
+    print('==============================')
+    print('Backbone: {}'.format(cfg['backbone']))
+    # ResNet
+    if 'resnet' in cfg['backbone']:
+        pretrained_weight = cfg['pretrained_weight'] if pretrained else None
+        model, feats = build_resnet(cfg, pretrained_weight)
+    elif 'svnetv2' in cfg['backbone']:
+        pretrained_weight = cfg['pretrained_weight'] if pretrained else None
+        model, feats = build_scnetv2(cfg, pretrained_weight)
+    else:
+        raise NotImplementedError("Unknown backbone: <>.".format(cfg['backbone']))
+    
+    return model, feats
+
+
+# ----------------- ResNet Backbone -----------------
+class ResNet(nn.Module):
+    """ResNet backbone with frozen BatchNorm."""
+    def __init__(self, name: str, res5_dilation: bool, norm_type: str, pretrained_weights: str = "imagenet1k_v1"):
+        super().__init__()
+        # Pretrained
+        assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
+        if pretrained_weights is not None:
+            if name in ('resnet18', 'resnet34'):
+                pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
+            else:
+                if pretrained_weights == "imagenet1k_v1":
+                    pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
+                else:
+                    pretrained_weights = pretrained_urls[name].IMAGENET1K_V2
+        else:
+            pretrained_weights = None
+        print('ImageNet pretrained weight: ', pretrained_weights)
+        # Norm layer
+        if norm_type == 'BN':
+            norm_layer = nn.BatchNorm2d
+        elif norm_type == 'FrozeBN':
+            norm_layer = FrozenBatchNorm2d
+        # Backbone
+        backbone = getattr(torchvision.models, name)(
+            replace_stride_with_dilation=[False, False, res5_dilation],
+            norm_layer=norm_layer, weights=pretrained_weights)
+        return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
+        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+        self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
+        # Freeze
+        for name, parameter in backbone.named_parameters():
+            if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+                parameter.requires_grad_(False)
+
+    def forward(self, x):
+        xs = self.body(x)
+        fmp_list = []
+        for name, fmp in xs.items():
+            fmp_list.append(fmp)
+
+        return fmp_list
+
+def build_resnet(cfg, pretrained_weight=None):
+    # ResNet series
+    backbone = ResNet(cfg['backbone'], cfg['res5_dilation'], cfg['backbone_norm'], pretrained_weight)
+
+    return backbone, backbone.feat_dims
+
+
+# ----------------- ShuffleNet Backbone -----------------
+## TODO: Add shufflenet-v2
+class ShuffleNetv2:
+    pass
+
+def build_scnetv2(cfg, pretrained_weight=None):
+    return
+
+
+if __name__ == '__main__':
+    cfg = {
+        'backbone':      'resnet18',
+        'backbone_norm': 'BN',
+        'res5_dilation': False,
+        'pretrained': True,
+        'pretrained_weight': 'imagenet1k_v1',
+    }
+    model, feat_dim = build_backbone(cfg, cfg['pretrained'])
+    print(feat_dim)
+
+    x = torch.randn(2, 3, 320, 320)
+    output = model(x)
+    for y in output:
+        print(y.size())
+
+    for n, p in model.named_parameters():
+        print(n.split(".")[-1])
+

+ 93 - 0
models/detectors/rtpdetr/basic_modules/basic.py

@@ -0,0 +1,93 @@
+import torch
+import torch.nn as nn
+
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type == 'gelu':
+        return nn.GELU()
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+
+# ----------------- MLP modules -----------------
+class MLP(nn.Module):
+    def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        return x
+
+class FFN(nn.Module):
+    def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
+        super().__init__()
+        self.fpn_dim = round(d_model * mlp_ratio)
+        self.linear1 = nn.Linear(d_model, self.fpn_dim)
+        self.activation = get_activation(act_type)
+        self.dropout2 = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(self.fpn_dim, d_model)
+        self.dropout3 = nn.Dropout(dropout)
+        self.norm = nn.LayerNorm(d_model)
+
+    def forward(self, src):
+        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+        src = src + self.dropout3(src2)
+        src = self.norm(src)
+        
+        return src
+    
+
+# ----------------- Basic CNN Ops -----------------
+class FrozenBatchNorm2d(torch.nn.Module):
+    def __init__(self, n):
+        super(FrozenBatchNorm2d, self).__init__()
+        self.register_buffer("weight", torch.ones(n))
+        self.register_buffer("bias", torch.zeros(n))
+        self.register_buffer("running_mean", torch.zeros(n))
+        self.register_buffer("running_var", torch.ones(n))
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        num_batches_tracked_key = prefix + 'num_batches_tracked'
+        if num_batches_tracked_key in state_dict:
+            del state_dict[num_batches_tracked_key]
+
+        super(FrozenBatchNorm2d, self)._load_from_state_dict(
+            state_dict, prefix, local_metadata, strict,
+            missing_keys, unexpected_keys, error_msgs)
+
+    def forward(self, x):
+        # move reshapes to the beginning
+        # to make it fuser-friendly
+        w = self.weight.reshape(1, -1, 1, 1)
+        b = self.bias.reshape(1, -1, 1, 1)
+        rv = self.running_var.reshape(1, -1, 1, 1)
+        rm = self.running_mean.reshape(1, -1, 1, 1)
+        eps = 1e-5
+        scale = w * (rv + eps).rsqrt()
+        bias = b - rm * scale
+        return x * scale + bias

+ 288 - 0
models/detectors/rtpdetr/basic_modules/transformer.py

@@ -0,0 +1,288 @@
+import math
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import constant_, xavier_uniform_
+
+try:
+    from .basic import get_activation, MLP, FFN
+except:
+    from  basic import get_activation, MLP, FFN
+
+
+def get_clones(module, N):
+    if N <= 0:
+        return None
+    else:
+        return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
+
+def inverse_sigmoid(x, eps=1e-5):
+    x = x.clamp(min=0., max=1.)
+    return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
+
+
+# ----------------- Transformer modules -----------------
+## Transformer Encoder layer
+class TransformerEncoderLayer(nn.Module):
+    def __init__(self,
+                 d_model         :int   = 256,
+                 num_heads       :int   = 8,
+                 mlp_ratio       :float = 4.0,
+                 dropout         :float = 0.1,
+                 act_type        :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        # ----------- Basic parameters -----------
+        # Multi-head Self-Attn
+        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
+        self.dropout = nn.Dropout(dropout)
+        self.norm = nn.LayerNorm(d_model)
+
+        # Feedforwaed Network
+        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
+
+    def with_pos_embed(self, tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+
+    def forward(self, src, pos_embed):
+        """
+        Input:
+            src:       [torch.Tensor] -> [B, N, C]
+            pos_embed: [torch.Tensor] -> [B, N, C]
+        Output:
+            src:       [torch.Tensor] -> [B, N, C]
+        """
+        q = k = self.with_pos_embed(src, pos_embed)
+
+        # -------------- MHSA --------------
+        src2 = self.self_attn(q, k, value=src)[0]
+        src = src + self.dropout(src2)
+        src = self.norm(src)
+
+        # -------------- FFN --------------
+        src = self.ffn(src)
+        
+        return src
+
+## Transformer Encoder
+class TransformerEncoder(nn.Module):
+    def __init__(self,
+                 d_model        :int   = 256,
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 mlp_ratio      :float = 4.0,
+                 pe_temperature : float = 10000.,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_layers = num_layers
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        self.pe_temperature = pe_temperature
+        self.pos_embed = None
+        # ----------- Basic parameters -----------
+        self.encoder_layers = get_clones(
+            TransformerEncoderLayer(d_model, num_heads, mlp_ratio, dropout, act_type), num_layers)
+
+    def build_2d_sincos_position_embedding(self, device, w, h, embed_dim=256, temperature=10000.):
+        assert embed_dim % 4 == 0, \
+            'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
+        
+        # ----------- Check cahed pos_embed -----------
+        if self.pos_embed is not None and \
+            self.pos_embed.shape[2:] == [h, w]:
+            return self.pos_embed
+        
+        # ----------- Generate grid coords -----------
+        grid_w = torch.arange(int(w), dtype=torch.float32)
+        grid_h = torch.arange(int(h), dtype=torch.float32)
+        grid_w, grid_h = torch.meshgrid([grid_w, grid_h])  # shape: [H, W]
+
+        pos_dim = embed_dim // 4
+        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
+        omega = 1. / (temperature**omega)
+
+        out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
+        out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
+
+        # shape: [1, N, C]
+        pos_embed = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], dim=1)[None, :, :]
+        pos_embed = pos_embed.to(device)
+        self.pos_embed = pos_embed
+
+        return pos_embed
+
+    def forward(self, src):
+        """
+        Input:
+            src:  [torch.Tensor] -> [B, C, H, W]
+        Output:
+            src:  [torch.Tensor] -> [B, C, H, W]
+        """
+        # -------- Transformer encoder --------
+        channels, fmp_h, fmp_w = src.shape[1:]
+        # [B, C, H, W] -> [B, N, C], N=HxW
+        src_flatten = src.flatten(2).permute(0, 2, 1)
+        memory = src_flatten
+
+        # PosEmbed: [1, N, C]
+        pos_embed = self.build_2d_sincos_position_embedding(
+            src.device, fmp_w, fmp_h, channels, self.pe_temperature)
+        
+        # Transformer Encoder layer
+        for encoder in self.encoder_layers:
+            memory = encoder(memory, pos_embed=pos_embed)
+
+        # Output: [B, N, C] -> [B, C, N] -> [B, C, H, W]
+        src = memory.permute(0, 2, 1).reshape([-1, channels, fmp_h, fmp_w])
+
+        return src
+
+## Transformer Decoder layer
+class TransformerDecoderLayer(nn.Module):
+    def __init__(self,
+                 d_model     :int   = 256,
+                 num_heads   :int   = 8,
+                 num_levels  :int   = 3,
+                 num_points  :int   = 4,
+                 mlp_ratio   :float = 4.0,
+                 dropout     :float = 0.1,
+                 act_type    :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_levels = num_levels
+        self.num_points = num_points
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        # ---------------- Network parameters ----------------
+        ## Multi-head Self-Attn
+        self.self_attn  = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+        self.dropout1 = nn.Dropout(dropout)
+        self.norm1 = nn.LayerNorm(d_model)
+        ## CrossAttention
+        self.cross_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.norm2 = nn.LayerNorm(d_model)
+        ## FFN
+        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
+
+    def with_pos_embed(self, tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+    def forward(self,
+                tgt,
+                reference_points,
+                memory,
+                memory_spatial_shapes,
+                attn_mask=None,
+                memory_mask=None,
+                query_pos_embed=None):
+        # ---------------- MSHA for Object Query -----------------
+        q = k = self.with_pos_embed(tgt, query_pos_embed)
+        if attn_mask is not None:
+            attn_mask = torch.where(
+                attn_mask.bool(),
+                torch.zeros(attn_mask.shape, dtype=tgt.dtype, device=attn_mask.device),
+                torch.full(attn_mask.shape, float("-inf"), dtype=tgt.dtype, device=attn_mask.device))
+        tgt2 = self.self_attn(q, k, value=tgt)[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+
+        # ---------------- CMHA for Object Query and Image-feature -----------------
+        tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos_embed),
+                               reference_points,
+                               memory,
+                               memory_spatial_shapes,
+                               memory_mask)
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+
+        # ---------------- FeedForward Network -----------------
+        tgt = self.ffn(tgt)
+
+        return tgt
+
+## Transformer Decoder
+class TransformerDecoder(nn.Module):
+    def __init__(self,
+                 d_model        :int   = 256,
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 num_levels     :int   = 3,
+                 num_points     :int   = 4,
+                 mlp_ratio      :float = 4.0,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 return_intermediate :bool = False,
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_layers = num_layers
+        self.mlp_ratio = mlp_ratio
+        self.dropout = dropout
+        self.act_type = act_type
+        self.pos_embed = None
+        # ----------- Network parameters -----------
+        self.decoder_layers = get_clones(
+            TransformerDecoderLayer(d_model, num_heads, num_levels, num_points, mlp_ratio, dropout, act_type), num_layers)
+        self.num_layers = num_layers
+        self.return_intermediate = return_intermediate
+
+    def forward(self,
+                tgt,
+                ref_points_unact,
+                memory,
+                memory_spatial_shapes,
+                bbox_head,
+                score_head,
+                query_pos_head,
+                attn_mask=None,
+                memory_mask=None):
+        output = tgt
+        dec_out_bboxes = []
+        dec_out_logits = []
+        ref_points_detach = F.sigmoid(ref_points_unact)
+        for i, layer in enumerate(self.decoder_layers):
+            ref_points_input = ref_points_detach.unsqueeze(2)
+            query_pos_embed = query_pos_head(ref_points_detach)
+
+            output = layer(output, ref_points_input, memory,
+                           memory_spatial_shapes, attn_mask,
+                           memory_mask, query_pos_embed)
+
+            inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
+                ref_points_detach))
+
+            dec_out_logits.append(score_head[i](output))
+            if i == 0:
+                dec_out_bboxes.append(inter_ref_bbox)
+            else:
+                dec_out_bboxes.append(
+                    F.sigmoid(bbox_head[i](output) + inverse_sigmoid(
+                        ref_points)))
+
+            ref_points = inter_ref_bbox
+            ref_points_detach = inter_ref_bbox.detach()
+
+        return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
+

+ 34 - 0
models/detectors/rtpdetr/build.py

@@ -0,0 +1,34 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import torch
+import torch.nn as nn
+
+from .loss import build_criterion
+from .rtdetr import RT_PDETR
+
+
+# build object detector
+def build_rtpdetr(args, cfg, num_classes=80, trainable=False, deploy=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+    
+    print('==============================')
+    print('Model Configuration: \n', cfg)
+    
+    # -------------- Build RT-DETR --------------
+    model = RT_PDETR(cfg             = cfg,
+                     num_classes     = num_classes,
+                     conf_thresh     = args.conf_thresh,
+                     topk            = 100,
+                     deploy          = deploy,
+                     no_multi_labels = args.no_multi_labels,
+                     )
+            
+    # -------------- Build criterion --------------
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(cfg, num_classes)
+        
+    return model, criterion

+ 200 - 0
models/detectors/rtpdetr/rtpdetr.py

@@ -0,0 +1,200 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .rtpdetr_encoder import build_image_encoder
+    from .rtpdetr_decoder import build_transformer
+except:
+    from  rtpdetr_encoder import build_image_encoder
+    from  rtpdetr_decoder import build_transformer
+
+
+# Real-time Plain Transformer-based Object Detector
+class RT_PDETR(nn.Module):
+    def __init__(self,
+                 cfg,
+                 num_classes = 80,
+                 conf_thresh = 0.1,
+                 topk        = 300,
+                 deploy      = False,
+                 no_multi_labels = False,
+                 ):
+        super().__init__()
+        # ----------- Basic setting -----------
+        self.num_classes = num_classes
+        self.num_topk = topk
+        self.conf_thresh = conf_thresh
+        self.no_multi_labels = no_multi_labels
+        self.deploy = deploy
+
+        # ----------- Network setting -----------
+        ## Image encoder
+        self.image_encoder = build_image_encoder(cfg)
+        self.feat_dim = self.image_encoder.fpn_dims[-1]
+
+        ## Detect decoder
+        self.detect_decoder = build_transformer(cfg, self.feat_dim, num_classes, return_intermediate=self.training)
+
+    def post_process(self, box_pred, cls_pred):
+        if self.no_multi_labels:
+            # [M,]
+            scores, labels = torch.max(cls_pred.sigmoid(), dim=1)
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.num_topk, box_pred.size(0))
+
+            # Topk candidates
+            predicted_prob, topk_idxs = scores.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # Filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            topk_idxs = topk_idxs[keep_idxs]
+
+            # Top-k results
+            topk_scores = topk_scores[keep_idxs]
+            topk_labels = labels[topk_idxs]
+            topk_bboxes = box_pred[topk_idxs]
+
+            return topk_bboxes, topk_scores, topk_labels
+        else:
+            # Top-k select
+            cls_pred = cls_pred[0].flatten().sigmoid_()
+            box_pred = box_pred[0]
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.num_topk, box_pred.size(0))
+
+            # Topk candidates
+            predicted_prob, topk_idxs = cls_pred.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:self.num_topk]
+
+            # Filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            scores = topk_scores[keep_idxs]
+            topk_idxs = topk_idxs[keep_idxs]
+            topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+
+            ## Top-k results
+            topk_scores = predicted_prob[:self.num_topk]
+            topk_labels = topk_idxs % self.num_classes
+            topk_bboxes = box_pred[topk_box_idxs]
+
+        return topk_bboxes, topk_scores, topk_labels
+    
+    def forward(self, x, targets=None):
+        # ----------- Image Encoder -----------
+        pyramid_feats = self.image_encoder(x)
+
+        # ----------- Transformer -----------
+        transformer_outputs = self.detect_decoder(pyramid_feats, targets)
+
+        if self.training:
+            return transformer_outputs
+        else:
+            pred_boxes, pred_logits = transformer_outputs[0], transformer_outputs[1]
+            box_preds = pred_boxes[-1]
+            cls_preds = pred_logits[-1]
+            
+            # post-process
+            bboxes, scores, labels = self.post_process(box_preds, cls_preds)
+
+            outputs = {
+                "scores": scores.cpu().numpy(),
+                "labels": labels.cpu().numpy(),
+                "bboxes": bboxes.cpu().numpy(),
+            }
+
+            return outputs
+        
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    from loss import build_criterion
+
+    # Model config
+    cfg = {
+        'width': 1.0,
+        'depth': 1.0,
+        'out_stride': [8, 16, 32],
+        # Image Encoder - Backbone
+        'backbone': 'resnet18',
+        'backbone_norm': 'BN',
+        'res5_dilation': False,
+        'pretrained': True,
+        'pretrained_weight': 'imagenet1k_v1',
+        # Image Encoder - FPN
+        'fpn': 'hybrid_encoder',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        'hidden_dim': 256,
+        'en_num_heads': 8,
+        'en_num_layers': 1,
+        'en_mlp_ratio': 4.0,
+        'en_dropout': 0.1,
+        'pe_temperature': 10000.,
+        'en_act': 'gelu',
+        # Transformer Decoder
+        'transformer': 'rtdetr_transformer',
+        'hidden_dim': 256,
+        'de_num_heads': 8,
+        'de_num_layers': 6,
+        'de_mlp_ratio': 4.0,
+        'de_dropout': 0.0,
+        'de_act': 'gelu',
+        'de_num_points': 4,
+        'num_queries': 300,
+        'learnt_init_query': False,
+        'pe_temperature': 10000.,
+        'dn_num_denoising': 100,
+        'dn_label_noise_ratio': 0.5,
+        'dn_box_noise_scale': 1,
+        # Head
+        'det_head': 'dino_head',
+        # Matcher
+        'matcher_hpy': {'cost_class': 2.0,
+                        'cost_bbox': 5.0,
+                        'cost_giou': 2.0,},
+        # Loss
+        'use_vfl': True,
+        'loss_coeff': {'class': 1,
+                       'bbox': 5,
+                       'giou': 2,
+                       'no_object': 0.1,},
+        }
+    bs = 1
+    # Create a batch of images & targets
+    image = torch.randn(bs, 3, 640, 640)
+    targets = [{
+        'labels': torch.tensor([2, 4, 5, 8]).long(),
+        'boxes':  torch.tensor([[0, 0, 10, 10], [12, 23, 56, 70], [0, 10, 20, 30], [50, 60, 55, 150]]).float() / 640.
+    }] * bs
+
+    # Create model
+    model = RT_DETR(cfg, num_classes=80)
+    model.train()
+
+    # Create criterion
+    criterion = build_criterion(cfg, num_classes=80)
+
+    # Model inference
+    t0 = time.time()
+    outputs = model(image, targets)
+    t1 = time.time()
+    print('Infer time: ', t1 - t0)
+
+    # Compute loss
+    loss = criterion(*outputs, targets)
+    for k in loss.keys():
+        print("{} : {}".format(k, loss[k].item()))
+
+    print('==============================')
+    model.eval()
+    flops, params = profile(model, inputs=(image, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))