Răsfoiți Sursa

add RTRDet: Real-time Detection with Transformer

yjh0410 2 ani în urmă
părinte
comite
9ea2e63afd

+ 12 - 2
config/__init__.py

@@ -30,6 +30,8 @@ from .data_config.transform_config import (
     yolox_medium_trans_config,
     yolox_large_trans_config,
     yolox_huge_trans_config,
+    # RTRDet-Style
+    rtrdet_large_trans_config,
     # SSD-Style
     ssd_trans_config,
 )
@@ -70,6 +72,10 @@ def build_trans_config(trans_config='ssd'):
     elif trans_config == 'yolox_huge':
         cfg = yolox_huge_trans_config
 
+    # RTRDet-style transform
+    elif trans_config == 'rtrdet_large':
+        cfg = rtrdet_large_trans_config
+        
     print('Transform Config: {} \n'.format(cfg))
 
     return cfg
@@ -84,9 +90,10 @@ from .model_config.yolov4_config import yolov4_cfg
 from .model_config.yolov5_config import yolov5_cfg
 from .model_config.yolov7_config import yolov7_cfg
 from .model_config.yolox_config import yolox_cfg
-## My RTMDet series
+## My RTCDet series
 from .model_config.rtcdet_config import rtcdet_cfg
-
+## My RTRDet series
+from .model_config.rtrdet_config import rtrdet_cfg
 
 def build_model_config(args):
     print('==============================')
@@ -115,6 +122,9 @@ def build_model_config(args):
     # RTCDet
     elif args.model in ['rtcdet_p', 'rtcdet_n', 'rtcdet_t', 'rtcdet_s', 'rtcdet_m', 'rtcdet_l', 'rtcdet_x']:
         cfg = rtcdet_cfg[args.model]
+    # RTRDet
+    elif args.model in ['rtrdet_p', 'rtrdet_n', 'rtrdet_t', 'rtrdet_s', 'rtrdet_m', 'rtrdet_l', 'rtrdet_x']:
+        cfg = rtrdet_cfg[args.model]
 
     return cfg
 

+ 21 - 0
config/data_config/transform_config.py

@@ -233,6 +233,27 @@ yolox_pico_trans_config = {
 }
 
 
+# ----------------------- RTRDet-Style Transform -----------------------
+rtrdet_large_trans_config = {
+    'aug_type': 'yolov5',
+    # Basic Augment
+    'degrees': 0.0,
+    'translate': 0.2,
+    'scale': [0.1, 2.0],
+    'shear': 0.0,
+    'perspective': 0.0,
+    'hsv_h': 0.015,
+    'hsv_s': 0.7,
+    'hsv_v': 0.4,
+    # Mosaic & Mixup
+    'mosaic_prob': 0.0,
+    'mixup_prob': 0.0,
+    'mosaic_type': 'yolov5_mosaic',
+    'mixup_type': 'yolox_mixup',
+    'mixup_scale': [0.5, 1.5]
+}
+
+
 # ----------------------- SSD-Style Transform -----------------------
 ssd_trans_config = {
     'aug_type': 'ssd',

+ 58 - 0
config/model_config/rtrdet_config.py

@@ -0,0 +1,58 @@
+# Real-time Detection with Transformer
+
+
+rtrdet_cfg = {
+    'rtrdet_l':{
+        # ---------------- Model config ----------------
+        ## Backbone
+        'backbone': 'elannet',
+        'pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_depthwise': False,
+        'width': 1.0,
+        'depth': 1.0,
+        'max_stride': 16,
+        'd_model': 512,
+        ## Transformer Encoder
+        'num_encoder': 1,
+        'encoder_num_head': 8,
+        'encoder_mlp_ratio': 4.0,
+        'encoder_dropout': 0.1,
+        'neck_depthwise': False,
+        'encoder_act': 'relu',
+        ## Transformer Decoder
+        'num_decoder': 6,
+        'stop_layer_id': -1,
+        'decoder_num_head': 8,
+        'decoder_mlp_ratio': 4.0,
+        'decoder_dropout': 0.1,
+        'decoder_act': 'relu',
+        'decoder_num_queries': 300,
+        'decoder_num_pattern': 3,
+        'spatial_prior': 'learned',  # 'learned', 'grid'
+        'num_topk': 100,
+        # ---------------- Train config ----------------
+        ## Input
+        'multi_scale': [0.5, 1.5], # 320 -> 960
+        'trans_type': 'rtrdet_large',
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        'matcher': "hungarian_matcher",
+        'matcher_hpy': {"hungarian_matcher": {'cost_cls_weight':  2.0,
+                                              'cost_box_weight':  5.0,
+                                              'cost_giou_weight': 2.0,
+                                              },
+                        },
+        # ---------------- Loss config ----------------
+        ## Loss weight
+        'ema_update': False,
+        'loss_weights': {"hungarian_matcher": {'loss_cls_weight':  1.0,
+                                               'loss_box_weight':  5.0,
+                                               'loss_giou_weight': 2.0},
+                        },
+        # ---------------- Train config ----------------
+        'trainer_type': 'rtrdet',
+    },
+
+}

+ 159 - 113
engine.py

@@ -1098,24 +1098,30 @@ class RTCTrainer(object):
         self.train_loader.dataset.transform = self.train_transform
         
 
-# Trainer for DETR
-class DetrTrainer(object):
+# RTRDet Trainer
+class RTRTrainer(object):
     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
-        # ------------------- basic parameters -------------------
+        # ------------------- Basic parameters -------------------
         self.args = args
         self.epoch = 0
         self.best_map = -1.
-        self.last_opt_step = 0
-        self.no_aug_epoch = args.no_aug_epoch
-        self.clip_grad = -1
         self.device = device
         self.criterion = criterion
         self.world_size = world_size
-        self.second_stage = False
+        self.grad_accumulate = args.grad_accumulate
+        self.clip_grad = 35
         self.heavy_eval = False
-        self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.001, 'backbone_lr_raio': 0.1}
+        # weak augmentatino stage
+        self.second_stage = False
+        self.third_stage = False
+        self.second_stage_epoch = args.no_aug_epoch
+        self.third_stage_epoch = args.no_aug_epoch // 2
+        # path to save model
+
+        # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
+        self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.001, 'backbone_lr_ratio': 0.1}
         self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
-        self.lr_schedule_dict = {'scheduler': 'linear', 'lrf': 0.01}
+        self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.05}
         self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}        
 
         # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
@@ -1125,26 +1131,26 @@ class DetrTrainer(object):
 
         # ---------------------------- Build Transform ----------------------------
         self.train_transform, self.trans_cfg = build_transform(
-            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
         self.val_transform, _ = build_transform(
-            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
 
         # ---------------------------- Build Dataset & Dataloader ----------------------------
-        self.dataset, self.dataset_info = build_dataset(self.args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
-        self.train_loader = build_dataloader(self.args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
+        self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
+        self.train_loader = build_dataloader(args, self.dataset, self.args.batch_size // self.world_size, CollateFunc())
 
         # ---------------------------- Build Evaluator ----------------------------
-        self.evaluator = build_evluator(self.args, self.data_cfg, self.val_transform, self.device)
+        self.evaluator = build_evluator(args, self.data_cfg, self.val_transform, self.device)
 
         # ---------------------------- Build Grad. Scaler ----------------------------
-        self.scaler = torch.cuda.amp.GradScaler(enabled=self.args.fp16)
+        self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
 
         # ---------------------------- Build Optimizer ----------------------------
-        self.optimizer_dict['lr0'] *= self.args.batch_size / 16.
+        self.optimizer_dict['lr0'] *= self.args.batch_size / 64.
         self.optimizer, self.start_epoch = build_detr_optimizer(self.optimizer_dict, model, self.args.resume)
 
         # ---------------------------- Build LR Scheduler ----------------------------
-        self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, self.args.max_epoch)
+        self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch - args.no_aug_epoch)
         self.lr_scheduler.last_epoch = self.start_epoch - 1  # do not move
         if self.args.resume:
             self.lr_scheduler.step()
@@ -1157,49 +1163,38 @@ class DetrTrainer(object):
             self.model_ema = None
 
 
-    def check_second_stage(self):
-        # set second stage
-        print('============== Second stage of Training ==============')
-        self.second_stage = True
-
-        # close mosaic augmentation
-        if self.train_loader.dataset.mosaic_prob > 0.:
-            print(' - Close < Mosaic Augmentation > ...')
-            self.train_loader.dataset.mosaic_prob = 0.
-            self.heavy_eval = True
-
-        # close mixup augmentation
-        if self.train_loader.dataset.mixup_prob > 0.:
-            print(' - Close < Mixup Augmentation > ...')
-            self.train_loader.dataset.mixup_prob = 0.
-            self.heavy_eval = True
-
-        # close rotation augmentation
-        if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
-            print(' - Close < degress of rotation > ...')
-            self.trans_cfg['degrees'] = 0.0
-        if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
-            print(' - Close < shear of rotation >...')
-            self.trans_cfg['shear'] = 0.0
-        if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
-            print(' - Close < perspective of rotation > ...')
-            self.trans_cfg['perspective'] = 0.0
-
-        # build a new transform for second stage
-        print(' - Rebuild transforms ...')
-        self.train_transform, self.trans_cfg = build_transform(
-            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
-        self.train_loader.dataset.transform = self.train_transform
-        
-
     def train(self, model):
         for epoch in range(self.start_epoch, self.args.max_epoch):
             if self.args.distributed:
                 self.train_loader.batch_sampler.sampler.set_epoch(epoch)
 
             # check second stage
-            if epoch >= (self.args.max_epoch - self.no_aug_epoch - 1) and not self.second_stage:
+            if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
                 self.check_second_stage()
+                # save model of the last mosaic epoch
+                weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch + 1))
+                torch.save({'model': model.state_dict(),
+                            'mAP': round(self.evaluator.map*100, 1),
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)
+
+            # check third stage
+            if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
+                self.check_third_stage()
+                # save model of the last mosaic epoch
+                weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                print('Saving state of the last weak augment epoch-{}.'.format(self.epoch + 1))
+                torch.save({'model': model.state_dict(),
+                            'mAP': round(self.evaluator.map*100, 1),
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)
 
             # train one epoch
             self.epoch = epoch
@@ -1219,23 +1214,19 @@ class DetrTrainer(object):
         # chech model
         model_eval = model if self.model_ema is None else self.model_ema.ema
 
-        # path to save model
-        path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
-        os.makedirs(path_to_save, exist_ok=True)
-
         if distributed_utils.is_main_process():
             # check evaluator
             if self.evaluator is None:
                 print('No evaluator ... save model and go on training.')
                 print('Saving state, epoch: {}'.format(self.epoch + 1))
                 weight_name = '{}_no_eval.pth'.format(self.args.model)
-                checkpoint_path = os.path.join(path_to_save, weight_name)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
                 torch.save({'model': model_eval.state_dict(),
                             'mAP': -1.,
                             'optimizer': self.optimizer.state_dict(),
                             'epoch': self.epoch,
                             'args': self.args}, 
-                            checkpoint_path)  
+                            checkpoint_path)               
             else:
                 print('eval ...')
                 # set eval mode
@@ -1254,7 +1245,7 @@ class DetrTrainer(object):
                     # save model
                     print('Saving state, epoch:', self.epoch + 1)
                     weight_name = '{}_best.pth'.format(self.args.model)
-                    checkpoint_path = os.path.join(path_to_save, weight_name)
+                    checkpoint_path = os.path.join(self.path_to_save, weight_name)
                     torch.save({'model': model_eval.state_dict(),
                                 'mAP': round(self.best_map*100, 1),
                                 'optimizer': self.optimizer.state_dict(),
@@ -1278,7 +1269,7 @@ class DetrTrainer(object):
         t0 = time.time()
         nw = epoch_size * self.args.wp_epoch
 
-        # train one epoch
+        # Train one epoch
         for iter_i, (images, targets) in enumerate(self.train_loader):
             ni = iter_i + self.epoch * epoch_size
             # Warmup
@@ -1286,10 +1277,9 @@ class DetrTrainer(object):
                 xi = [0, nw]  # x interp
                 for j, x in enumerate(self.optimizer.param_groups):
                     # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
-                    x['lr'] = np.interp(
-                        ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
+                    x['lr'] = np.interp( ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
                     if 'momentum' in x:
-                        x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
+                        x['momentum'] = np.interp(ni, xi, [self.warmup_dict['warmup_momentum'], self.optimizer_dict['momentum']])
                                 
             # To device
             images = images.to(self.device, non_blocking=True).float() / 255.
@@ -1297,11 +1287,14 @@ class DetrTrainer(object):
             # Multi scale
             if self.args.multi_scale:
                 images, targets, img_size = self.rescale_image_targets(
-                    images, targets, self.model_cfg['stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
+                    images, targets, self.model_cfg['max_stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
             else:
-                targets = self.refine_targets(targets, self.args.min_box_size, img_size)
+                targets = self.refine_targets(targets, self.args.min_box_size)
+
+            # Normalize bbox
+            targets = self.normalize_bbox(targets, img_size)
                 
-            # Visualize targets
+            # Visualize train targets
             if self.args.vis_tgt:
                 vis_data(images*255, targets)
 
@@ -1311,6 +1304,9 @@ class DetrTrainer(object):
                 # Compute loss
                 loss_dict = self.criterion(outputs=outputs, targets=targets, epoch=self.epoch)
                 losses = loss_dict['losses']
+                # Grad Accumulate
+                if self.grad_accumulate > 1:
+                    losses /= self.grad_accumulate
 
                 loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
 
@@ -1318,21 +1314,22 @@ class DetrTrainer(object):
             self.scaler.scale(losses).backward()
 
             # Optimize
-            if self.clip_grad > 0:
-                # unscale gradients
-                self.scaler.unscale_(self.optimizer)
-                # clip gradients
-                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
-            self.scaler.step(self.optimizer)
-            self.scaler.update()
-            self.optimizer.zero_grad()
-
-            # Model EMA
-            if self.model_ema is not None:
-                self.model_ema.update(model)
-            self.last_opt_step = ni
-
-            # Log
+            if ni % self.grad_accumulate == 0:
+                grad_norm = None
+                if self.clip_grad > 0:
+                    # unscale gradients
+                    self.scaler.unscale_(self.optimizer)
+                    # clip gradients
+                    grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.clip_grad)
+                # optimizer.step
+                self.scaler.step(self.optimizer)
+                self.scaler.update()
+                self.optimizer.zero_grad()
+                # ema
+                if self.model_ema is not None:
+                    self.model_ema.update(model)
+
+            # Logs
             if distributed_utils.is_main_process() and iter_i % 10 == 0:
                 t1 = time.time()
                 cur_lr = [param_group['lr']  for param_group in self.optimizer.param_groups]
@@ -1342,13 +1339,12 @@ class DetrTrainer(object):
                 log += '[lr: {:.6f}]'.format(cur_lr[0])
                 # loss infor
                 for k in loss_dict_reduced.keys():
-                    if self.args.vis_aux_loss:
-                        log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
-                    else:
-                        if k in ['loss_cls', 'loss_bbox', 'loss_giou', 'losses']:
-                            log += '[{}: {:.2f}]'.format(k, loss_dict_reduced[k])
-
+                    loss_val = loss_dict_reduced[k]
+                    if k == 'losses':
+                        loss_val *= self.grad_accumulate
+                    log += '[{}: {:.2f}]'.format(k, loss_val)
                 # other infor
+                log += '[grad_norm: {:.2f}]'.format(grad_norm)
                 log += '[time: {:.2f}]'.format(t1 - t0)
                 log += '[size: {}]'.format(img_size)
 
@@ -1357,33 +1353,35 @@ class DetrTrainer(object):
                 
                 t0 = time.time()
         
-        # LR Scheduler
-        self.lr_scheduler.step()
+        # LR Schedule
+        if not self.second_stage:
+            self.lr_scheduler.step()
         
 
-    def refine_targets(self, targets, min_box_size, img_size):
+    def refine_targets(self, targets, min_box_size):
         # rescale targets
         for tgt in targets:
-            boxes = tgt["boxes"]
-            labels = tgt["labels"]
+            boxes = tgt["boxes"].clone()
+            labels = tgt["labels"].clone()
             # refine tgt
             tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
             min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
             keep = (min_tgt_size >= min_box_size)
-            # xyxy -> cxcywh
-            new_boxes = torch.zeros_like(boxes)
-            new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
-            new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
-            # normalize
-            new_boxes /= img_size
-            del boxes
-
-            tgt["boxes"] = new_boxes[keep]
+
+            tgt["boxes"] = boxes[keep]
             tgt["labels"] = labels[keep]
         
         return targets
 
 
+    def normalize_bbox(self, targets, img_size):
+        # normalize targets
+        for tgt in targets:
+            tgt["boxes"] /= img_size
+        
+        return targets
+
+
     def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
         """
             Deployed for Multi scale trick.
@@ -1416,20 +1414,68 @@ class DetrTrainer(object):
             tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
             min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
             keep = (min_tgt_size >= min_box_size)
-            # xyxy -> cxcywh
-            new_boxes = torch.zeros_like(boxes)
-            new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
-            new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
-            # normalize
-            new_boxes /= new_img_size
-            del boxes
-
-            tgt["boxes"] = new_boxes[keep]
+
+            tgt["boxes"] = boxes[keep]
             tgt["labels"] = labels[keep]
 
         return images, targets, new_img_size
 
 
+    def check_second_stage(self):
+        # set second stage
+        print('============== Second stage of Training ==============')
+        self.second_stage = True
+
+        # close mosaic augmentation
+        if self.train_loader.dataset.mosaic_prob > 0.:
+            print(' - Close < Mosaic Augmentation > ...')
+            self.train_loader.dataset.mosaic_prob = 0.
+            self.heavy_eval = True
+
+        # close mixup augmentation
+        if self.train_loader.dataset.mixup_prob > 0.:
+            print(' - Close < Mixup Augmentation > ...')
+            self.train_loader.dataset.mixup_prob = 0.
+            self.heavy_eval = True
+
+        # close rotation augmentation
+        if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
+            print(' - Close < degress of rotation > ...')
+            self.trans_cfg['degrees'] = 0.0
+        if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
+            print(' - Close < shear of rotation >...')
+            self.trans_cfg['shear'] = 0.0
+        if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
+            print(' - Close < perspective of rotation > ...')
+            self.trans_cfg['perspective'] = 0.0
+
+        # build a new transform for second stage
+        print(' - Rebuild transforms ...')
+        self.train_transform, self.trans_cfg = build_transform(
+            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.train_loader.dataset.transform = self.train_transform
+        
+
+    def check_third_stage(self):
+        # set third stage
+        print('============== Third stage of Training ==============')
+        self.third_stage = True
+
+        # close random affine
+        if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
+            print(' - Close < translate of affine > ...')
+            self.trans_cfg['translate'] = 0.0
+        if 'scale' in self.trans_cfg.keys():
+            print(' - Close < scale of affine >...')
+            self.trans_cfg['scale'] = [1.0, 1.0]
+
+        # build a new transform for second stage
+        print(' - Rebuild transforms ...')
+        self.train_transform, self.trans_cfg = build_transform(
+            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.train_loader.dataset.transform = self.train_transform
+        
+
 # Build Trainer
 def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
     if model_cfg['trainer_type'] == 'yolov8':
@@ -1438,8 +1484,8 @@ def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion
         return YoloxTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
     elif model_cfg['trainer_type'] == 'rtcdet':
         return RTCTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
-    elif model_cfg['trainer_type'] == 'detr':
-        return DetrTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
+    elif model_cfg['trainer_type'] == 'rtrdet':
+        return RTRTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
     else:
         raise NotImplementedError
     

+ 7 - 1
models/detectors/__init__.py

@@ -9,9 +9,11 @@ from .yolov3.build import build_yolov3
 from .yolov4.build import build_yolov4
 from .yolov5.build import build_yolov5
 from .yolov7.build import build_yolov7
+from .yolox.build import build_yolox
 # My RTCDet
 from .rtcdet.build import build_rtcdet
-from .yolox.build import build_yolox
+# My RTRDet
+from .rtrdet.build import build_rtrdet
 
 
 # build object detector
@@ -53,6 +55,10 @@ def build_model(args,
     elif args.model in ['rtcdet_p', 'rtcdet_n', 'rtcdet_t', 'rtcdet_s', 'rtcdet_m', 'rtcdet_l', 'rtcdet_x']:
         model, criterion = build_rtcdet(
             args, model_cfg, device, num_classes, trainable, deploy)
+    # RTRDet
+    elif args.model in ['rtrdet_p', 'rtrdet_n', 'rtrdet_t', 'rtrdet_s', 'rtrdet_m', 'rtrdet_l', 'rtrdet_x']:
+        model, criterion = build_rtrdet(
+            args, model_cfg, device, num_classes, trainable, deploy)
 
     if trainable:
         # Load pretrained weight

+ 31 - 0
models/detectors/rtrdet/build.py

@@ -0,0 +1,31 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import torch
+import torch.nn as nn
+
+from .loss import build_criterion
+from .rtrdet import RTRDet
+
+
+# build object detector
+def build_rtrdet(args, cfg, device, num_classes=80, trainable=False, deploy=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+        
+    # -------------- Build RTRDet --------------
+    model = RTRDet(cfg         = cfg,
+                   device      = device, 
+                   num_classes = num_classes,
+                   trainable   = trainable,
+                   aux_loss    = True if trainable else False,
+                   deploy      = deploy
+                   )
+            
+    # -------------- Build criterion --------------
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(cfg, num_classes, aux_loss=True)
+
+    return model, criterion

+ 165 - 0
models/detectors/rtrdet/loss.py

@@ -0,0 +1,165 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .matcher import build_matcher
+from utils.misc import sigmoid_focal_loss
+from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
+from utils.distributed_utils import is_dist_avail_and_initialized, get_world_size
+
+
+class Criterion(nn.Module):
+    """ This class computes the loss for DETR.
+    The process happens in two steps:
+        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+    """
+    def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25):
+        """ Create the criterion.
+        Parameters:
+            num_classes: number of object categories, omitting the special no-object category
+            matcher: module able to compute a matching between targets and proposals
+            weight_dict: dict containing as key the names of the losses and as values their relative weight.
+            eos_coef: relative classification weight applied to the no-object category
+            losses: list of all the losses to be applied. See get_loss for list of available losses.
+        """
+        super().__init__()
+        self.num_classes = num_classes
+        self.matcher = matcher
+        self.weight_dict = weight_dict
+        self.losses = losses
+        self.focal_alpha = focal_alpha
+
+
+    def _get_src_permutation_idx(self, indices):
+        # permute predictions following indices
+        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+        src_idx = torch.cat([src for (src, _) in indices])
+        return batch_idx, src_idx
+
+
+    def _get_tgt_permutation_idx(self, indices):
+        # permute targets following indices
+        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+        return batch_idx, tgt_idx
+
+
+    def loss_labels(self, outputs, targets, indices, num_boxes):
+        """Classification loss (NLL)
+        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+        """
+        assert 'pred_logits' in outputs
+        src_logits = outputs['pred_logits']
+
+        idx = self._get_src_permutation_idx(indices)
+        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]).to(src_logits.device)
+        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
+                                    dtype=torch.int64, device=src_logits.device)
+        target_classes[idx] = target_classes_o
+
+        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
+                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
+        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
+
+        target_classes_onehot = target_classes_onehot[:, :, :-1]
+        loss_cls = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * \
+                  src_logits.shape[1]
+        losses = {'loss_cls': loss_cls}
+
+        return losses
+
+
+    def loss_boxes(self, outputs, targets, indices, num_boxes):
+        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
+        """
+        assert 'pred_boxes' in outputs
+        idx = self._get_src_permutation_idx(indices)
+        src_boxes = outputs['pred_boxes'][idx]
+        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0).to(src_boxes.device)
+
+        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
+
+        losses = {}
+        losses['loss_box'] = loss_bbox.sum() / num_boxes
+
+        loss_giou = 1 - torch.diag(generalized_box_iou(
+            box_cxcywh_to_xyxy(src_boxes),
+            box_cxcywh_to_xyxy(target_boxes)))
+        losses['loss_giou'] = loss_giou.sum() / num_boxes
+        return losses
+
+
+    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
+        loss_map = {
+            'labels': self.loss_labels,
+            'boxes': self.loss_boxes,
+        }
+        assert loss in loss_map, f'do you really want to compute {loss} loss?'
+        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
+
+
+    def forward(self, outputs, targets, epoch=0):
+        """ This performs the loss computation.
+        Parameters:
+             outputs: dict of tensors, see the output specification of the model for the format
+             targets: list of dicts, such that len(targets) == batch_size.
+                      The expected keys in each dict depends on the losses applied, see each loss' doc
+        """
+        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
+
+        # Retrieve the matching between the outputs of the last layer and the targets
+        indices = self.matcher(outputs_without_aux, targets)
+
+        # Compute the average number of target boxes accross all nodes, for normalization purposes
+        num_boxes = sum(len(t["labels"]) for t in targets)
+        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_boxes)
+        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
+
+        # Compute all the requested losses
+        losses = {}
+        for loss in self.losses:
+            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
+
+        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+        if 'aux_outputs' in outputs:
+            for i, aux_outputs in enumerate(outputs['aux_outputs']):
+                indices = self.matcher(aux_outputs, targets)
+                for loss in self.losses:
+                    kwargs = {}
+                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
+                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
+                    losses.update(l_dict)
+
+        # compute total losses
+        total_loss = sum(losses[k] * self.weight_dict[k] for k in losses.keys() if k in self.weight_dict)
+        losses['losses'] = total_loss
+
+        return losses
+
+
+# build criterion
+def build_criterion(cfg, num_classes, aux_loss=False):
+    # build matcher
+    matcher_type = cfg['matcher']
+    matcher = build_matcher(cfg)
+    
+    # build criterion
+    weight_dict = {'loss_cls':  cfg['loss_weights'][matcher_type]['loss_cls_weight'],
+                   'loss_box':  cfg['loss_weights'][matcher_type]['loss_box_weight'],
+                   'loss_giou': cfg['loss_weights'][matcher_type]['loss_giou_weight']}
+
+    if aux_loss:
+        aux_weight_dict = {}
+        for i in range(cfg['num_decoder'] - 1):
+            aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
+        weight_dict.update(aux_weight_dict)
+    losses = ['labels', 'boxes']
+    criterion = Criterion(num_classes, matcher, weight_dict, losses)
+
+    return criterion
+    

+ 102 - 0
models/detectors/rtrdet/matcher.py

@@ -0,0 +1,102 @@
+import torch
+import torch.nn as nn
+from scipy.optimize import linear_sum_assignment
+from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
+
+
+class HungarianMatcher(nn.Module):
+    """This class computes an assignment between the targets and the predictions of the network
+    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+    while the others are un-matched (and thus treated as non-objects).
+    """
+
+    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
+        """Creates the matcher
+        Params:
+            cost_class: This is the relative weight of the classification error in the matching cost
+            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
+            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
+        """
+        super().__init__()
+        self.cost_class = cost_class
+        self.cost_bbox = cost_bbox
+        self.cost_giou = cost_giou
+        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
+
+
+    @torch.no_grad()
+    def forward(self, outputs, targets):
+        """ Performs the matching
+        Params:
+            outputs: This is a dict that contains at least these entries:
+                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
+            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+                           objects in the target) containing the class labels
+                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
+        Returns:
+            A list of size batch_size, containing tuples of (index_i, index_j) where:
+                - index_i is the indices of the selected predictions (in order)
+                - index_j is the indices of the corresponding selected targets (in order)
+            For each batch element, it holds:
+                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+        """
+        bs, num_queries = outputs["pred_logits"].shape[:2]
+
+        # We flatten to compute the cost matrices in a batch
+        # [B * num_queries, C] = [N, C], where N is B * num_queries
+        out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
+        # [B * num_queries, 4] = [N, 4]
+        out_bbox = outputs["pred_boxes"].flatten(0, 1)
+
+        # Also concat the target labels and boxes
+        # [M,] where M is number of all targets in this batch
+        tgt_ids = torch.cat([v["labels"] for v in targets])
+        # [M, 4] where M is number of all targets in this batch
+        tgt_bbox = torch.cat([v["boxes"] for v in targets])
+
+        # Compute the classification cost.
+        alpha = 0.25
+        gamma = 2.0
+        neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
+        pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
+        cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
+
+        # Compute the L1 cost between boxes
+        # [N, M]
+        cost_bbox = torch.cdist(out_bbox, tgt_bbox.to(out_bbox.device), p=1)
+
+        # Compute the giou cost betwen boxes
+        # [N, M]
+        cost_giou = -generalized_box_iou(
+            box_cxcywh_to_xyxy(out_bbox),
+            box_cxcywh_to_xyxy(tgt_bbox.to(out_bbox.device)))
+
+        # Final cost matrix: [N, M]
+        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
+        # [N, M] -> [B, num_queries, M]
+        C = C.view(bs, num_queries, -1).cpu()
+
+        # The number of boxes in each image
+        sizes = [len(v["boxes"]) for v in targets]
+        # In the last dimension of C, we divide it into B costs, and each cost is [B, num_querys, M_i]
+        # where sum(Mi) = M.
+        # i is the batch index and c is cost_i = [B, num_querys, M_i].
+        # Therefore c[i] is the cost between the i-th sample and i-th prediction.
+        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
+        # As for each (i, j) in indices, i is the prediction indexes and j is the target indexes
+        # i contains row indexes of cost matrix: array([row_1, row_2, row_3]) 
+        # j contains col indexes of cost matrix: array([col_1, col_2, col_3])
+        # len(i) == len(j)
+        # len(indices) = batch_size
+        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
+
+
+def build_matcher(cfg):
+    matcher_type = cfg['matcher']
+    if matcher_type == 'hungarian_matcher':
+        return HungarianMatcher(cfg['matcher_hpy'][matcher_type]['cost_cls_weight'],
+                                cfg['matcher_hpy'][matcher_type]['cost_box_weight'],
+                                cfg['matcher_hpy'][matcher_type]['cost_giou_weight'])

+ 152 - 0
models/detectors/rtrdet/rtrdet.py

@@ -0,0 +1,152 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .rtrdet_backbone import build_backbone
+from .rtrdet_encoder import build_encoder
+from .rtrdet_decoder import build_decoder
+
+
+# Real-time Detection with Transformer
+class RTRDet(nn.Module):
+    def __init__(self, 
+                 cfg,
+                 device, 
+                 num_classes :int = 20, 
+                 trainable   :bool = False, 
+                 aux_loss    :bool = False,
+                 deploy      :bool = False):
+        super(RTRDet, self).__init__()
+        # ------------------ Basic parameters ------------------
+        self.cfg = cfg
+        self.device = device
+        self.max_stride = cfg['max_stride']
+        self.num_topk = cfg['num_topk']
+        self.d_model = round(cfg['d_model'] * cfg['width'])
+        self.num_classes = num_classes
+        self.aux_loss = aux_loss
+        self.trainable = trainable
+        self.deploy = deploy
+        
+        # ------------------ Network parameters ------------------
+        ## Backbone
+        self.backbone, self.feat_dims = build_backbone(cfg, trainable&cfg['pretrained'])
+        self.input_proj1 = nn.Conv2d(self.feat_dims[-1], self.d_model, kernel_size=1)
+        self.input_proj2 = nn.Conv2d(self.feat_dims[-2], self.d_model, kernel_size=1)
+
+        ## Transformer Encoder
+        self.encoder = build_encoder(cfg)
+
+        ## Transformer Decoder
+        self.decoder = build_decoder(cfg, num_classes, return_intermediate=aux_loss)
+
+
+    # ---------------------- Basic Functions ----------------------
+    def position_embedding(self, x, temperature=10000):
+        hs, ws = x.shape[-2:]
+        device = x.device
+        num_pos_feats = x.shape[1] // 2       
+        scale = 2 * 3.141592653589793
+
+        # generate xy coord mat
+        y_embed, x_embed = torch.meshgrid(
+            [torch.arange(1, hs+1, dtype=torch.float32),
+             torch.arange(1, ws+1, dtype=torch.float32)])
+        y_embed = y_embed / (hs + 1e-6) * scale
+        x_embed = x_embed / (ws + 1e-6) * scale
+    
+        # [H, W] -> [1, H, W]
+        y_embed = y_embed[None, :, :].to(device)
+        x_embed = x_embed[None, :, :].to(device)
+
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=device)
+        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
+        dim_t = temperature ** (2 * dim_t_)
+
+        pos_x = torch.div(x_embed[:, :, :, None], dim_t)
+        pos_y = torch.div(y_embed[:, :, :, None], dim_t)
+        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+
+        # [B, C, H, W]
+        pos_embed = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        
+        return pos_embed
+        
+    @torch.jit.unused
+    def set_aux_loss(self, outputs_class, outputs_coord):
+        # this is a workaround to make torchscript happy, as torchscript
+        # doesn't support dictionary with non-homogeneous values, such
+        # as a dict having both a Tensor and a list.
+        return [{'pred_logits': a, 'pred_boxes': b}
+                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+
+
+    # ---------------------- Main Process for Inference ----------------------
+    @torch.no_grad()
+    def inference_single_image(self, x):
+        # -------------------- Inference --------------------
+        ## Backbone
+        pyramid_feats = self.backbone(x)
+        high_level_feat = self.input_proj1(pyramid_feats[-1])
+        bs, c, h, w = high_level_feat.size()
+
+        ## Transformer Encoder
+        pos_embed1 = self.position_embedding(high_level_feat)
+        high_level_feat = self.encoder(high_level_feat, pos_embed1, self.decoder.adapt_pos2d)
+        high_level_feat = high_level_feat.permute(0, 2, 1).reshape(bs, c, h, w)
+        p4_level_feat = self.input_proj2(pyramid_feats[-2]) + F.interpolate(high_level_feat, scale_factor=2.0)
+
+        ## Transformer Decoder
+        pos_embed2 = self.position_embedding(p4_level_feat)
+        output_classes, output_coords = self.decoder(p4_level_feat, pos_embed2)
+
+        # -------------------- Post-process --------------------
+        ## Top-k
+        cls_pred, box_pred = output_classes[-1].flatten().sigmoid_(), output_coords[-1]
+        cls_pred = cls_pred[0].flatten().sigmoid_()
+        box_pred = box_pred[0]
+        predicted_prob, topk_idxs = cls_pred.sort(descending=True)
+        topk_idxs = topk_idxs[:self.num_topk]
+        topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+        topk_scores = predicted_prob[:self.num_topk]
+        topk_labels = topk_idxs % self.num_classes
+        topk_bboxes = box_pred[topk_box_idxs]
+        ## Denormalize bbox
+        img_h, img_w = x.shape[-2:]
+        topk_bboxes[..., 0::2] *= img_w
+        topk_bboxes[..., 1::2] *= img_h
+
+        if self.deploy:
+            return topk_bboxes, topk_scores, topk_labels
+        else:
+            return topk_bboxes.cpu().numpy(), topk_scores.cpu().numpy(), topk_labels.cpu().numpy()
+        
+
+    # ---------------------- Main Process for Training ----------------------
+    def forward(self, x):
+        if not self.trainable:
+            return self.inference_single_image(x)
+        else:
+            # -------------------- Inference --------------------
+            ## Backbone
+            pyramid_feats = self.backbone(x)
+            high_level_feat = self.input_proj1(pyramid_feats[-1])
+            bs, c, h, w = high_level_feat.size()
+
+            ## Transformer Encoder
+            pos_embed1 = self.position_embedding(high_level_feat)
+            high_level_feat = self.encoder(high_level_feat, pos_embed1, self.decoder.adapt_pos2d)
+            high_level_feat = high_level_feat.permute(0, 2, 1).reshape(bs, c, h, w)
+            p4_level_feat = self.input_proj2(pyramid_feats[-2]) + F.interpolate(high_level_feat, scale_factor=2.0)
+
+            ## Transformer Decoder
+            pos_embed2 = self.position_embedding(p4_level_feat)
+            output_classes, output_coords = self.decoder(p4_level_feat, pos_embed2)
+
+            outputs = {'pred_logits': output_classes[-1], 'pred_boxes': output_coords[-1]}
+            if self.aux_loss:
+                outputs['aux_outputs'] = self.set_aux_loss(output_classes, output_coords)
+            
+            return outputs
+    

+ 157 - 0
models/detectors/rtrdet/rtrdet_backbone.py

@@ -0,0 +1,157 @@
+import torch
+import torch.nn as nn
+try:
+    from .rtrdet_basic import Conv, ELANBlock, DSBlock
+except:
+    from rtrdet_basic import Conv, ELANBlock, DSBlock
+
+
+model_urls = {
+    'elannet_pico':   "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_pico.pth",
+    'elannet_nano':   "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_nano.pth",
+    'elannet_tiny':   "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_tiny.pth",
+    'elannet_small':  "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_small.pth",
+    'elannet_medium': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_medium.pth",
+    'elannet_large':  "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_large.pth",
+    'elannet_huge':   "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_huge.pth",
+}
+
+
+# ---------------------------- Backbones ----------------------------
+# ELANNet-P5
+class ELANNet(nn.Module):
+    def __init__(self, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANNet, self).__init__()
+        # ------------------ Basic parameters ------------------
+        self.width = width
+        self.depth = depth
+        self.expand_ratios = [0.5, 0.5, 0.5, 0.25]
+        self.feat_dims = [round(64*width), round(128*width), round(256*width), round(512*width), round(1024*width), round(1024*width)]
+        
+        # ------------------ Network parameters ------------------
+        ## P1/2
+        self.layer_1 = nn.Sequential(
+            Conv(3, self.feat_dims[0], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            Conv(self.feat_dims[0], self.feat_dims[0], k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        ## P2/4
+        self.layer_2 = nn.Sequential(   
+            Conv(self.feat_dims[0], self.feat_dims[1], k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
+            ELANBlock(self.feat_dims[1], self.feat_dims[2], self.expand_ratios[0], self.depth, act_type, norm_type, depthwise)
+        )
+        ## P3/8
+        self.layer_3 = nn.Sequential(
+            DSBlock(self.feat_dims[2], self.feat_dims[2], act_type, norm_type, depthwise),             
+            ELANBlock(self.feat_dims[2], self.feat_dims[3], self.expand_ratios[1], self.depth, act_type, norm_type, depthwise)
+        )
+        ## P4/16
+        self.layer_4 = nn.Sequential(
+            DSBlock(self.feat_dims[3], self.feat_dims[3], act_type, norm_type, depthwise),             
+            ELANBlock(self.feat_dims[3], self.feat_dims[4], self.expand_ratios[2], self.depth, act_type, norm_type, depthwise)
+        )
+        ## P5/32
+        self.layer_5 = nn.Sequential(
+            DSBlock(self.feat_dims[4], self.feat_dims[4], act_type, norm_type, depthwise),             
+            ELANBlock(self.feat_dims[4], self.feat_dims[5], self.expand_ratios[3], self.depth, act_type, norm_type, depthwise)
+        )
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ---------------------------- Functions ----------------------------
+## load pretrained weight
+def load_weight(model, model_name):
+    # load weight
+    print('Loading pretrained weight ...')
+    url = model_urls[model_name]
+    if url is not None:
+        checkpoint = torch.hub.load_state_dict_from_url(
+            url=url, map_location="cpu", check_hash=True)
+        # checkpoint state dict
+        checkpoint_state_dict = checkpoint.pop("model")
+        # model state dict
+        model_state_dict = model.state_dict()
+        # check
+        for k in list(checkpoint_state_dict.keys()):
+            if k in model_state_dict:
+                shape_model = tuple(model_state_dict[k].shape)
+                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                if shape_model != shape_checkpoint:
+                    checkpoint_state_dict.pop(k)
+                    print(k)
+            else:
+                checkpoint_state_dict.pop(k)
+                print(k)
+
+        model.load_state_dict(checkpoint_state_dict)
+    else:
+        print('No pretrained for {}'.format(model_name))
+
+    return model
+
+## build ELAN-Net
+def build_backbone(cfg, pretrained=False): 
+    # model
+    backbone = ELANNet(
+        width=cfg['width'],
+        depth=cfg['depth'],
+        act_type=cfg['bk_act'],
+        norm_type=cfg['bk_norm'],
+        depthwise=cfg['bk_depthwise']
+        )
+    # check whether to load imagenet pretrained weight
+    if pretrained:
+        if cfg['width'] == 0.25 and cfg['depth'] == 0.34 and cfg['bk_depthwise']:
+            backbone = load_weight(backbone, model_name='elannet_pico')
+        elif cfg['width'] == 0.25 and cfg['depth'] == 0.34:
+            backbone = load_weight(backbone, model_name='elannet_nano')
+        elif cfg['width'] == 0.375 and cfg['depth'] == 0.34:
+            backbone = load_weight(backbone, model_name='elannet_tiny')
+        elif cfg['width'] == 0.5 and cfg['depth'] == 0.34:
+            backbone = load_weight(backbone, model_name='elannet_small')
+        elif cfg['width'] == 0.75 and cfg['depth'] == 0.67:
+            backbone = load_weight(backbone, model_name='elannet_medium')
+        elif cfg['width'] == 1.0 and cfg['depth'] == 1.0:
+            backbone = load_weight(backbone, model_name='elannet_large')
+        elif cfg['width'] == 1.25 and cfg['depth'] == 1.34:
+            backbone = load_weight(backbone, model_name='elannet_huge')
+    feat_dims = backbone.feat_dims[-3:]
+
+    return backbone, feat_dims
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_depthwise': False,
+        'width': 1.0,
+        'depth': 1.0,
+    }
+    model, feats = build_backbone(cfg)
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 399 - 0
models/detectors/rtrdet/rtrdet_basic.py

@@ -0,0 +1,399 @@
+import copy
+import torch
+import torch.nn as nn
+from typing import Optional
+from torch import Tensor
+
+
+# ---------------------------- Basic functions ----------------------------
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+
+def get_clones(module, N):
+    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+    
+
+# ---------------------------- 2D CNN ----------------------------
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type='lrelu',     # activation
+                 norm_type='BN',       # normalization
+                 depthwise=False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        p = p if d == 1 else d
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+
+# ------------------------------- MLP -------------------------------
+class MLP(nn.Module):
+    """ Very simple multi-layer perceptron (also called FFN)"""
+
+    def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        return x
+
+class FFN(nn.Module):
+    def __init__(self, d_model=256, mlp_ratio=4.0, dropout=0., act_type='relu'):
+        super().__init__()
+        self.fpn_dim = round(d_model * mlp_ratio)
+        self.linear1 = nn.Linear(d_model, self.fpn_dim)
+        self.activation = get_activation(act_type)
+        self.dropout2 = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(self.fpn_dim, d_model)
+        self.dropout3 = nn.Dropout(dropout)
+        self.norm2 = nn.LayerNorm(d_model)
+
+    def forward(self, src):
+        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+        src = src + self.dropout3(src2)
+        src = self.norm2(src)
+        return src
+    
+# ---------------------------- Attention ----------------------------
+class MultiHeadAttention(nn.Module):
+    def __init__(self, d_model, num_heads, dropout=0.) -> None:
+        super().__init__()
+        # --------------- Basic parameters ---------------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.scale = (d_model // num_heads) ** -0.5
+
+        # --------------- Network parameters ---------------
+        self.q_proj = nn.Linear(d_model, d_model, bias = False) # W_q, W_k, W_v
+        self.k_proj = nn.Linear(d_model, d_model, bias = False) # W_q, W_k, W_v
+        self.v_proj = nn.Linear(d_model, d_model, bias = False) # W_q, W_k, W_v
+
+        self.out_proj = nn.Linear(d_model, d_model)
+        self.dropout = nn.Dropout(dropout)
+
+
+    def forward(self, query, key, value):
+        """
+        Inputs:
+            query : (Tensor) -> [B, Nq, C]
+            key   : (Tensor) -> [B, Nk, C]
+            value : (Tensor) -> [B, Nk, C]
+        """
+        bs = query.shape[0]
+        Nq = query.shape[1]
+        Nk = key.shape[1]
+
+        # ----------------- Input proj -----------------
+        query = self.q_proj(query)
+        key   = self.k_proj(key)
+        value = self.v_proj(value)
+
+        # ----------------- Multi-head Attn -----------------
+        ## [B, N, C] -> [B, N, H, C_h] -> [B, H, N, C_h]
+        query = query.view(bs, Nq, self.num_heads, self.d_model // self.num_heads)
+        query = query.permute(0, 2, 1, 3).contiguous()
+        key   = key.view(bs, Nk, self.num_heads, self.d_model // self.num_heads)
+        key   = key.permute(0, 2, 1, 3).contiguous()
+        value = value.view(bs, Nk, self.num_heads, self.d_model // self.num_heads)
+        value = value.permute(0, 2, 1, 3).contiguous()
+        # Attention
+        ## [B, H, Nq, C_h] X [B, H, C_h, Nk] = [B, H, Nq, Nk]
+        sim_matrix = torch.matmul(query, key.transpose(-1, -2)) * self.scale
+        sim_matrix = torch.softmax(sim_matrix, dim=-1)
+
+        # ----------------- Output -----------------
+        out = torch.matmul(sim_matrix, value)  # [B, H, Nq, C_h]
+        out = out.permute(0, 2, 1, 3).contiguous().view(bs, Nq, -1)
+        out = self.out_proj(out)
+
+        return out
+        
+
+# ---------------------------- Modified YOLOv7's Modules ----------------------------
+class ELANBlock(nn.Module):
+    def __init__(self, in_dim, out_dim, expand_ratio=0.5, depth=1.0, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANBlock, self).__init__()
+        if isinstance(expand_ratio, float):
+            inter_dim = int(in_dim * expand_ratio)
+            inter_dim2 = inter_dim
+        elif isinstance(expand_ratio, list):
+            assert len(expand_ratio) == 2
+            e1, e2 = expand_ratio
+            inter_dim = int(in_dim * e1)
+            inter_dim2 = int(inter_dim * e2)
+        # branch-1
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        # branch-2
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        # branch-3
+        for idx in range(round(3*depth)):
+            if idx == 0:
+                cv3 = [Conv(inter_dim, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
+            else:
+                cv3.append(Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise))
+        self.cv3 = nn.Sequential(*cv3)
+        # branch-4
+        self.cv4 = nn.Sequential(*[
+            Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            for _ in range(round(3*depth))
+        ])
+        # output
+        self.out = Conv(inter_dim*2 + inter_dim2*2, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+
+    def forward(self, x):
+        """
+        Input:
+            x: [B, C_in, H, W]
+        Output:
+            out: [B, C_out, H, W]
+        """
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.cv3(x2)
+        x4 = self.cv4(x3)
+
+        # [B, C, H, W] -> [B, 2C, H, W]
+        out = self.out(torch.cat([x1, x2, x3, x4], dim=1))
+
+        return out
+
+class ELANBlockFPN(nn.Module):
+    def __init__(self, in_dim, out_dim, expand_ratio :float=0.5, branch_depth :int=1, shortcut=False, act_type='silu', norm_type='BN', depthwise=False):
+        super().__init__()
+        # ----------- Basic Parameters -----------
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+        self.inter_dim1 = round(out_dim * expand_ratio)
+        self.inter_dim2 = round(self.inter_dim1 * expand_ratio)
+        self.expand_ratio = expand_ratio
+        self.branch_depth = branch_depth
+        self.shortcut = shortcut
+        # ----------- Network Parameters -----------
+        ## branch-1
+        self.cv1 = Conv(in_dim, self.inter_dim1, k=1, act_type=act_type, norm_type=norm_type)
+        ## branch-2
+        self.cv2 = Conv(in_dim, self.inter_dim1, k=1, act_type=act_type, norm_type=norm_type)
+        ## branch-3
+        self.cv3 = []
+        for i in range(branch_depth):
+            if i == 0:
+                self.cv3.append(Conv(self.inter_dim1, self.inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise))
+            else:
+                self.cv3.append(Conv(self.inter_dim2, self.inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise))
+        self.cv3 = nn.Sequential(*self.cv3)
+        ## branch-4
+        self.cv4 = nn.Sequential(*[
+            Conv(self.inter_dim2, self.inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            for _ in range(branch_depth)
+        ])
+        ## branch-5
+        self.cv5 = nn.Sequential(*[
+            Conv(self.inter_dim2, self.inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            for _ in range(branch_depth)
+        ])
+        ## branch-6
+        self.cv6 = nn.Sequential(*[
+            Conv(self.inter_dim2, self.inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            for _ in range(branch_depth)
+        ])
+        ## output proj
+        self.out = Conv(self.inter_dim1*2 + self.inter_dim2*4, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.cv3(x2)
+        x4 = self.cv4(x3)
+        x5 = self.cv5(x4)
+        x6 = self.cv6(x5)
+
+        # [B, C, H, W] -> [B, 2C, H, W]
+        out = self.out(torch.cat([x1, x2, x3, x4, x5, x6], dim=1))
+
+        return out
+    
+class DSBlock(nn.Module):
+    def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
+        super().__init__()
+        inter_dim = out_dim // 2
+        self.mp = nn.MaxPool2d((2, 2), 2)
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = nn.Sequential(
+            Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type),
+            Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+    def forward(self, x):
+        x1 = self.cv1(self.mp(x))
+        x2 = self.cv2(x)
+        out = torch.cat([x1, x2], dim=1)
+
+        return out
+
+
+# ---------------------------- Transformer Modules ----------------------------
+class TREncoderLayer(nn.Module):
+    def __init__(self,
+                 d_model,
+                 num_heads,
+                 mlp_ratio=4.0,
+                 dropout=0.1,
+                 act_type="relu",
+                 ):
+        super().__init__()
+        # Multi-head Self-Attn
+        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
+        self.dropout = nn.Dropout(dropout)
+        self.norm = nn.LayerNorm(d_model)
+
+        # Feedforwaed Network
+        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
+
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+
+    def forward(self, src, pos):
+        """
+        Input:
+            src: [torch.Tensor] -> [B, N, C]
+            pos: [torch.Tensor] -> [B, N, C]
+        Output:
+            src: [torch.Tensor] -> [B, N, C]
+        """
+        q = k = self.with_pos_embed(src, pos)
+
+        # self-attn
+        src2 = self.self_attn(q, k, value=src)
+
+        # reshape: [B, N, C] -> [B, C, H, W]
+        src = src + self.dropout(src2)
+        src = self.norm(src)
+
+        # ffpn
+        src = self.ffn(src)
+        
+        return src
+
+class TRDecoderLayer(nn.Module):
+    def __init__(self,
+                 d_model,
+                 num_heads,
+                 mlp_ratio=4.0,
+                 dropout=0.1,
+                 act_type="relu"):
+        super().__init__()
+        self.scale = 2 * 3.141592653589793
+        self.d_model = d_model
+        # self attention
+        self.self_attn = MultiHeadAttention(d_model, num_heads, dropout)
+        self.dropout1 = nn.Dropout(dropout)
+        self.norm1 = nn.LayerNorm(d_model)
+        # cross attention
+        self.cross_attn = MultiHeadAttention(d_model, num_heads, dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.norm2 = nn.LayerNorm(d_model)
+        # FFN
+        self.ffn = FFN(d_model, mlp_ratio, dropout, act_type)
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+    def pos2posemb2d(self, pos, temperature=10000):
+        pos = pos * self.scale
+        num_pos_feats = self.d_model // 2
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
+        dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
+        pos_x = pos[..., 0, None] / dim_t
+        pos_y = pos[..., 1, None] / dim_t
+        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
+        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
+        posemb = torch.cat((pos_y, pos_x), dim=-1)
+        
+        return posemb
+    
+    def forward(self, tgt, memory, query_pos, memory_pos):
+        # self attention
+        q1 = k1 = self.with_pos_embed(tgt, query_pos)
+        v1 = tgt
+        tgt2 = self.self_attn(q1, k1, v1)
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+
+        # cross attention
+        q2 = self.with_pos_embed(tgt, query_pos)
+        k2 = self.with_pos_embed(memory, memory_pos)
+        v2 = memory
+        tgt2 = self.cross_attn(q2, k2, v2)
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+
+        # ffn
+        tgt = self.ffn(tgt)
+
+        return tgt

+ 123 - 0
models/detectors/rtrdet/rtrdet_decoder.py

@@ -0,0 +1,123 @@
+import torch
+import torch.nn as nn
+import math
+
+from .rtrdet_basic import get_clones, TRDecoderLayer, MLP
+
+
+# Transformer Decoder Module
+class TransformerDecoder(nn.Module):
+    def __init__(self, cfg, num_classes, return_intermediate=False):
+        super().__init__()
+        # -------------------- Basic Parameters ---------------------
+        self.d_model = round(cfg['d_model'] * cfg['width'])
+        self.num_queries = cfg['decoder_num_queries']
+        self.num_pattern = cfg['decoder_num_pattern']
+        self.num_deocder = cfg['num_decoder']
+        self.num_classes = num_classes
+        self.stop_layer_id = cfg['num_decoder'] if cfg['stop_layer_id'] == -1 else cfg['stop_layer_id']
+        self.return_intermediate = return_intermediate
+        self.scale = 2 * 3.141592653589793
+
+        # -------------------- Network Parameters ---------------------
+        ## Decoder
+        decoder_layer = TRDecoderLayer(d_model   = self.d_model,
+                                       num_heads = cfg['decoder_num_head'],
+                                       mlp_ratio = cfg['decoder_mlp_ratio'],
+                                       dropout   = cfg['decoder_dropout'],
+                                       act_type  = cfg['decoder_act']
+                                       )
+        self.decoder_layers = get_clones(decoder_layer, self.num_deocder)
+        ## Pattern embed
+        self.pattern = nn.Embedding(self.num_pattern, self.d_model)
+        ## Spatial embed
+        self.position = nn.Embedding(self.num_queries, 2)
+        ## Output head
+        self.class_embed = nn.Linear(self.d_model, self.num_classes)
+        self.bbox_embed  = MLP(self.d_model, self.d_model, 4, 3)
+        # Adaptive pos_embed
+        self.adapt_pos2d = nn.Sequential(
+            nn.Linear(self.d_model, self.d_model),
+            nn.ReLU(),
+            nn.Linear(self.d_model, self.d_model),
+        )
+
+        self._reset_parameters()
+
+
+    def _reset_parameters(self):
+        prior_prob = 0.01
+        bias_value = -math.log((1 - prior_prob) / prior_prob)
+        self.class_embed.bias.data = torch.ones(self.num_classes) * bias_value
+
+        nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0)
+        nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0)
+        nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0)
+        nn.init.uniform_(self.position.weight.data, 0, 1)
+
+        self.class_embed = nn.ModuleList([self.class_embed for _ in range(self.num_deocder)])
+        self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(self.num_deocder)])
+
+
+    def pos2posemb2d(self, pos, temperature=10000):
+        pos = pos * self.scale
+        num_pos_feats = self.d_model // 2
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=pos.device)
+        dim_t = temperature ** (2 * (dim_t // 2) / num_pos_feats)
+        pos_x = pos[..., 0, None] / dim_t
+        pos_y = pos[..., 1, None] / dim_t
+        pos_x = torch.stack((pos_x[..., 0::2].sin(), pos_x[..., 1::2].cos()), dim=-1).flatten(-2)
+        pos_y = torch.stack((pos_y[..., 0::2].sin(), pos_y[..., 1::2].cos()), dim=-1).flatten(-2)
+        posemb = torch.cat((pos_y, pos_x), dim=-1)
+        
+        return posemb
+    
+
+    def forward(self, memory, memory_pos):
+        # reshape: [B, C, H, W] -> [B, N, C], N = HW
+        memory = memory.flatten(2).permute(0, 2, 1).contiguous()
+        memory_pos = memory_pos.flatten(2).permute(0, 2, 1).contiguous()
+        memory_pos = self.adapt_pos2d(memory_pos)
+        bs, _, channels = memory.size()
+
+        # reshape: [Na, C] -> [1, Na, 1, C] -> [1, Na, Np, C] -> [1, Nq, C], Nq = Na*Np
+        tgt = self.pattern.weight.reshape(1, self.num_pattern, 1, channels).repeat(bs, 1, self.num_queries, 1)
+        tgt = tgt.reshape(bs, self.num_pattern * self.num_queries, channels)
+        
+        # Reference points
+        reference_points = self.position.weight.unsqueeze(0).repeat(bs, self.num_pattern, 1)
+
+        # Decoder
+        output_classes = []
+        output_coords = []
+        for layer_id, layer in enumerate(self.decoder_layers):
+            # query embed
+            query_pos = self.adapt_pos2d(self.pos2posemb2d(reference_points))
+            tgt = layer(tgt, memory, query_pos, memory_pos)
+            reference = self.inverse_sigmoid(reference_points)
+            ## class
+            outputs_class = self.class_embed[layer_id](tgt)
+            ## bbox
+            tmp = self.bbox_embed[layer_id](tgt)
+            tmp[..., :2] += reference
+            outputs_coord = tmp.sigmoid()
+
+            output_classes.append(outputs_class)
+            output_coords.append(outputs_coord)
+
+            if layer_id == self.stop_layer_id:
+                break
+
+        return torch.stack(output_classes), torch.stack(output_coords)
+
+
+    def inverse_sigmoid(self, x):
+        x = x.clamp(min=0, max=1)
+        return torch.log(x.clamp(min=1e-5)/(1 - x).clamp(min=1e-5))
+
+    
+# build detection head
+def build_decoder(cfg, num_classes, return_intermediate=False):
+    decoder = TransformerDecoder(cfg, num_classes, return_intermediate) 
+
+    return decoder

+ 41 - 0
models/detectors/rtrdet/rtrdet_encoder.py

@@ -0,0 +1,41 @@
+import torch
+import torch.nn as nn
+
+from .rtrdet_basic import get_clones, TREncoderLayer
+
+
+# Transformer Encoder Module
+class TransformerEncoder(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        # -------------------- Basic Parameters ---------------------
+        self.d_model = round(cfg['d_model']*cfg['width'])
+        self.num_encoder = cfg['num_encoder']
+
+        # -------------------- Network Parameters ---------------------
+        encoder_layer = TREncoderLayer(d_model   = self.d_model,
+                                       num_heads = cfg['encoder_num_head'],
+                                       mlp_ratio = cfg['encoder_mlp_ratio'],
+                                       dropout   = cfg['encoder_dropout'],
+                                       act_type  = cfg['encoder_act']
+                                       )
+        self.encoder_layers = get_clones(encoder_layer, self.num_encoder)
+
+
+    def forward(self, feat, pos_embed, adapt_pos2d):
+        # reshape: [B, C, H, W] -> [B, N, C], N = HW
+        feat = feat.flatten(2).permute(0, 2, 1).contiguous()
+        pos_embed = adapt_pos2d(pos_embed.flatten(2).permute(0, 2, 1).contiguous())
+
+        # Transformer encoder
+        for encoder in self.encoder_layers:
+            feat = encoder(feat, pos_embed)
+
+        return feat
+
+
+# build detection head
+def build_encoder(cfg):
+    transformer_encoder = TransformerEncoder(cfg) 
+
+    return transformer_encoder

+ 1 - 1
utils/solver/optimizer.py

@@ -54,7 +54,7 @@ def build_detr_optimizer(cfg, model, resume=None):
         {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
         {
             "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
-            "lr": cfg['lr0'] * cfg['backbone_lr_raio'],
+            "lr": cfg['lr0'] * cfg['backbone_lr_ratio'],
         },
     ]