Browse Source

add RT-DETR

yjh0410 2 years ago
parent
commit
d0c17b2f29
37 changed files with 2172 additions and 218 deletions
  1. 48 43
      config/__init__.py
  2. 79 0
      config/rtdetr_config.py
  3. 1 0
      config/yolov1_config.py
  4. 1 0
      config/yolov2_config.py
  5. 2 0
      config/yolov3_config.py
  6. 2 0
      config/yolov4_config.py
  7. 5 0
      config/yolov5_config.py
  8. 3 0
      config/yolov7_config.py
  9. 1 0
      config/yolox2_config.py
  10. 5 0
      config/yolox_config.py
  11. 430 41
      engine.py
  12. 5 0
      models/detectors/__init__.py
  13. 33 0
      models/detectors/rtdetr/build.py
  14. 154 0
      models/detectors/rtdetr/image_encoder/cnn_backbone.py
  15. 190 0
      models/detectors/rtdetr/image_encoder/cnn_basic.py
  16. 70 0
      models/detectors/rtdetr/image_encoder/cnn_neck.py
  17. 94 0
      models/detectors/rtdetr/image_encoder/cnn_pafpn.py
  18. 78 0
      models/detectors/rtdetr/image_encoder/img_encoder.py
  19. 171 0
      models/detectors/rtdetr/loss.py
  20. 102 0
      models/detectors/rtdetr/matcher.py
  21. 107 0
      models/detectors/rtdetr/rtdetr.py
  22. 221 0
      models/detectors/rtdetr/rtdetr_basic.py
  23. 122 0
      models/detectors/rtdetr/rtdetr_decoder.py
  24. 77 0
      models/detectors/rtdetr/rtdetr_dethead.py
  25. 10 0
      models/detectors/rtdetr/rtdetr_encoder.py
  26. 0 1
      models/detectors/yolox2/yolox2.py
  27. 0 0
      models/trackers/__init__.py
  28. 0 0
      models/trackers/byte_tracker/basetrack.py
  29. 0 0
      models/trackers/byte_tracker/build.py
  30. 0 0
      models/trackers/byte_tracker/byte_tracker.py
  31. 0 0
      models/trackers/byte_tracker/kalman_filter.py
  32. 0 0
      models/trackers/byte_tracker/matching.py
  33. 2 2
      track.py
  34. 16 96
      train.py
  35. 72 34
      utils/box_ops.py
  36. 37 0
      utils/misc.py
  37. 34 1
      utils/solver/optimizer.py

+ 48 - 43
config/__init__.py

@@ -14,48 +14,6 @@ def build_dataset_config(args):
     return cfg
 
 
-# ------------------ Model Config ----------------------
-from .yolov1_config import yolov1_cfg
-from .yolov2_config import yolov2_cfg
-from .yolov3_config import yolov3_cfg
-from .yolov4_config import yolov4_cfg
-from .yolov5_config import yolov5_cfg
-from .yolov7_config import yolov7_cfg
-from .yolox_config import yolox_cfg
-from .yolox2_config import yolox2_cfg
-
-
-def build_model_config(args):
-    print('==============================')
-    print('Model: {} ...'.format(args.model.upper()))
-    # YOLOv1
-    if args.model == 'yolov1':
-        cfg = yolov1_cfg
-    # YOLOv2
-    elif args.model == 'yolov2':
-        cfg = yolov2_cfg
-    # YOLOv3
-    elif args.model in ['yolov3', 'yolov3_t']:
-        cfg = yolov3_cfg[args.model]
-    # YOLOv4
-    elif args.model in ['yolov4', 'yolov4_t']:
-        cfg = yolov4_cfg[args.model]
-    # YOLOv5
-    elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
-        cfg = yolov5_cfg[args.model]
-    # YOLOv7
-    elif args.model in ['yolov7_t', 'yolov7_l', 'yolov7_x']:
-        cfg = yolov7_cfg[args.model]
-    # YOLOX
-    elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
-        cfg = yolox_cfg[args.model]
-    # YOLOX2
-    elif args.model in ['yolox2_n', 'yolox2_s', 'yolox2_m', 'yolox2_l', 'yolox2_x']:
-        cfg = yolox2_cfg[args.model]
-
-    return cfg
-
-
 # ------------------ Transform Config ----------------------
 from .transform_config import (
     # YOLOv5-Style
@@ -112,4 +70,51 @@ def build_trans_config(trans_config='ssd'):
     elif trans_config == 'yolox_huge':
         cfg = yolox_huge_trans_config
         
-    return cfg
+    return cfg
+
+
+# ------------------ Model Config ----------------------
+from .yolov1_config import yolov1_cfg
+from .yolov2_config import yolov2_cfg
+from .yolov3_config import yolov3_cfg
+from .yolov4_config import yolov4_cfg
+from .yolov5_config import yolov5_cfg
+from .yolov7_config import yolov7_cfg
+from .yolox_config import yolox_cfg
+from .yolox2_config import yolox2_cfg
+from .rtdetr_config import rtdetr_cfg
+
+
+def build_model_config(args):
+    print('==============================')
+    print('Model: {} ...'.format(args.model.upper()))
+    # YOLOv1
+    if args.model == 'yolov1':
+        cfg = yolov1_cfg
+    # YOLOv2
+    elif args.model == 'yolov2':
+        cfg = yolov2_cfg
+    # YOLOv3
+    elif args.model in ['yolov3', 'yolov3_t']:
+        cfg = yolov3_cfg[args.model]
+    # YOLOv4
+    elif args.model in ['yolov4', 'yolov4_t']:
+        cfg = yolov4_cfg[args.model]
+    # YOLOv5
+    elif args.model in ['yolov5_n', 'yolov5_s', 'yolov5_m', 'yolov5_l', 'yolov5_x']:
+        cfg = yolov5_cfg[args.model]
+    # YOLOv7
+    elif args.model in ['yolov7_t', 'yolov7_l', 'yolov7_x']:
+        cfg = yolov7_cfg[args.model]
+    # YOLOX
+    elif args.model in ['yolox_n', 'yolox_s', 'yolox_m', 'yolox_l', 'yolox_x']:
+        cfg = yolox_cfg[args.model]
+    # YOLOX2
+    elif args.model in ['yolox2_n', 'yolox2_s', 'yolox2_m', 'yolox2_l', 'yolox2_x']:
+        cfg = yolox2_cfg[args.model]
+    # RT-DETR
+    elif args.model in ['rtdetr_n', 'rtdetr_s', 'rtdetr_m', 'rtdetr_l', 'rtdetr_x']:
+        cfg = rtdetr_cfg[args.model]
+
+    return cfg
+

+ 79 - 0
config/rtdetr_config.py

@@ -0,0 +1,79 @@
+# yolo-free config
+
+
+rtdetr_cfg = {
+    # P5
+    'rtdetr_n': {
+        # ---------------- Model config ----------------
+        ## ------- Image Encoder -------
+        ### CNN-Backbone
+        'backbone': 'elannet',
+        'pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': False,
+        'width': 0.25,
+        'depth': 0.34,
+        'stride': [8, 16, 32],  # P3, P4, P5
+        'max_stride': 32,
+        ### CNN-Neck
+        'neck': 'sppf',
+        'neck_expand_ratio': 0.5,
+        'pooling_size': 5,
+        'neck_act': 'silu',
+        'neck_norm': 'BN',
+        'neck_depthwise': False,
+        ### CNN-CSFM
+        'fpn': 'yolo_pafpn',
+        'fpn_reduce_layer': 'conv',
+        'fpn_downsample_layer': 'conv',
+        'fpn_core_block': 'elanblock',
+        'fpn_act': 'silu',
+        'fpn_norm': 'BN',
+        'fpn_depthwise': False,
+        ## ------- Transformer Decoder -------
+        'd_model': 256,
+        'attn_type': 'mhsa',
+        'num_decoder_layers': 6,
+        'num_queries': 300,
+        'de_dim_feedforward': 1024,
+        'de_num_heads': 8,
+        'de_dropout': 0.1,
+        'de_act': 'silu',
+        'de_norm': 'LN',
+        # ---------------- Train config ----------------
+        ## input
+        'multi_scale': [0.5, 1.0],   # 320 -> 640
+        'trans_type': 'yolov5_nano',
+        # ---------------- Assignment config ----------------
+        ## matcher
+        'set_cost_class': 2.0,
+        'set_cost_bbox': 5.0,
+        'set_cost_giou': 2.0,
+        # ---------------- Loss config ----------------
+        ## loss weight
+        'focal_alpha': 0.25,
+        'loss_cls_weight': 1.0,
+        'loss_box_weight': 5.0,
+        'loss_giou_weight': 2.0,
+        # ---------------- Train config ----------------
+        ## close strong augmentation
+        'no_aug_epoch': 10,
+        'trainer_type': 'detr',
+        ## optimizer
+        'optimizer': 'adamw',
+        'momentum': None,
+        'weight_decay': 1e-4,
+        'clip_grad': 0.1,
+        ## model EMA
+        'ema_decay': 0.9999,       # SGD: 0.9999;   AdamW: 0.9998
+        'ema_tau': 2000,
+        ## lr schedule
+        'scheduler': 'linear',
+        'lr0': 0.0001,             # SGD: 0.01;     AdamW: 0.001
+        'lrf': 0.05,               # SGD: 0.01;     AdamW: 0.01
+        'warmup_momentum': 0.8,
+        'warmup_bias_lr': 0.1,
+        },
+
+}

+ 1 - 0
config/yolov1_config.py

@@ -29,6 +29,7 @@ yolov1_cfg = {
     'loss_box_weight': 5.0,
     # training configuration
     'no_aug_epoch': -1,
+    'trainer_type': 'yolo',
     # optimizer
     'optimizer': 'sgd',        # optional: sgd, adam, adamw
     'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid

+ 1 - 0
config/yolov2_config.py

@@ -36,6 +36,7 @@ yolov2_cfg = {
     'loss_box_weight': 5.0,
     # training configuration
     'no_aug_epoch': -1,
+    'trainer_type': 'yolo',
     # optimizer
     'optimizer': 'sgd',        # optional: sgd, adam, adamw
     'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid

+ 2 - 0
config/yolov3_config.py

@@ -47,6 +47,7 @@ yolov3_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 10,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.937,         # SGD: 0.937;    AdamW: None
@@ -109,6 +110,7 @@ yolov3_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 10,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.937,         # SGD: 0.937;    AdamW: None

+ 2 - 0
config/yolov4_config.py

@@ -47,6 +47,7 @@ yolov4_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 10,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.937,         # SGD: 0.937;    AdamW: None
@@ -109,6 +110,7 @@ yolov4_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 10,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.937,         # SGD: 0.937;    AdamW: None

+ 5 - 0
config/yolov5_config.py

@@ -46,6 +46,7 @@ yolov5_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.937,         # SGD: 0.937;    AdamW: None
@@ -107,6 +108,7 @@ yolov5_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.937,         # SGD: 0.937;    AdamW: None
@@ -168,6 +170,7 @@ yolov5_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.937,         # SGD: 0.937;    AdamW: None
@@ -229,6 +232,7 @@ yolov5_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.937,         # SGD: 0.937;    AdamW: None
@@ -290,6 +294,7 @@ yolov5_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.937,         # SGD: 0.937;    AdamW: None

+ 3 - 0
config/yolov7_config.py

@@ -44,6 +44,7 @@ yolov7_cfg = {
         'loss_box_weight': 5.0,
         # training configuration
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         # optimizer
         'optimizer': 'sgd',        # optional: sgd, adam, adamw
         'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
@@ -103,6 +104,7 @@ yolov7_cfg = {
         'loss_box_weight': 5.0,
         # training configuration
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         # optimizer
         'optimizer': 'sgd',        # optional: sgd, adam, adamw
         'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid
@@ -162,6 +164,7 @@ yolov7_cfg = {
         'loss_box_weight': 5.0,
         # training configuration
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         # optimizer
         'optimizer': 'sgd',        # optional: sgd, adam, adamw
         'momentum': 0.937,         # SGD: 0.937;    AdamW: invalid

+ 1 - 0
config/yolox2_config.py

@@ -53,6 +53,7 @@ yolox2_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.9,           # SGD: 0.9;    AdamW: None

+ 5 - 0
config/yolox_config.py

@@ -45,6 +45,7 @@ yolox_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.9,           # SGD: 0.9;    AdamW: None
@@ -104,6 +105,7 @@ yolox_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.9,           # SGD: 0.9;    AdamW: None
@@ -163,6 +165,7 @@ yolox_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.9,           # SGD: 0.9;    AdamW: None
@@ -222,6 +225,7 @@ yolox_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.9,           # SGD: 0.9;    AdamW: None
@@ -281,6 +285,7 @@ yolox_cfg = {
         # ---------------- Train config ----------------
         ## close strong augmentation
         'no_aug_epoch': 20,
+        'trainer_type': 'yolo',
         ## optimizer
         'optimizer': 'sgd',        # optional: sgd, AdamW
         'momentum': 0.9,           # SGD: 0.9;    AdamW: None

+ 430 - 41
engine.py

@@ -6,39 +6,121 @@ import os
 import numpy as np
 import random
 
+# ----------------- Extra Components -----------------
 from utils import distributed_utils
+from utils.misc import ModelEMA, CollateFunc, build_dataloader
 from utils.vis_tools import vis_data
 
+# ----------------- Evaluator Components -----------------
+from evaluator.build import build_evluator
 
+# ----------------- Optimizer & LrScheduler Components -----------------
+from utils.solver.optimizer import build_yolo_optimizer, build_detr_optimizer
+from utils.solver.lr_scheduler import build_lr_scheduler
 
-class Trainer(object):
-    def __init__(self, args, device, cfg, model_ema, optimizer, lf, lr_scheduler, criterion, scaler):
+# ----------------- Dataset Components -----------------
+from dataset.build import build_dataset, build_transform
+
+
+# Trainer for YOLO
+class YoloTrainer(object):
+    def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
         # ------------------- basic parameters -------------------
         self.args = args
-        self.cfg = cfg
-        self.device = device
         self.epoch = 0
         self.best_map = -1.
-        # ------------------- core modules -------------------
-        self.model_ema = model_ema
-        self.optimizer = optimizer
-        self.lf = lf
-        self.lr_scheduler = lr_scheduler
-        self.criterion = criterion
-        self.scaler = scaler
         self.last_opt_step = 0
+        self.device = device
+        self.criterion = criterion
+        self.heavy_eval = False
+
+        # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
+        self.data_cfg = data_cfg
+        self.model_cfg = model_cfg
+        self.trans_cfg = trans_cfg
+
+        # ---------------------------- Build Transform ----------------------------
+        self.train_transform, self.trans_cfg = build_transform(
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.val_transform, _ = build_transform(
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
+
+        # ---------------------------- Build Dataset & Dataloader ----------------------------
+        self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
+        world_size = distributed_utils.get_world_size()
+        self.train_loader = build_dataloader(args, self.dataset, self.args.batch_size // world_size, CollateFunc())
+
+        # ---------------------------- Build Evaluator ----------------------------
+        self.evaluator = build_evluator(args, self.data_cfg, self.val_transform, self.device)
+
+        # ---------------------------- Build Grad. Scaler ----------------------------
+        self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
+
+        # ---------------------------- Build Optimizer ----------------------------
+        accumulate = max(1, round(64 / args.batch_size))
+        self.model_cfg['weight_decay'] *= args.batch_size * accumulate / 64
+        self.optimizer, self.start_epoch = build_yolo_optimizer(self.model_cfg, model, self.model_cfg['lr0'], args.resume)
+
+        # ---------------------------- Build LR Scheduler ----------------------------
+        args.max_epoch += args.wp_epoch
+        self.lr_scheduler, self.lf = build_lr_scheduler(self.model_cfg, self.optimizer, args.max_epoch)
+        self.lr_scheduler.last_epoch = self.start_epoch - 1  # do not move
+        if args.resume:
+            self.lr_scheduler.step()
+
+        # ---------------------------- Build Model-EMA ----------------------------
+        if args.ema and distributed_utils.get_rank() in [-1, 0]:
+            print('Build ModelEMA ...')
+            self.model_ema = ModelEMA(
+                model,
+                self.model_cfg['ema_decay'],
+                self.model_cfg['ema_tau'],
+                self.start_epoch * len(self.train_loader))
+        else:
+            self.model_ema = None
+
+
+    def train(self, model):
+        for epoch in range(self.start_epoch, self.args.max_epoch):
+            if self.args.distributed:
+                self.train_loader.batch_sampler.sampler.set_epoch(epoch)
+
+            # check second stage
+            if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
+                # close mosaic augmentation
+                if self.train_loader.dataset.mosaic_prob > 0.:
+                    print('close Mosaic Augmentation ...')
+                    self.train_loader.dataset.mosaic_prob = 0.
+                    self.heavy_eval = True
+                # close mixup augmentation
+                if self.train_loader.dataset.mixup_prob > 0.:
+                    print('close Mixup Augmentation ...')
+                    self.train_loader.dataset.mixup_prob = 0.
+                    self.heavy_eval = True
+
+            # train one epoch
+            self.train_one_epoch(model)
+
+            # eval one epoch
+            if self.heavy_eval:
+                model_eval = model.module if self.args.distributed else model
+                self.eval_one_epoch(model_eval)
+            else:
+                model_eval = model.module if self.args.distributed else model
+                if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
+                    self.eval_one_epoch(model_eval)
 
 
-    def train_one_epoch(self, model, train_loader):
+    def train_one_epoch(self, model):
         # basic parameters
-        epoch_size = len(train_loader)
+        epoch_size = len(self.train_loader)
         img_size = self.args.img_size
         t0 = time.time()
         nw = epoch_size * self.args.wp_epoch
         accumulate = accumulate = max(1, round(64 / self.args.batch_size))
 
-        # train one epoch
-        for iter_i, (images, targets) in enumerate(train_loader):
+        # Train one epoch
+        for iter_i, (images, targets) in enumerate(self.train_loader):
             ni = iter_i + self.epoch * epoch_size
             # Warmup
             if ni <= nw:
@@ -47,57 +129,48 @@ class Trainer(object):
                 for j, x in enumerate(self.optimizer.param_groups):
                     # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
                     x['lr'] = np.interp(
-                        ni, xi, [self.cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
+                        ni, xi, [self.model_cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * self.lf(self.epoch)])
                     if 'momentum' in x:
-                        x['momentum'] = np.interp(ni, xi, [self.cfg['warmup_momentum'], self.cfg['momentum']])
+                        x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
                                 
-            # to device
+            # To device
             images = images.to(self.device, non_blocking=True).float() / 255.
 
-            # multi scale
+            # Multi scale
             if self.args.multi_scale:
                 images, targets, img_size = self.rescale_image_targets(
-                    images, targets, model.stride, self.args.min_box_size, self.cfg['multi_scale'])
+                    images, targets, model.stride, self.args.min_box_size, self.model_cfg['multi_scale'])
             else:
                 targets = self.refine_targets(targets, self.args.min_box_size)
                 
-            # visualize train targets
+            # Visualize train targets
             if self.args.vis_tgt:
                 vis_data(images*255, targets)
 
-            # inference
+            # Inference
             with torch.cuda.amp.autocast(enabled=self.args.fp16):
                 outputs = model(images)
-                # loss
+                # Compute loss
                 loss_dict = self.criterion(outputs=outputs, targets=targets)
                 losses = loss_dict['losses']
                 losses *= images.shape[0]  # loss * bs
 
-                # reduce            
                 loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
 
                 if self.args.distributed:
                     # gradient averaged between devices in DDP mode
                     losses *= distributed_utils.get_world_size()
 
-            # check loss
-            try:
-                if torch.isnan(losses):
-                    print('loss is NAN !!')
-                    continue
-            except:
-                print(loss_dict)
-
-            # backward
+            # Backward
             self.scaler.scale(losses).backward()
 
             # Optimize
             if ni - self.last_opt_step >= accumulate:
-                if self.cfg['clip_grad'] > 0:
+                if self.model_cfg['clip_grad'] > 0:
                     # unscale gradients
                     self.scaler.unscale_(self.optimizer)
                     # clip gradients
-                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg['clip_grad'])
+                    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
                 # optimizer.step
                 self.scaler.step(self.optimizer)
                 self.scaler.update()
@@ -107,7 +180,7 @@ class Trainer(object):
                     self.model_ema.update(model)
                 self.last_opt_step = ni
 
-            # display
+            # Logs
             if distributed_utils.is_main_process() and iter_i % 10 == 0:
                 t1 = time.time()
                 cur_lr = [param_group['lr']  for param_group in self.optimizer.param_groups]
@@ -132,12 +205,12 @@ class Trainer(object):
                 
                 t0 = time.time()
         
+        # LR Schedule
         self.lr_scheduler.step()
         self.epoch += 1
         
 
-    @torch.no_grad()
-    def eval_one_epoch(self, model, evaluator):
+    def eval(self, model):
         # chech model
         model_eval = model if self.model_ema is None else self.model_ema.ema
 
@@ -147,7 +220,7 @@ class Trainer(object):
 
         if distributed_utils.is_main_process():
             # check evaluator
-            if evaluator is None:
+            if self.evaluator is None:
                 print('No evaluator ... save model and go on training.')
                 print('Saving state, epoch: {}'.format(self.epoch + 1))
                 weight_name = '{}_no_eval.pth'.format(self.args.model)
@@ -166,10 +239,11 @@ class Trainer(object):
                 model_eval.eval()
 
                 # evaluate
-                evaluator.evaluate(model_eval)
+                with torch.no_grad():
+                    self.evaluator.evaluate(model_eval)
 
                 # save model
-                cur_map = evaluator.map
+                cur_map = self.evaluator.map
                 if cur_map > self.best_map:
                     # update best-map
                     self.best_map = cur_map
@@ -247,3 +321,318 @@ class Trainer(object):
 
         return images, targets, new_img_size
 
+
+# Trainer for DETR
+class DetrTrainer(object):
+    def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
+        # ------------------- basic parameters -------------------
+        self.args = args
+        self.epoch = 0
+        self.best_map = -1.
+        self.last_opt_step = 0
+        self.device = device
+        self.criterion = criterion
+        self.heavy_eval = False
+
+        # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
+        self.data_cfg = data_cfg
+        self.model_cfg = model_cfg
+        self.trans_cfg = trans_cfg
+
+        # ---------------------------- Build Transform ----------------------------
+        self.train_transform, self.trans_cfg = build_transform(
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
+        self.val_transform, _ = build_transform(
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
+
+        # ---------------------------- Build Dataset & Dataloader ----------------------------
+        self.dataset, self.dataset_info = build_dataset(args, self.data_cfg, self.trans_cfg, self.train_transform, is_train=True)
+        world_size = distributed_utils.get_world_size()
+        self.train_loader = build_dataloader(args, self.dataset, self.args.batch_size // world_size, CollateFunc())
+
+        # ---------------------------- Build Evaluator ----------------------------
+        self.evaluator = build_evluator(args, self.data_cfg, self.val_transform, self.device)
+
+        # ---------------------------- Build Grad. Scaler ----------------------------
+        self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
+
+        # ---------------------------- Build Optimizer ----------------------------
+        self.model_cfg['lr0'] *= args.batch_size / 16.
+        self.optimizer, self.start_epoch = build_detr_optimizer(model_cfg, model, args.resume)
+
+        # ---------------------------- Build LR Scheduler ----------------------------
+        args.max_epoch += args.wp_epoch
+        self.lr_scheduler, self.lf = build_lr_scheduler(self.model_cfg, self.optimizer, args.max_epoch)
+        self.lr_scheduler.last_epoch = self.start_epoch - 1  # do not move
+        if args.resume:
+            self.lr_scheduler.step()
+
+        # ---------------------------- Build Model-EMA ----------------------------
+        if args.ema and distributed_utils.get_rank() in [-1, 0]:
+            print('Build ModelEMA ...')
+            self.model_ema = ModelEMA(
+                model,
+                self.model_cfg['ema_decay'],
+                self.model_cfg['ema_tau'],
+                self.start_epoch * len(self.train_loader))
+        else:
+            self.model_ema = None
+
+
+    def train(self, model):
+        for epoch in range(self.start_epoch, self.args.max_epoch):
+            if self.args.distributed:
+                self.train_loader.batch_sampler.sampler.set_epoch(epoch)
+
+            # check second stage
+            if epoch >= (self.args.max_epoch - self.model_cfg['no_aug_epoch'] - 1):
+                # close mosaic augmentation
+                if self.train_loader.dataset.mosaic_prob > 0.:
+                    print('close Mosaic Augmentation ...')
+                    self.train_loader.dataset.mosaic_prob = 0.
+                    self.heavy_eval = True
+                # close mixup augmentation
+                if self.train_loader.dataset.mixup_prob > 0.:
+                    print('close Mixup Augmentation ...')
+                    self.train_loader.dataset.mixup_prob = 0.
+                    self.heavy_eval = True
+
+            # train one epoch
+            self.train_one_epoch(model)
+
+            # eval one epoch
+            if self.heavy_eval:
+                model_eval = model.module if self.args.distributed else model
+                self.eval_one_epoch(model_eval)
+            else:
+                model_eval = model.module if self.args.distributed else model
+                if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
+                    self.eval_one_epoch(model_eval)
+
+
+    def train_one_epoch(self, model):
+        # basic parameters
+        epoch_size = len(self.train_loader)
+        img_size = self.args.img_size
+        t0 = time.time()
+        nw = epoch_size * self.args.wp_epoch
+
+        # train one epoch
+        for iter_i, (images, targets) in enumerate(self.train_loader):
+            ni = iter_i + self.epoch * epoch_size
+            # Warmup
+            if ni <= nw:
+                xi = [0, nw]  # x interp
+                for j, x in enumerate(self.optimizer.param_groups):
+                    # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
+                    x['lr'] = np.interp(
+                        ni, xi, [0.0, x['initial_lr'] * self.lf(self.epoch)])
+                    if 'momentum' in x:
+                        x['momentum'] = np.interp(ni, xi, [self.model_cfg['warmup_momentum'], self.model_cfg['momentum']])
+                                
+            # To device
+            images = images.to(self.device, non_blocking=True).float() / 255.
+
+            # Multi scale
+            if self.args.multi_scale:
+                images, targets, img_size = self.rescale_image_targets(
+                    images, targets, model.stride, self.args.min_box_size, self.model_cfg['multi_scale'])
+            else:
+                targets = self.refine_targets(targets, self.args.min_box_size, img_size)
+                
+            # Visualize targets
+            if self.args.vis_tgt:
+                vis_data(images*255, targets)
+
+            # Inference
+            with torch.cuda.amp.autocast(enabled=self.args.fp16):
+                outputs = model(images)
+                # Compute loss
+                loss_dict = self.criterion(outputs=outputs, targets=targets)
+                losses = loss_dict['losses']
+
+                loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
+
+            # Backward
+            self.scaler.scale(losses).backward()
+
+            # Optimize
+            if self.model_cfg['clip_grad'] > 0:
+                # unscale gradients
+                self.scaler.unscale_(self.optimizer)
+                # clip gradients
+                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.model_cfg['clip_grad'])
+            self.scaler.step(self.optimizer)
+            self.scaler.update()
+            self.optimizer.zero_grad()
+
+            # Model EMA
+            if self.model_ema is not None:
+                self.model_ema.update(model)
+            self.last_opt_step = ni
+
+            # Log
+            if distributed_utils.is_main_process() and iter_i % 10 == 0:
+                t1 = time.time()
+                cur_lr = [param_group['lr']  for param_group in self.optimizer.param_groups]
+                # basic infor
+                log =  '[Epoch: {}/{}]'.format(self.epoch+1, self.args.max_epoch)
+                log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
+                log += '[lr: {:.6f}]'.format(cur_lr[0])
+                # loss infor
+                for k in loss_dict_reduced.keys():
+                    if self.args.vis_aux_loss:
+                        log += '[{}: {:.2f}]'.format(k, loss_dict[k])
+                    else:
+                        if k in ['loss_cls', 'loss_bbox', 'loss_giou', 'losses']:
+                            log += '[{}: {:.2f}]'.format(k, loss_dict[k])
+
+                # other infor
+                log += '[time: {:.2f}]'.format(t1 - t0)
+                log += '[size: {}]'.format(img_size)
+
+                # print log infor
+                print(log, flush=True)
+                
+                t0 = time.time()
+        
+        # LR Scheduler
+        self.lr_scheduler.step()
+        self.epoch += 1
+        
+
+    def eval(self, model):
+        # chech model
+        model_eval = model if self.model_ema is None else self.model_ema.ema
+
+        # path to save model
+        path_to_save = os.path.join(self.args.save_folder, self.args.dataset, self.args.model)
+        os.makedirs(path_to_save, exist_ok=True)
+
+        if distributed_utils.is_main_process():
+            # check evaluator
+            if self.evaluator is None:
+                print('No evaluator ... save model and go on training.')
+                print('Saving state, epoch: {}'.format(self.epoch + 1))
+                weight_name = '{}_no_eval.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(path_to_save, weight_name)
+                torch.save({'model': model_eval.state_dict(),
+                            'mAP': -1.,
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)                      
+                
+            else:
+                print('eval ...')
+                # set eval mode
+                model_eval.trainable = False
+                model_eval.eval()
+
+                # evaluate
+                with torch.no_grad():
+                    self.evaluator.evaluate(model_eval)
+
+                # save model
+                cur_map = self.evaluator.map
+                if cur_map > self.best_map:
+                    # update best-map
+                    self.best_map = cur_map
+                    # save model
+                    print('Saving state, epoch:', self.epoch + 1)
+                    weight_name = '{}_best.pth'.format(self.args.model)
+                    checkpoint_path = os.path.join(path_to_save, weight_name)
+                    torch.save({'model': model_eval.state_dict(),
+                                'mAP': round(self.best_map*100, 1),
+                                'optimizer': self.optimizer.state_dict(),
+                                'epoch': self.epoch,
+                                'args': self.args}, 
+                                checkpoint_path)                      
+
+                # set train mode.
+                model_eval.trainable = True
+                model_eval.train()
+
+        if self.args.distributed:
+            # wait for all processes to synchronize
+            dist.barrier()
+
+
+    def refine_targets(self, targets, min_box_size, img_size):
+        # rescale targets
+        for tgt in targets:
+            boxes = tgt["boxes"]
+            labels = tgt["labels"]
+            # refine tgt
+            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= min_box_size)
+            # xyxy -> cxcywh
+            new_boxes = torch.zeros_like(boxes)
+            new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
+            new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
+            # normalize
+            new_boxes /= img_size
+            del boxes
+
+            tgt["boxes"] = new_boxes[keep]
+            tgt["labels"] = labels[keep]
+        
+        return targets
+
+
+    def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
+        """
+            Deployed for Multi scale trick.
+        """
+        if isinstance(stride, int):
+            max_stride = stride
+        elif isinstance(stride, list):
+            max_stride = max(stride)
+
+        # During training phase, the shape of input image is square.
+        old_img_size = images.shape[-1]
+        new_img_size = random.randrange(old_img_size * multi_scale_range[0], old_img_size * multi_scale_range[1] + max_stride)
+        new_img_size = new_img_size // max_stride * max_stride  # size
+        if new_img_size / old_img_size != 1:
+            # interpolate
+            images = torch.nn.functional.interpolate(
+                                input=images, 
+                                size=new_img_size, 
+                                mode='bilinear', 
+                                align_corners=False)
+        # rescale targets
+        for tgt in targets:
+            boxes = tgt["boxes"].clone()
+            labels = tgt["labels"].clone()
+            boxes = torch.clamp(boxes, 0, old_img_size)
+            # rescale box
+            boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
+            boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
+            # refine tgt
+            tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+            keep = (min_tgt_size >= min_box_size)
+            # xyxy -> cxcywh
+            new_boxes = torch.zeros_like(boxes)
+            new_boxes[..., :2] = (boxes[..., 2:] + boxes[..., :2]) * 0.5
+            new_boxes[..., 2:] = (boxes[..., 2:] - boxes[..., :2])
+            # normalize
+            new_boxes /= new_img_size
+            del boxes
+
+            tgt["boxes"] = new_boxes[keep]
+            tgt["labels"] = labels[keep]
+
+        return images, targets, new_img_size
+
+
+# Build Trainer
+def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion):
+    if model_cfg['trainer_type'] == 'yolo':
+        return YoloTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion)
+    elif model_cfg['trainer_type'] == 'detr':
+        return DetrTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion)
+    else:
+        raise NotImplementedError
+    

+ 5 - 0
models/detectors/__init__.py

@@ -10,6 +10,7 @@ from .yolov5.build import build_yolov5
 from .yolov7.build import build_yolov7
 from .yolox.build import build_yolox
 from .yolox2.build import build_yolox2
+from .rtdetr.build import build_rtdetr
 
 
 # build object detector
@@ -51,6 +52,10 @@ def build_model(args,
     elif args.model in ['yolox2_n', 'yolox2_s', 'yolox2_m', 'yolox2_l', 'yolox2_x']:
         model, criterion = build_yolox2(
             args, model_cfg, device, num_classes, trainable, deploy)
+    # RT-DETR
+    elif args.model in ['rtdetr_n', 'rtdetr_s', 'rtdetr_m', 'rtdetr_l', 'rtdetr_x']:
+        model, criterion = build_rtdetr(
+            args, model_cfg, device, num_classes, trainable, deploy)
 
     if trainable:
         # Load pretrained weight

+ 33 - 0
models/detectors/rtdetr/build.py

@@ -0,0 +1,33 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+from .loss import build_criterion
+from .rtdetr import RTDETR
+
+
+# build object detector
+def build_rtdetr(args, cfg, device, num_classes=80, trainable=False, deploy=False):
+    print('==============================')
+    print('Build {} ...'.format(args.model.upper()))
+    
+    print('==============================')
+    print('Model Configuration: \n', cfg)
+    
+    # -------------- Build rtdetr --------------
+    model = RTDETR(
+        cfg=cfg,
+        device=device, 
+        num_classes=num_classes,
+        trainable=trainable,
+        aux_loss=trainable,
+        with_box_refine=True,
+        deploy=deploy
+        )
+
+    # -------------- Build criterion --------------
+    criterion = None
+    if trainable:
+        # build criterion for training
+        criterion = build_criterion(cfg, num_classes, aux_loss=True)
+        
+    return model, criterion

+ 154 - 0
models/detectors/rtdetr/image_encoder/cnn_backbone.py

@@ -0,0 +1,154 @@
+import torch
+import torch.nn as nn
+try:
+    from .cnn_basic import Conv, ELANBlock, DownSample
+except:
+    from cnn_basic import Conv, ELANBlock, DownSample
+
+
+
+model_urls = {
+    'elannet_pico': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_pico.pth",
+    'elannet_nano': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_nano.pth",
+    'elannet_small': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_small.pth",
+    'elannet_medium': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_medium.pth",
+    'elannet_large': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_large.pth",
+    'elannet_huge': "https://github.com/yjh0410/image_classification_pytorch/releases/download/weight/elannet_huge.pth",
+}
+
+
+# ---------------------------- Backbones ----------------------------
+# ELANNet-P5
+class ELANNet(nn.Module):
+    def __init__(self, width=1.0, depth=1.0, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANNet, self).__init__()
+        self.feat_dims = [int(512 * width), int(1024 * width), int(1024 * width)]
+        
+        # P1/2
+        self.layer_1 = nn.Sequential(
+            Conv(3, int(64*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type),
+            Conv(int(64*width), int(64*width), k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P2/4
+        self.layer_2 = nn.Sequential(   
+            Conv(int(64*width), int(128*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise),             
+            ELANBlock(in_dim=int(128*width), out_dim=int(256*width), expand_ratio=0.5, depth=depth,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            DownSample(in_dim=int(256*width), out_dim=int(256*width), act_type=act_type, norm_type=norm_type),             
+            ELANBlock(in_dim=int(256*width), out_dim=int(512*width), expand_ratio=0.5, depth=depth,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            DownSample(in_dim=int(512*width), out_dim=int(512*width), act_type=act_type, norm_type=norm_type),             
+            ELANBlock(in_dim=int(512*width), out_dim=int(1024*width), expand_ratio=0.5, depth=depth,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            DownSample(in_dim=int(1024*width), out_dim=int(1024*width), act_type=act_type, norm_type=norm_type),             
+            ELANBlock(in_dim=int(1024*width), out_dim=int(1024*width), expand_ratio=0.25, depth=depth,
+                    act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ---------------------------- Functions ----------------------------
+## load pretrained weight
+def load_weight(model, model_name):
+    # load weight
+    print('Loading pretrained weight ...')
+    url = model_urls[model_name]
+    if url is not None:
+        checkpoint = torch.hub.load_state_dict_from_url(
+            url=url, map_location="cpu", check_hash=True)
+        # checkpoint state dict
+        checkpoint_state_dict = checkpoint.pop("model")
+        # model state dict
+        model_state_dict = model.state_dict()
+        # check
+        for k in list(checkpoint_state_dict.keys()):
+            if k in model_state_dict:
+                shape_model = tuple(model_state_dict[k].shape)
+                shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                if shape_model != shape_checkpoint:
+                    checkpoint_state_dict.pop(k)
+            else:
+                checkpoint_state_dict.pop(k)
+                print(k)
+
+        model.load_state_dict(checkpoint_state_dict)
+    else:
+        print('No pretrained for {}'.format(model_name))
+
+    return model
+
+
+## build ELAN-Net
+def build_backbone(cfg, pretrained=False): 
+    # model
+    backbone = ELANNet(
+        width=cfg['width'],
+        depth=cfg['depth'],
+        act_type=cfg['bk_act'],
+        norm_type=cfg['bk_norm'],
+        depthwise=cfg['bk_dpw']
+        )
+    # check whether to load imagenet pretrained weight
+    if pretrained:
+        if cfg['width'] == 0.25 and cfg['depth'] == 0.34 and cfg['bk_dpw']:
+            backbone = load_weight(backbone, model_name='elannet_pico')
+        elif cfg['width'] == 0.25 and cfg['depth'] == 0.34:
+            backbone = load_weight(backbone, model_name='elannet_nano')
+        elif cfg['width'] == 0.5 and cfg['depth'] == 0.34:
+            backbone = load_weight(backbone, model_name='elannet_small')
+        elif cfg['width'] == 0.75 and cfg['depth'] == 0.67:
+            backbone = load_weight(backbone, model_name='elannet_medium')
+        elif cfg['width'] == 1.0 and cfg['depth'] == 1.0:
+            backbone = load_weight(backbone, model_name='elannet_large')
+        elif cfg['width'] == 1.25 and cfg['depth'] == 1.34:
+            backbone = load_weight(backbone, model_name='elannet_huge')
+    feat_dims = backbone.feat_dims
+
+    return backbone, feat_dims
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    cfg = {
+        'pretrained': True,
+        'bk_act': 'silu',
+        'bk_norm': 'BN',
+        'bk_dpw': True,
+        'width': 0.25,
+        'depth': 0.34,
+    }
+    model, feats = build_backbone(cfg)
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 190 - 0
models/detectors/rtdetr/image_encoder/cnn_basic.py

@@ -0,0 +1,190 @@
+import torch
+import torch.nn as nn
+
+
+# ------------------------------- Basic Modules -------------------------------
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'gelu':
+        return nn.GELU()
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type == 'LN':
+        return nn.LayerNorm(dim)
+
+
+# ------------------------------- Conv -------------------------------
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type='relu',      # activation
+                 norm_type='BN',       # normalization
+                 depthwise=False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+
+# ---------------------------- Modified YOLOv7's Modules ----------------------------
+## ELANBlock
+class ELANBlock(nn.Module):
+    def __init__(self, in_dim, out_dim, expand_ratio=0.5, depth=1.0, act_type='silu', norm_type='BN', depthwise=False):
+        super(ELANBlock, self).__init__()
+        if isinstance(expand_ratio, float):
+            inter_dim = int(in_dim * expand_ratio)
+            inter_dim2 = inter_dim
+        elif isinstance(expand_ratio, list):
+            assert len(expand_ratio) == 2
+            e1, e2 = expand_ratio
+            inter_dim = int(in_dim * e1)
+            inter_dim2 = int(inter_dim * e2)
+        # branch-1
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        # branch-2
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        # branch-3
+        for idx in range(round(3*depth)):
+            if idx == 0:
+                cv3 = [Conv(inter_dim, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)]
+            else:
+                cv3.append(Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise))
+        self.cv3 = nn.Sequential(*cv3)
+        # branch-4
+        self.cv4 = nn.Sequential(*[
+            Conv(inter_dim2, inter_dim2, k=3, p=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+            for _ in range(round(3*depth))
+        ])
+        # output
+        self.out = Conv(inter_dim*2 + inter_dim2*2, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+
+    def forward(self, x):
+        """
+        Input:
+            x: [B, C_in, H, W]
+        Output:
+            out: [B, C_out, H, W]
+        """
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.cv3(x2)
+        x4 = self.cv4(x3)
+
+        # [B, C, H, W] -> [B, 2C, H, W]
+        out = self.out(torch.cat([x1, x2, x3, x4], dim=1))
+
+        return out
+
+## DownSample
+class DownSample(nn.Module):
+    def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
+        super().__init__()
+        inter_dim = out_dim // 2
+        self.mp = nn.MaxPool2d((2, 2), 2)
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.cv2 = nn.Sequential(
+            Conv(in_dim, inter_dim, k=1, act_type=act_type, norm_type=norm_type),
+            Conv(inter_dim, inter_dim, k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        )
+
+    def forward(self, x):
+        """
+        Input:
+            x: [B, C, H, W]
+        Output:
+            out: [B, C, H//2, W//2]
+        """
+        # [B, C, H, W] -> [B, C//2, H//2, W//2]
+        x1 = self.cv1(self.mp(x))
+        x2 = self.cv2(x)
+
+        # [B, C, H//2, W//2]
+        out = torch.cat([x1, x2], dim=1)
+
+        return out
+
+
+## build core block for CSFM
+def build_fpn_block(cfg, in_dim, out_dim):
+    if cfg['fpn_core_block'] == 'elanblock':
+        layer = ELANBlock(in_dim=in_dim,
+                          out_dim=out_dim,
+                          expand_ratio=[0.5, 0.5],
+                          depth=cfg['depth'],
+                          act_type=cfg['fpn_act'],
+                          norm_type=cfg['fpn_norm'],
+                          depthwise=cfg['fpn_depthwise']
+                          )
+        
+    return layer
+
+## build reduce layer for CSFM
+def build_reduce_layer(cfg, in_dim, out_dim):
+    layer = Conv(in_dim, out_dim, k=1,
+                 act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
+        
+    return layer
+
+## build downsample layer for CSFM
+def build_downsample_layer(cfg, in_dim, out_dim):
+    if cfg['fpn_downsample_layer'] == 'conv':
+        layer = Conv(in_dim, out_dim, k=3, s=2, p=1,
+                     act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
+        
+    return layer

+ 70 - 0
models/detectors/rtdetr/image_encoder/cnn_neck.py

@@ -0,0 +1,70 @@
+import torch
+import torch.nn as nn
+from .cnn_basic import Conv
+
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    """
+        This code referenced to https://github.com/ultralytics/yolov5
+    """
+    def __init__(self, cfg, in_dim, out_dim, expand_ratio=0.5):
+        super().__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.cv2 = Conv(inter_dim * 4, out_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.m = nn.MaxPool2d(kernel_size=cfg['pooling_size'], stride=1, padding=cfg['pooling_size'] // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
+
+
+# SPPF block with CSP module
+class SPPFBlockCSP(nn.Module):
+    """
+        CSP Spatial Pyramid Pooling Block
+    """
+    def __init__(self, cfg, in_dim, out_dim, expand_ratio):
+        super(SPPFBlockCSP, self).__init__()
+        inter_dim = int(in_dim * expand_ratio)
+        self.out_dim = out_dim
+        self.cv1 = Conv(in_dim, inter_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.cv2 = Conv(in_dim, inter_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+        self.m = nn.Sequential(
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=cfg['neck_act'], norm_type=cfg['neck_norm'], 
+                 depthwise=cfg['neck_depthwise']),
+            SPPF(cfg, inter_dim, inter_dim, expand_ratio=1.0),
+            Conv(inter_dim, inter_dim, k=3, p=1, 
+                 act_type=cfg['neck_act'], norm_type=cfg['neck_norm'], 
+                 depthwise=cfg['neck_depthwise'])
+        )
+        self.cv3 = Conv(inter_dim * 2, self.out_dim, k=1, act_type=cfg['neck_act'], norm_type=cfg['neck_norm'])
+
+        
+    def forward(self, x):
+        x1 = self.cv1(x)
+        x2 = self.cv2(x)
+        x3 = self.m(x2)
+        y = self.cv3(torch.cat([x1, x3], dim=1))
+
+        return y
+
+
+def build_neck(cfg, in_dim, out_dim):
+    model = cfg['neck']
+    print('==============================')
+    print('Neck: {}'.format(model))
+    # build neck
+    if model == 'sppf':
+        neck = SPPF(cfg, in_dim, out_dim, cfg['neck_expand_ratio'])
+    elif model == 'csp_sppf':
+        neck = SPPFBlockCSP(cfg, in_dim, out_dim, cfg['neck_expand_ratio'])
+
+    return neck
+        

+ 94 - 0
models/detectors/rtdetr/image_encoder/cnn_pafpn.py

@@ -0,0 +1,94 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .cnn_basic import (Conv, build_reduce_layer, build_downsample_layer, build_fpn_block)
+
+
+# YOLO-Style PaFPN
+class YoloPaFPN(nn.Module):
+    def __init__(self, cfg, in_dims=[256, 512, 1024], out_dim=None):
+        super(YoloPaFPN, self).__init__()
+        # --------------------------- Basic Parameters ---------------------------
+        self.in_dims = in_dims
+        c3, c4, c5 = in_dims
+        width = cfg['width']
+
+        # --------------------------- Network Parameters ---------------------------
+        ## top dwon
+        ### P5 -> P4
+        self.reduce_layer_1 = build_reduce_layer(cfg, c5, round(512*width))
+        self.reduce_layer_2 = build_reduce_layer(cfg, c4, round(512*width))
+        self.top_down_layer_1 = build_fpn_block(cfg, round(512*width) + round(512*width), round(512*width))
+
+        ### P4 -> P3
+        self.reduce_layer_3 = build_reduce_layer(cfg, round(512*width), round(256*width))
+        self.reduce_layer_4 = build_reduce_layer(cfg, c3, round(256*width))
+        self.top_down_layer_2 = build_fpn_block(cfg, round(256*width) + round(256*width), round(256*width))
+
+        ## bottom up
+        ### P3 -> P4
+        self.downsample_layer_1 = build_downsample_layer(cfg, round(256*width), round(256*width))
+        self.bottom_up_layer_1 = build_fpn_block(cfg, round(256*width) + round(256*width), round(512*width))
+
+        ### P4 -> P5
+        self.downsample_layer_2 = build_downsample_layer(cfg, round(512*width), round(512*width))
+        self.bottom_up_layer_2 = build_fpn_block(cfg, round(512*width) + round(512*width), round(1024*width))
+                
+        ## output proj layers
+        if out_dim is not None:
+            self.out_layers = nn.ModuleList([
+                Conv(in_dim, out_dim, k=1,
+                     act_type=cfg['fpn_act'], norm_type=cfg['fpn_norm'])
+                     for in_dim in [round(256*width), round(512*width), round(1024*width)]
+                     ])
+            self.out_dim = [out_dim] * 3
+        else:
+            self.out_layers = None
+            self.out_dim = [round(256*width), round(512*width), round(1024*width)]
+
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        # Top down
+        ## P5 -> P4
+        c6 = self.reduce_layer_1(c5)
+        c7 = self.reduce_layer_2(c4)
+        c8 = torch.cat([F.interpolate(c6, scale_factor=2.0), c7], dim=1)
+        c9 = self.top_down_layer_1(c8)
+        ## P4 -> P3
+        c10 = self.reduce_layer_3(c9)
+        c11 = self.reduce_layer_4(c3)
+        c12 = torch.cat([F.interpolate(c10, scale_factor=2.0), c11], dim=1)
+        c13 = self.top_down_layer_2(c12)
+
+        # Bottom up
+        # p3 -> P4
+        c14 = self.downsample_layer_1(c13)
+        c15 = torch.cat([c14, c10], dim=1)
+        c16 = self.bottom_up_layer_1(c15)
+        # P4 -> P5
+        c17 = self.downsample_layer_2(c16)
+        c18 = torch.cat([c17, c6], dim=1)
+        c19 = self.bottom_up_layer_2(c18)
+
+        out_feats = [c13, c16, c19] # [P3, P4, P5]
+        
+        # output proj layers
+        if self.out_layers is not None:
+            out_feats_proj = []
+            for feat, layer in zip(out_feats, self.out_layers):
+                out_feats_proj.append(layer(feat))
+            return out_feats_proj
+
+        return out_feats
+
+
+def build_fpn(cfg, in_dims, out_dim=None):
+    model = cfg['fpn']
+    # build pafpn
+    if model == 'yolo_pafpn':
+        fpn_net = YoloPaFPN(cfg, in_dims, out_dim)
+
+    return fpn_net

+ 78 - 0
models/detectors/rtdetr/image_encoder/img_encoder.py

@@ -0,0 +1,78 @@
+import torch
+import torch.nn as nn
+
+from .cnn_backbone import build_backbone
+from .cnn_neck import build_neck
+from .cnn_pafpn import build_fpn
+
+
+# ------------------------ Image Encoder ------------------------
+class ImageEncoder(nn.Module):
+    def __init__(self, cfg, trainable=False) -> None:
+        super().__init__()
+        ## Backbone
+        self.backbone, feats_dim = build_backbone(cfg, cfg['pretrained']*trainable)
+
+        ## Encoder
+        self.encoder = build_neck(cfg, feats_dim[-1], feats_dim[-1])
+
+        ## CSFM
+        self.csfm = build_fpn(cfg=cfg, in_dims=feats_dim, out_dim=round(cfg['d_model']*cfg['width']))
+
+
+    def position_embedding(self, x, temperature=10000):
+        hs, ws = x.shape[-2:]
+        device = x.device
+        num_pos_feats = x.shape[1] // 2       
+        scale = 2 * 3.141592653589793
+
+        # generate xy coord mat
+        y_embed, x_embed = torch.meshgrid(
+            [torch.arange(1, hs+1, dtype=torch.float32),
+             torch.arange(1, ws+1, dtype=torch.float32)])
+        y_embed = y_embed / (hs + 1e-6) * scale
+        x_embed = x_embed / (ws + 1e-6) * scale
+    
+        # [H, W] -> [1, H, W]
+        y_embed = y_embed[None, :, :].to(device)
+        x_embed = x_embed[None, :, :].to(device)
+
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=device)
+        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
+        dim_t = temperature ** (2 * dim_t_)
+
+        pos_x = torch.div(x_embed[:, :, :, None], dim_t)
+        pos_y = torch.div(y_embed[:, :, :, None], dim_t)
+        pos_x = torch.stack((pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4).flatten(3)
+        pos_y = torch.stack((pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4).flatten(3)
+
+        # [B, C, H, W]
+        pos_embed = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+        
+        return pos_embed
+        
+
+    def forward(self, x):
+        # Backbone
+        pyramid_feats = self.backbone(x)
+
+        # Encoder
+        pyramid_feats[-1] = self.encoder(pyramid_feats[-1])
+
+        # CSFM
+        pyramid_feats = self.csfm(pyramid_feats)
+
+        # Prepare memory & memoery_pos for Decoder
+        memory = torch.cat([feat.flatten(2) for feat in pyramid_feats], dim=-1)
+        memory = memory.permute(0, 2, 1).contiguous()
+        memory_pos = torch.cat([self.position_embedding(feat).flatten(2)
+                                for feat in pyramid_feats], dim=-1)
+        memory_pos = memory_pos.permute(0, 2, 1).contiguous()
+
+        return memory, memory_pos
+
+
+# build img-encoder
+def build_img_encoder(cfg, trainable):
+    return ImageEncoder(cfg, trainable)
+

+ 171 - 0
models/detectors/rtdetr/loss.py

@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import copy
+
+from .matcher import build_matcher
+from utils.misc import sigmoid_focal_loss
+from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
+from utils.distributed_utils import is_dist_avail_and_initialized, get_world_size
+
+
+class Criterion(nn.Module):
+    """ This class computes the loss for DETR.
+    The process happens in two steps:
+        1) we compute hungarian assignment between ground truth boxes and the outputs of the model
+        2) we supervise each pair of matched ground-truth / prediction (supervise class and box)
+    """
+    def __init__(self, num_classes, matcher, weight_dict, losses, focal_alpha=0.25):
+        """ Create the criterion.
+        Parameters:
+            num_classes: number of object categories, omitting the special no-object category
+            matcher: module able to compute a matching between targets and proposals
+            weight_dict: dict containing as key the names of the losses and as values their relative weight.
+            eos_coef: relative classification weight applied to the no-object category
+            losses: list of all the losses to be applied. See get_loss for list of available losses.
+        """
+        super().__init__()
+        self.num_classes = num_classes
+        self.matcher = matcher
+        self.weight_dict = weight_dict
+        self.losses = losses
+        self.focal_alpha = focal_alpha
+
+
+    def _get_src_permutation_idx(self, indices):
+        # permute predictions following indices
+        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+        src_idx = torch.cat([src for (src, _) in indices])
+        return batch_idx, src_idx
+
+
+    def _get_tgt_permutation_idx(self, indices):
+        # permute targets following indices
+        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+        return batch_idx, tgt_idx
+
+
+    def loss_labels(self, outputs, targets, indices, num_boxes):
+        """Classification loss (NLL)
+        targets dicts must contain the key "labels" containing a tensor of dim [nb_target_boxes]
+        """
+        assert 'pred_logits' in outputs
+        src_logits = outputs['pred_logits']
+
+        idx = self._get_src_permutation_idx(indices)
+        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)]).to(src_logits.device)
+        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
+                                    dtype=torch.int64, device=src_logits.device)
+        target_classes[idx] = target_classes_o
+
+        target_classes_onehot = torch.zeros([src_logits.shape[0], src_logits.shape[1], src_logits.shape[2] + 1],
+                                            dtype=src_logits.dtype, layout=src_logits.layout, device=src_logits.device)
+        target_classes_onehot.scatter_(2, target_classes.unsqueeze(-1), 1)
+
+        target_classes_onehot = target_classes_onehot[:, :, :-1]
+        loss_cls = sigmoid_focal_loss(src_logits, target_classes_onehot, num_boxes, alpha=self.focal_alpha, gamma=2) * \
+                  src_logits.shape[1]
+        losses = {'loss_cls': loss_cls}
+
+        return losses
+
+
+    def loss_boxes(self, outputs, targets, indices, num_boxes):
+        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
+        """
+        assert 'pred_boxes' in outputs
+        idx = self._get_src_permutation_idx(indices)
+        src_boxes = outputs['pred_boxes'][idx]
+        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0).to(src_boxes.device)
+
+        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
+
+        losses = {}
+        losses['loss_bbox'] = loss_bbox.sum() / num_boxes
+
+        loss_giou = 1 - torch.diag(generalized_box_iou(
+            box_cxcywh_to_xyxy(src_boxes),
+            box_cxcywh_to_xyxy(target_boxes)))
+        losses['loss_giou'] = loss_giou.sum() / num_boxes
+        return losses
+
+
+    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
+        loss_map = {
+            'labels': self.loss_labels,
+            'boxes': self.loss_boxes,
+        }
+        assert loss in loss_map, f'do you really want to compute {loss} loss?'
+        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
+
+
+    def forward(self, outputs, targets):
+        """ This performs the loss computation.
+        Parameters:
+             outputs: dict of tensors, see the output specification of the model for the format
+             targets: list of dicts, such that len(targets) == batch_size.
+                      The expected keys in each dict depends on the losses applied, see each loss' doc
+        """
+        outputs_without_aux = {k: v for k, v in outputs.items() if k != 'aux_outputs'}
+
+        # Retrieve the matching between the outputs of the last layer and the targets
+        indices = self.matcher(outputs_without_aux, targets)
+
+        # Compute the average number of target boxes accross all nodes, for normalization purposes
+        num_boxes = sum(len(t["labels"]) for t in targets)
+        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_boxes)
+        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
+
+        # Compute all the requested losses
+        losses = {}
+        for loss in self.losses:
+            losses.update(self.get_loss(loss, outputs, targets, indices, num_boxes))
+
+        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+        if 'aux_outputs' in outputs:
+            for i, aux_outputs in enumerate(outputs['aux_outputs']):
+                indices = self.matcher(aux_outputs, targets)
+                for loss in self.losses:
+                    kwargs = {}
+                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes, **kwargs)
+                    l_dict = {k + f'_{i}': v for k, v in l_dict.items()}
+                    losses.update(l_dict)
+
+        weight_dict = self.weight_dict
+        total_loss = sum(losses[k] * weight_dict[k] for k in losses.keys() if k in weight_dict)
+        losses['losses'] = total_loss
+
+        return losses
+
+
+# build criterion
+def build_criterion(cfg, num_classes, aux_loss=False):
+    matcher = build_matcher(cfg)
+    
+    weight_dict = {'loss_cls': cfg['loss_cls_weight'],
+                  'loss_bbox': cfg['loss_box_weight'],
+                  'loss_giou': cfg['loss_giou_weight']}
+
+    # TODO this is a hack
+    if aux_loss:
+        aux_weight_dict = {}
+        for i in range(cfg['num_decoder_layers'] - 1):
+            aux_weight_dict.update({k + f'_{i}': v for k, v in weight_dict.items()})
+        weight_dict.update(aux_weight_dict)
+
+    losses = ['labels', 'boxes']
+    
+    criterion = Criterion(
+        num_classes=num_classes,
+        matcher=matcher,
+        weight_dict=weight_dict,
+        losses=losses,
+        focal_alpha=cfg['focal_alpha'])
+
+    return criterion
+    

+ 102 - 0
models/detectors/rtdetr/matcher.py

@@ -0,0 +1,102 @@
+import torch
+import torch.nn as nn
+from scipy.optimize import linear_sum_assignment
+from utils.box_ops import box_cxcywh_to_xyxy, generalized_box_iou
+
+
+class HungarianMatcher(nn.Module):
+    """This class computes an assignment between the targets and the predictions of the network
+    For efficiency reasons, the targets don't include the no_object. Because of this, in general,
+    there are more predictions than targets. In this case, we do a 1-to-1 matching of the best predictions,
+    while the others are un-matched (and thus treated as non-objects).
+    """
+
+    def __init__(self, cost_class: float = 1, cost_bbox: float = 1, cost_giou: float = 1):
+        """Creates the matcher
+        Params:
+            cost_class: This is the relative weight of the classification error in the matching cost
+            cost_bbox: This is the relative weight of the L1 error of the bounding box coordinates in the matching cost
+            cost_giou: This is the relative weight of the giou loss of the bounding box in the matching cost
+        """
+        super().__init__()
+        self.cost_class = cost_class
+        self.cost_bbox = cost_bbox
+        self.cost_giou = cost_giou
+        assert cost_class != 0 or cost_bbox != 0 or cost_giou != 0, "all costs cant be 0"
+
+
+    @torch.no_grad()
+    def forward(self, outputs, targets):
+        """ Performs the matching
+        Params:
+            outputs: This is a dict that contains at least these entries:
+                 "pred_logits": Tensor of dim [batch_size, num_queries, num_classes] with the classification logits
+                 "pred_boxes": Tensor of dim [batch_size, num_queries, 4] with the predicted box coordinates
+            targets: This is a list of targets (len(targets) = batch_size), where each target is a dict containing:
+                 "labels": Tensor of dim [num_target_boxes] (where num_target_boxes is the number of ground-truth
+                           objects in the target) containing the class labels
+                 "boxes": Tensor of dim [num_target_boxes, 4] containing the target box coordinates
+        Returns:
+            A list of size batch_size, containing tuples of (index_i, index_j) where:
+                - index_i is the indices of the selected predictions (in order)
+                - index_j is the indices of the corresponding selected targets (in order)
+            For each batch element, it holds:
+                len(index_i) = len(index_j) = min(num_queries, num_target_boxes)
+        """
+        bs, num_queries = outputs["pred_logits"].shape[:2]
+
+        # We flatten to compute the cost matrices in a batch
+        # [B * num_queries, C] = [N, C], where N is B * num_queries
+        out_prob = outputs["pred_logits"].flatten(0, 1).sigmoid()
+        # [B * num_queries, 4] = [N, 4]
+        out_bbox = outputs["pred_boxes"].flatten(0, 1)
+
+        # Also concat the target labels and boxes
+        # [M,] where M is number of all targets in this batch
+        tgt_ids = torch.cat([v["labels"] for v in targets])
+        # [M, 4] where M is number of all targets in this batch
+        tgt_bbox = torch.cat([v["boxes"] for v in targets])
+
+        # Compute the classification cost.
+        alpha = 0.25
+        gamma = 2.0
+        neg_cost_class = (1 - alpha) * (out_prob ** gamma) * (-(1 - out_prob + 1e-8).log())
+        pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-(out_prob + 1e-8).log())
+        cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids]
+
+        # Compute the L1 cost between boxes
+        # [N, M]
+        cost_bbox = torch.cdist(out_bbox, tgt_bbox.to(out_bbox.device), p=1)
+
+        # Compute the giou cost betwen boxes
+        # [N, M]
+        cost_giou = -generalized_box_iou(
+            box_cxcywh_to_xyxy(out_bbox),
+            box_cxcywh_to_xyxy(tgt_bbox.to(out_bbox.device)))
+
+        # Final cost matrix: [N, M]
+        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
+        # [N, M] -> [B, num_queries, M]
+        C = C.view(bs, num_queries, -1).cpu()
+
+        # The number of boxes in each image
+        sizes = [len(v["boxes"]) for v in targets]
+        # In the last dimension of C, we divide it into B costs, and each cost is [B, num_querys, M_i]
+        # where sum(Mi) = M.
+        # i is the batch index and c is cost_i = [B, num_querys, M_i].
+        # Therefore c[i] is the cost between the i-th sample and i-th prediction.
+        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
+        # As for each (i, j) in indices, i is the prediction indexes and j is the target indexes
+        # i contains row indexes of cost matrix: array([row_1, row_2, row_3]) 
+        # j contains col indexes of cost matrix: array([col_1, col_2, col_3])
+        # len(i) == len(j)
+        # len(indices) = batch_size
+        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]
+
+
+def build_matcher(cfg):
+    return HungarianMatcher(
+        cost_class=cfg['set_cost_class'],
+        cost_bbox=cfg['set_cost_bbox'],
+        cost_giou=cfg['set_cost_giou']
+        )

+ 107 - 0
models/detectors/rtdetr/rtdetr.py

@@ -0,0 +1,107 @@
+import torch
+import torch.nn as nn
+
+from .rtdetr_encoder import build_encoder
+from .rtdetr_decoder import build_decoder
+from .rtdetr_dethead import build_dethead
+
+
+# Real-time DETR
+class RTDETR(nn.Module):
+    def __init__(self, 
+                 cfg,
+                 device, 
+                 num_classes = 20, 
+                 trainable = False, 
+                 aux_loss = False,
+                 with_box_refine = False,
+                 deploy = False):
+        super(RTDETR, self).__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.device = device
+        self.num_classes = num_classes
+        self.trainable = trainable
+        self.max_stride = max(cfg['stride'])
+        self.d_model = round(cfg['d_model'] * self.cfg['width'])
+        self.aux_loss = aux_loss
+        self.with_box_refine = with_box_refine
+        self.deploy = deploy
+        
+        # --------- Network Parameters ----------
+        ## Encoder
+        self.img_encoder = build_encoder(cfg, trainable, 'img_encoder')
+
+        ## Decoder
+        self.decoder = build_decoder(cfg, self.d_model, return_intermediate=aux_loss)
+
+        ## DetHead
+        self.dethead = build_dethead(cfg, self.d_model, num_classes, with_box_refine)
+            
+        # set for TR-Decoder
+        self.decoder.class_embed = self.dethead.class_embed
+        self.decoder.bbox_embed = self.dethead.bbox_embed
+
+
+    # ---------------------- Basic Functions ----------------------
+    @torch.jit.unused
+    def set_aux_loss(self, outputs_class, outputs_coord):
+        # this is a workaround to make torchscript happy, as torchscript
+        # doesn't support dictionary with non-homogeneous values, such
+        # as a dict having both a Tensor and a list.
+        return [{'pred_logits': a, 'pred_boxes': b}
+                for a, b in zip(outputs_class[:-1], outputs_coord[:-1])]
+
+
+    # ---------------------- Main Process for Inference ----------------------
+    @torch.no_grad()
+    def inference_single_image(self, x):
+        # -------------------- Encoder --------------------
+        memory, memory_pos = self.img_encoder(x)
+
+        # -------------------- Decoder --------------------
+        hs, reference = self.decoder(memory, memory_pos)
+
+        # -------------------- DetHead --------------------
+        out_logits, out_bbox = self.dethead(hs, reference, False)
+
+        # -------------------- Decode bbox --------------------
+        cls_pred = out_logits[0]
+        box_pred = out_bbox[0]
+        x1y1_pred = box_pred[..., :2] - box_pred[..., 2:] * 0.5
+        x2y2_pred = box_pred[..., :2] + box_pred[..., 2:] * 0.5
+        box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+        # -------------------- Top-k --------------------
+        cls_pred = cls_pred.flatten().sigmoid_()
+        num_topk = 100
+        predicted_prob, topk_idxs = cls_pred.sort(descending=True)
+        topk_idxs = topk_idxs[:num_topk]
+        topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+        topk_scores = predicted_prob[:num_topk]
+        topk_labels = topk_idxs % self.num_classes
+        topk_bboxes = box_pred[topk_box_idxs]
+
+        return topk_bboxes, topk_scores, topk_labels
+        
+
+    # ---------------------- Main Process for Training ----------------------
+    def forward(self, x):
+        if not self.trainable:
+            return self.inference_single_image(x)
+        else:
+            # -------------------- Encoder --------------------
+            memory, memory_pos = self.img_encoder(x)
+
+            # -------------------- Decoder --------------------
+            hs, reference = self.decoder(memory, memory_pos)
+
+            # -------------------- DetHead --------------------
+            outputs_class, outputs_coords = self.dethead(hs, reference, True)
+
+            outputs = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coords[-1]}
+            if self.aux_loss:
+                outputs['aux_outputs'] = self.set_aux_loss(outputs_class, outputs_coords)
+            
+            return outputs
+    

+ 221 - 0
models/detectors/rtdetr/rtdetr_basic.py

@@ -0,0 +1,221 @@
+import copy
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import nn, Tensor
+
+
+# ------------------------------- Basic Modules -------------------------------
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'gelu':
+        return nn.GELU()
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+
+
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type == 'LN':
+        return nn.LayerNorm(dim)
+
+
+def get_clones(module, N):
+    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
+    
+
+def build_multi_head_attention(d_model, num_heads, dropout, attn_type='mhsa'):
+    if attn_type == 'mhsa':
+        attn_layer = MultiHeadAttention(d_model, num_heads, dropout)
+    elif attn_type == 's_mhsa':
+        attn_layer = None
+
+    return attn_layer
+
+
+# ------------------------------- MLP -------------------------------
+class MLP(nn.Module):
+    """ Very simple multi-layer perceptron (also called FFN)"""
+
+    def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        return x
+
+
+# ------------------------------- Transformer Modules -------------------------------
+## Vanilla Multi-Head Attention
+class MultiHeadAttention(nn.Module):
+    def __init__(self, d_model, num_heads, dropout=0.) -> None:
+        super().__init__()
+        # --------------- Basic parameters ---------------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.dropout = dropout
+        self.scale = (d_model // num_heads) ** -0.5
+
+        # --------------- Network parameters ---------------
+        self.q_proj = nn.Linear(d_model, d_model, bias = False) # W_q, W_k, W_v
+        self.k_proj = nn.Linear(d_model, d_model, bias = False) # W_q, W_k, W_v
+        self.v_proj = nn.Linear(d_model, d_model, bias = False) # W_q, W_k, W_v
+
+        self.out_proj = nn.Linear(d_model, d_model)
+        self.dropout = nn.Dropout(dropout)
+
+
+    def forward(self, query, key, value):
+        """
+        Inputs:
+            query : (Tensor) -> [B, Nq, C]
+            key   : (Tensor) -> [B, Nk, C]
+            value : (Tensor) -> [B, Nk, C]
+        """
+        bs = query.shape[0]
+        Nq = query.shape[1]
+        Nk = key.shape[1]
+
+        # ----------------- Input proj -----------------
+        query = self.q_proj(query)
+        key   = self.k_proj(key)
+        value = self.v_proj(value)
+
+        # ----------------- Multi-head Attn -----------------
+        ## [B, N, C] -> [B, N, H, C_h] -> [B, H, N, C_h]
+        query = query.view(bs, Nq, self.num_heads, self.d_model // self.num_heads)
+        query = query.permute(0, 2, 1, 3).contiguous()
+        key   = key.view(bs, Nk, self.num_heads, self.d_model // self.num_heads)
+        key   = key.permute(0, 2, 1, 3).contiguous()
+        value = value.view(bs, Nk, self.num_heads, self.d_model // self.num_heads)
+        value = value.permute(0, 2, 1, 3).contiguous()
+        # Attention
+        ## [B, H, Nq, C_h] X [B, H, C_h, Nk] = [B, H, Nq, Nk]
+        sim_matrix = torch.matmul(query, key.transpose(-1, -2)) * self.scale
+        sim_matrix = torch.softmax(sim_matrix, dim=-1)
+
+        # ----------------- Output -----------------
+        out = torch.matmul(sim_matrix, value)  # [B, H, Nq, C_h]
+        out = out.permute(0, 2, 1, 3).contiguous().view(bs, Nq, -1)
+        out = self.out_proj(out)
+
+        return out
+        
+## Transformer Encoder layer
+class TREncoderLayer(nn.Module):
+    def __init__(self,
+                 d_model,
+                 num_heads,
+                 dim_feedforward=2048,
+                 dropout=0.1,
+                 act_type="relu",
+                 attn_type='mhsa'
+                 ):
+        super().__init__()
+        # Multi-head Self-Attn
+        self.self_attn = build_multi_head_attention(d_model, num_heads, dropout, attn_type)
+
+        # Feedforwaed Network
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+
+        self.activation = get_activation(act_type)
+
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+
+    def forward(self, src, pos):
+        """
+        Input:
+            src: [torch.Tensor] -> [B, N, C]
+            pos: [torch.Tensor] -> [B, N, C]
+        Output:
+            src: [torch.Tensor] -> [B, N, C]
+        """
+        q = k = self.with_pos_embed(src, pos)
+
+        # self-attn
+        src2 = self.self_attn(q, k, value=src)
+
+        # reshape: [B, N, C] -> [B, C, H, W]
+        src = src + self.dropout1(src2)
+        src = self.norm1(src)
+
+        # ffpn
+        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
+        src = src + self.dropout2(src2)
+        src = self.norm2(src)
+        
+        return src
+
+## Transformer Decoder layer
+class TRDecoderLayer(nn.Module):
+    def __init__(self, d_model, num_heads, dim_feedforward=2048, dropout=0.1, act_type="relu", attn_type='mhsa'):
+        super().__init__()
+        # Multi-head Self-Attn
+        self.self_attn = build_multi_head_attention(d_model, num_heads, dropout, attn_type)
+        self.cross_attn = build_multi_head_attention(d_model, num_heads, dropout)
+        # Feedforward Network
+        self.linear1 = nn.Linear(d_model, dim_feedforward)
+        self.dropout = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(dim_feedforward, d_model)
+
+        self.norm1 = nn.LayerNorm(d_model)
+        self.norm2 = nn.LayerNorm(d_model)
+        self.norm3 = nn.LayerNorm(d_model)
+        self.dropout1 = nn.Dropout(dropout)
+        self.dropout2 = nn.Dropout(dropout)
+        self.dropout3 = nn.Dropout(dropout)
+
+        self.activation = get_activation(act_type)
+
+
+    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
+        return tensor if pos is None else tensor + pos
+
+
+    def forward(self, tgt, tgt_query_pos, memory, memory_pos):
+        # self attention
+        tgt2 = self.self_attn(
+            query=self.with_pos_embed(tgt, tgt_query_pos),
+            key=self.with_pos_embed(tgt, tgt_query_pos),
+            value=tgt)[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+
+        # cross attention
+        tgt2 = self.cross_attn(
+            query=self.with_pos_embed(tgt, tgt_query_pos),
+            key=self.with_pos_embed(memory, memory_pos),
+            value=memory)
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+
+        # ffn
+        tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
+        tgt = tgt + self.dropout3(tgt2)
+        tgt = self.norm3(tgt)
+        
+        return tgt

+ 122 - 0
models/detectors/rtdetr/rtdetr_decoder.py

@@ -0,0 +1,122 @@
+import torch
+import torch.nn as nn
+
+from .rtdetr_basic import get_clones, TRDecoderLayer, MLP
+
+
+# Transformer Decoder Module
+class TransformerDecoder(nn.Module):
+    def __init__(self, cfg, in_dim, return_intermediate=False):
+        super().__init__()
+        self.d_model = in_dim
+        self.query_dim = 4
+        self.scale = 2 * 3.141592653589793
+        self.num_queries = cfg['num_queries']
+        self.num_deocder_layers = cfg['num_decoder_layers']
+        self.return_intermediate = return_intermediate
+
+        # -------------------- Network Parameters ---------------------
+        ## Decoder
+        decoder_layer = TRDecoderLayer(
+            d_model=in_dim,
+            num_heads=cfg['de_num_heads'],
+            dim_feedforward=cfg['de_dim_feedforward'],
+            dropout=cfg['de_dropout'],
+            act_type=cfg['de_act']
+        )
+        self.decoder_layers = get_clones(decoder_layer, cfg['num_decoder_layers'])
+        ## RefPoint Embed
+        self.refpoint_embed = nn.Embedding(cfg['num_queries'], 4)
+        ## Object Query Embed
+        self.object_query = nn.Embedding(cfg['num_queries'], in_dim)
+        nn.init.normal_(self.object_query.weight.data)
+        ## TODO: Group queries
+
+        self.ref_point_head = MLP(self.query_dim // 2 * in_dim, in_dim, in_dim, 2)
+        self.query_pos_sine_scale = MLP(in_dim, in_dim, in_dim, 2)
+        self.ref_anchor_head = MLP(in_dim, in_dim, 2, 2)
+
+        self.bbox_embed = None
+        self.class_embed = None
+
+
+    def query_sine_embed(self, num_feats, reference_points):
+        dim_t = torch.arange(num_feats, dtype=torch.float32, device=reference_points.device)
+        dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_feats
+        dim_t = 10000 ** (2 * dim_t_)
+
+        x_embed = reference_points[:, :, 0] * self.scale
+        y_embed = reference_points[:, :, 1] * self.scale
+        pos_x = x_embed[:, :, None] / dim_t
+        pos_y = y_embed[:, :, None] / dim_t
+        pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
+        pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
+        w_embed = reference_points[:, :, 2] * self.scale
+        pos_w = w_embed[:, :, None] / dim_t
+        pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
+
+        h_embed = reference_points[:, :, 3] * self.scale
+        pos_h = h_embed[:, :, None] / dim_t
+        pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
+        query_sine_embed = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
+
+        return query_sine_embed
+    
+
+    def forward(self, memory, memory_pos):
+        bs, _, channels = memory.size()
+        num_feats = channels // 2
+
+        # prepare tgt & refpoint
+        tgt = self.object_query.weight[None].repeat(bs, 1, 1)
+        refpoint_embed = self.refpoint_embed.weight[None].repeat(bs, 1, 1)
+
+        intermediate = []
+        reference_points = refpoint_embed.sigmoid()
+        ref_points = [reference_points]
+
+        # main process
+        output = tgt
+        for layer_id, layer in enumerate(self.decoder_layers):
+            # query sine embed
+            query_sine_embed = self.query_sine_embed(num_feats, reference_points)
+
+            # conditional query
+            query_pos = self.ref_point_head(query_sine_embed) # [B, N, C]
+
+            # decoder
+            output = layer(
+                    # input for decoder
+                    tgt = output,
+                    tgt_query_pos = query_pos,
+                    # input from encoder
+                    memory = memory,
+                    memory_pos = memory_pos,
+                )
+
+            # iter update
+            if self.bbox_embed is not None:
+                # --------------- Start inverse_sigmoid ---------------
+                reference_points = reference_points.clamp(min=0, max=1)
+                reference_points_1 = reference_points.clamp(min=1e-5)
+                reference_points_2 = (1 - reference_points).clamp(min=1e-5)
+                reference_before_sigmoid = torch.log(reference_points_1/reference_points_2)
+                # --------------- End inverse_sigmoid ---------------
+
+                delta_unsig = self.bbox_embed[layer_id](output)
+                outputs_unsig = delta_unsig + reference_before_sigmoid
+                new_reference_points = outputs_unsig.sigmoid()
+
+                reference_points = new_reference_points.detach()
+                ref_points.append(new_reference_points)
+
+            intermediate.append(output)
+
+        return torch.stack(intermediate), torch.stack(ref_points)
+
+
+# build detection head
+def build_decoder(cfg, in_dim, return_intermediate=False):
+    decoder = TransformerDecoder(cfg, in_dim, return_intermediate=return_intermediate) 
+
+    return decoder

+ 77 - 0
models/detectors/rtdetr/rtdetr_dethead.py

@@ -0,0 +1,77 @@
+import torch
+import torch.nn as nn
+
+from .rtdetr_basic import MLP
+
+
+class DetectHead(nn.Module):
+    def __init__(self, cfg, d_model, num_classes, with_box_refine=False):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.num_classes = num_classes
+
+        # --------- Network Parameters ----------
+        self.class_embed = nn.ModuleList([nn.Linear(d_model, self.num_classes)])
+        self.bbox_embed = nn.ModuleList([MLP(d_model, d_model, 4, 3)])
+        if with_box_refine:
+            self.class_embed = nn.ModuleList([
+                self.class_embed[0] for _ in range(cfg['num_decoder_layers'])])
+            self.bbox_embed = nn.ModuleList([
+                self.bbox_embed[0] for _ in range(cfg['num_decoder_layers'])])
+
+        self.init_weight()
+
+
+    def init_weight(self):
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+
+        # cls pred
+        for class_embed in self.class_embed:
+            class_embed.bias.data = torch.ones(self.num_classes) * bias_value
+
+        # box pred
+        for bbox_embed in self.bbox_embed:
+            nn.init.constant_(bbox_embed.layers[-1].weight.data, 0)
+            nn.init.constant_(bbox_embed.layers[-1].bias.data, 0)
+        
+
+    def forward(self, hs, reference, multi_layer=False):
+        if multi_layer:
+            ## class embed
+            outputs_class = torch.stack([layer_cls_embed(layer_hs) for
+                                        layer_cls_embed, layer_hs in zip(self.class_embed, hs)])
+            ## Bbox embed
+            outputs_coords = []
+            for dec_lid, (layer_ref_sig, layer_bbox_embed, layer_hs) in enumerate(zip(reference[:-1], self.bbox_embed, hs)):
+                layer_delta_unsig = layer_bbox_embed(layer_hs)
+                # ---------- start <inverse sigmoid> ----------
+                layer_ref_sig = layer_ref_sig.clamp(min=0, max=1)
+                layer_ref_sig_1 = layer_ref_sig.clamp(min=1e-5)
+                layer_ref_sig_2 = (1 - layer_ref_sig).clamp(min=1e-5)
+                layer_ref_sig = torch.log(layer_ref_sig_1/layer_ref_sig_2)
+                # ---------- end <inverse sigmoid> ----------
+                layer_outputs_unsig = layer_delta_unsig + layer_ref_sig
+                layer_outputs_unsig = layer_outputs_unsig.sigmoid()
+                outputs_coords.append(layer_outputs_unsig)
+        else:
+            ## class embed
+            outputs_class = self.class_embed[-1](hs[-1]) 
+            ## bbox embed
+            delta_unsig = self.bbox_embed[-1](hs[-1])
+            ref_sig = reference[-2]
+            ## ---------- start <inverse sigmoid> ----------
+            ref_sig = ref_sig.clamp(min=0, max=1)
+            ref_sig_1 = ref_sig.clamp(min=1e-5)
+            ref_sig_2 = (1 - ref_sig).clamp(min=1e-5)
+            ref_sig = torch.log(ref_sig_1/ref_sig_2)
+            ## ---------- end <inverse sigmoid> ----------
+            outputs_unsig = delta_unsig + ref_sig
+            outputs_coords = outputs_unsig.sigmoid()
+
+        return outputs_class, outputs_coords
+
+
+def build_dethead(cfg, d_model, num_classes, with_box_refine):
+    return DetectHead(cfg, d_model, num_classes, with_box_refine)

+ 10 - 0
models/detectors/rtdetr/rtdetr_encoder.py

@@ -0,0 +1,10 @@
+from .image_encoder.img_encoder import build_img_encoder
+
+
+# build encoder
+def build_encoder(cfg, trainable=False, en_type='img_encoder'):
+    if en_type == 'img_encoder':
+        return build_img_encoder(cfg, trainable)
+    elif en_type == 'text_encoder':
+        ## TODO: design text encoder
+        return None

+ 0 - 1
models/detectors/yolox2/yolox2.py

@@ -1,7 +1,6 @@
 # --------------- Torch components ---------------
 import torch
 import torch.nn as nn
-import torch.nn.functional as F
 
 # --------------- Model components ---------------
 from .yolox2_backbone import build_backbone

+ 0 - 0
models/tracker/__init__.py → models/trackers/__init__.py


+ 0 - 0
models/tracker/byte_tracker/basetrack.py → models/trackers/byte_tracker/basetrack.py


+ 0 - 0
models/tracker/byte_tracker/build.py → models/trackers/byte_tracker/build.py


+ 0 - 0
models/tracker/byte_tracker/byte_tracker.py → models/trackers/byte_tracker/byte_tracker.py


+ 0 - 0
models/tracker/byte_tracker/kalman_filter.py → models/trackers/byte_tracker/kalman_filter.py


+ 0 - 0
models/tracker/byte_tracker/matching.py → models/trackers/byte_tracker/matching.py


+ 2 - 2
track.py

@@ -7,7 +7,7 @@ import numpy as np
 
 import torch
 
-from dataset.data_augment import build_transform
+from dataset.build import build_transform
 from utils.vis_tools import plot_tracking
 from utils.misc import load_weight
 from utils.box_ops import rescale_bboxes
@@ -15,7 +15,7 @@ from utils.box_ops import rescale_bboxes
 from config import build_model_config, build_trans_config
 
 from models.detectors import build_model
-from models.tracker import build_tracker
+from models.trackers import build_tracker
 
 os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
 IMAGE_EXT = [".jpg", ".jpeg", ".webp", ".bmp", ".png"]

+ 16 - 96
train.py

@@ -12,26 +12,15 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 # ----------------- Extra Components -----------------
 from utils import distributed_utils
 from utils.misc import compute_flops
-from utils.misc import ModelEMA, CollateFunc, build_dataloader
-
-# ----------------- Evaluator Components -----------------
-from evaluator.build import build_evluator
-
-# ----------------- Optimizer & LrScheduler Components -----------------
-from utils.solver.optimizer import build_optimizer
-from utils.solver.lr_scheduler import build_lr_scheduler
 
 # ----------------- Config Components -----------------
 from config import build_dataset_config, build_model_config, build_trans_config
 
-# ----------------- Dataset Components -----------------
-from dataset.build import build_dataset, build_transform
-
 # ----------------- Model Components -----------------
 from models.detectors import build_model
 
 # ----------------- Train Components -----------------
-from engine import Trainer
+from engine import build_trainer
 
 
 def parse_args():
@@ -53,6 +42,8 @@ def parse_args():
                         help="Adopting mix precision training.")
     parser.add_argument('--vis_tgt', action="store_true", default=False,
                         help="visualize training data.")
+    parser.add_argument('--vis_aux_loss', action="store_true", default=False,
+                        help="visualize aux loss.")
     
     # Batchsize
     parser.add_argument('-bs', '--batch_size', default=16, type=int, 
@@ -65,8 +56,6 @@ def parse_args():
                         help='warmup epoch.')
     parser.add_argument('--eval_epoch', default=10, type=int, 
                         help='after eval epoch, the model is evaluated on val dataset.')
-    parser.add_argument('--step_epoch', nargs='+', default=[90, 120], type=int,
-                        help='lr epoch to decay')
 
     # model
     parser.add_argument('-m', '--model', default='yolov1', type=str,
@@ -118,15 +107,14 @@ def train():
     print("Setting Arguments.. : ", args)
     print("----------------------------------------------------------")
 
-    # ---------------------------- Build DDP ----------------------------
+    # Build DDP
     world_size = distributed_utils.get_world_size()
-    per_gpu_batch = args.batch_size // world_size
     print('World size: {}'.format(world_size))
     if args.distributed:
         distributed_utils.init_distributed_mode(args)
         print("git:\n  {}\n".format(distributed_utils.get_sha()))
 
-    # ---------------------------- Build CUDA ----------------------------
+    # Build CUDA
     if args.cuda:
         print('use cuda')
         # cudnn.benchmark = True
@@ -134,38 +122,23 @@ def train():
     else:
         device = torch.device("cpu")
 
-    # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
+    # Build Dataset & Model & Trans. Config
     data_cfg = build_dataset_config(args)
     model_cfg = build_model_config(args)
     trans_cfg = build_trans_config(model_cfg['trans_type'])
 
-    # ---------------------------- Build Transform ----------------------------
-    train_transform, trans_cfg = build_transform(
-        args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
-    val_transform, _ = build_transform(
-        args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
-
-    # ---------------------------- Build Dataset & Dataloader ----------------------------
-    dataset, dataset_info = build_dataset(args, data_cfg, trans_cfg, train_transform, is_train=True)
-    train_loader = build_dataloader(args, dataset, per_gpu_batch, CollateFunc())
-
-    # ---------------------------- Build Evaluator ----------------------------
-    evaluator = build_evluator(args, data_cfg, val_transform, device)
-
-    # ---------------------------- Build Model ----------------------------
-    model, criterion = build_model(args, model_cfg, device, dataset_info['num_classes'], True)
+    # Build Model
+    model, criterion = build_model(args, model_cfg, device, data_cfg['num_classes'], True)
     model = model.to(device).train()
+    model_without_ddp = model
     if args.sybn and args.distributed:
         print('use SyncBatchNorm ...')
         model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
-
-    # ---------------------------- Build DDP Model ----------------------------
-    model_without_ddp = model
     if args.distributed:
         model = DDP(model, device_ids=[args.gpu])
         model_without_ddp = model.module
 
-    # ---------------------------- Calcute Params & GFLOPs ----------------------------
+    # Calcute Params & GFLOPs
     if distributed_utils.is_main_process:
         model_copy = deepcopy(model_without_ddp)
         model_copy.trainable = False
@@ -175,74 +148,21 @@ def train():
                       device=device)
         del model_copy
     if args.distributed:
-        # wait for all processes to synchronize
-        dist.barrier()
         dist.barrier()
 
-    # ---------------------------- Build Grad. Scaler ----------------------------
-    scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
-
-    # ---------------------------- Build Optimizer ----------------------------
-    accumulate = max(1, round(64 / args.batch_size))
-    print('Grad_Accumulate: ', accumulate)
-    model_cfg['weight_decay'] *= args.batch_size * accumulate / 64
-    optimizer, start_epoch = build_optimizer(model_cfg, model_without_ddp, model_cfg['lr0'], args.resume)
-
-    # ---------------------------- Build LR Scheduler ----------------------------
-    args.max_epoch += args.wp_epoch
-    lr_scheduler, lf = build_lr_scheduler(model_cfg, optimizer, args.max_epoch)
-    lr_scheduler.last_epoch = start_epoch - 1  # do not move
-    if args.resume:
-        lr_scheduler.step()
-
-    # ---------------------------- Build Model-EMA ----------------------------
-    if args.ema and distributed_utils.get_rank() in [-1, 0]:
-        print('Build ModelEMA ...')
-        model_ema = ModelEMA(model, model_cfg['ema_decay'], model_cfg['ema_tau'], start_epoch * len(train_loader))
-    else:
-        model_ema = None
-
-    # ---------------------------- Build Trainer ----------------------------
-    trainer = Trainer(args, device, model_cfg, model_ema, optimizer, lf, lr_scheduler, criterion, scaler)
+    # Build Trainer
+    trainer = build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model_without_ddp, criterion)
 
-    # start training loop
-    heavy_eval = False
-    optimizer.zero_grad()
-    
-    # --------------------------------- Main process for training ---------------------------------
+    # --------------------------------- Start ---------------------------------
     ## Eval before training
     if args.eval_first and distributed_utils.is_main_process():
         # to check whether the evaluator can work
         model_eval = model_without_ddp
-        trainer.eval_one_epoch(model_eval, evaluator)
+        trainer.eval_one_epoch(model_eval)
 
     ## Satrt Training
-    for epoch in range(start_epoch, args.max_epoch):
-        if args.distributed:
-            train_loader.batch_sampler.sampler.set_epoch(epoch)
-
-        # check second stage
-        if epoch >= (args.max_epoch - model_cfg['no_aug_epoch'] - 1):
-            # close mosaic augmentation
-            if train_loader.dataset.mosaic_prob > 0.:
-                print('close Mosaic Augmentation ...')
-                train_loader.dataset.mosaic_prob = 0.
-                heavy_eval = True
-            # close mixup augmentation
-            if train_loader.dataset.mixup_prob > 0.:
-                print('close Mixup Augmentation ...')
-                train_loader.dataset.mixup_prob = 0.
-                heavy_eval = True
-
-        # train one epoch
-        trainer.train_one_epoch(model, train_loader)
-
-        # eval one epoch
-        if heavy_eval:
-            trainer.eval_one_epoch(model_without_ddp, evaluator)
-        else:
-            if (epoch % args.eval_epoch) == 0 or (epoch == args.max_epoch - 1):
-                trainer.eval_one_epoch(model_without_ddp, evaluator)
+    trainer.train(model)
+    # --------------------------------- End ---------------------------------
 
     # Empty cache after train loop
     del trainer

+ 72 - 34
utils/box_ops.py

@@ -4,7 +4,54 @@ import numpy as np
 from torchvision.ops.boxes import box_area
 
 
-# modified from torchvision to also return the union
+# ------------------ Box ops ------------------
+def box_cxcywh_to_xyxy(x):
+    x_c, y_c, w, h = x.unbind(-1)
+    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
+         (x_c + 0.5 * w), (y_c + 0.5 * h)]
+    return torch.stack(b, dim=-1)
+
+
+def box_xyxy_to_cxcywh(x):
+    x0, y0, x1, y1 = x.unbind(-1)
+    b = [(x0 + x1) / 2, (y0 + y1) / 2,
+         (x1 - x0), (y1 - y0)]
+    return torch.stack(b, dim=-1)
+
+
+def rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas=None):
+    origin_h, origin_w = origin_img_size
+    cur_img_h, cur_img_w = cur_img_size
+    if deltas is None:
+        # rescale
+        bboxes[..., [0, 2]] = bboxes[..., [0, 2]] / cur_img_w * origin_w
+        bboxes[..., [1, 3]] = bboxes[..., [1, 3]] / cur_img_h * origin_h
+
+        # clip bboxes
+        bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_w)
+        bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_h)
+    else:
+        # rescale
+        bboxes[..., [0, 2]] = bboxes[..., [0, 2]] / (cur_img_w - deltas[0]) * origin_w
+        bboxes[..., [1, 3]] = bboxes[..., [1, 3]] / (cur_img_h - deltas[1]) * origin_h
+        
+        # clip bboxes
+        bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_w)
+        bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_h)
+
+    return bboxes
+
+
+def bbox2dist(anchor_points, bbox, reg_max):
+    '''Transform bbox(xyxy) to dist(ltrb).'''
+    x1y1, x2y2 = torch.split(bbox, 2, -1)
+    lt = anchor_points - x1y1
+    rb = x2y2 - anchor_points
+    dist = torch.cat([lt, rb], -1).clamp(0, reg_max - 0.01)
+    return dist
+
+
+# ------------------ IoU ops ------------------
 def box_iou(boxes1, boxes2):
     area1 = box_area(boxes1)
     area2 = box_area(boxes2)
@@ -21,6 +68,30 @@ def box_iou(boxes1, boxes2):
     return iou, union
 
 
+def generalized_box_iou(boxes1, boxes2):
+    """
+    Generalized IoU from https://giou.stanford.edu/
+
+    The boxes should be in [x0, y0, x1, y1] format
+
+    Returns a [N, M] pairwise matrix, where N = len(boxes1)
+    and M = len(boxes2)
+    """
+    # degenerate boxes gives inf / nan results
+    # so do an early check
+    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+    iou, union = box_iou(boxes1, boxes2)
+
+    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+    wh = (rb - lt).clamp(min=0)  # [N,M,2]
+    area = wh[:, :, 0] * wh[:, :, 1]
+
+    return iou - (area - union) / area
+
+
 def get_ious(bboxes1,
              bboxes2,
              box_mode="xyxy",
@@ -74,39 +145,6 @@ def get_ious(bboxes1,
     else:
         raise NotImplementedError
 
-
-def rescale_bboxes(bboxes, origin_img_size, cur_img_size, deltas=None):
-    origin_h, origin_w = origin_img_size
-    cur_img_h, cur_img_w = cur_img_size
-    if deltas is None:
-        # rescale
-        bboxes[..., [0, 2]] = bboxes[..., [0, 2]] / cur_img_w * origin_w
-        bboxes[..., [1, 3]] = bboxes[..., [1, 3]] / cur_img_h * origin_h
-
-        # clip bboxes
-        bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_w)
-        bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_h)
-    else:
-        # rescale
-        bboxes[..., [0, 2]] = bboxes[..., [0, 2]] / (cur_img_w - deltas[0]) * origin_w
-        bboxes[..., [1, 3]] = bboxes[..., [1, 3]] / (cur_img_h - deltas[1]) * origin_h
-        
-        # clip bboxes
-        bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_w)
-        bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_h)
-
-    return bboxes
-
-
-def bbox2dist(anchor_points, bbox, reg_max):
-    '''Transform bbox(xyxy) to dist(ltrb).'''
-    x1y1, x2y2 = torch.split(bbox, 2, -1)
-    lt = anchor_points - x1y1
-    rb = x2y2 - anchor_points
-    dist = torch.cat([lt, rb], -1).clamp(0, reg_max - 0.01)
-    return dist
-
-
 # copy from YOLOv5
 def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
     # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)

+ 37 - 0
utils/misc.py

@@ -44,6 +44,43 @@ class CollateFunc(object):
         return images, targets
 
 
+# ---------------------------- For Loss ----------------------------
+## FocalLoss
+def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
+    """
+    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+        alpha: (optional) Weighting factor in range (0,1) to balance
+                positive vs negative examples. Default = -1 (no weighting).
+        gamma: Exponent of the modulating factor (1 - p_t) to
+               balance easy vs hard examples.
+    Returns:
+        Loss tensor
+    """
+    prob = inputs.sigmoid()
+    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+    p_t = prob * targets + (1 - prob) * (1 - targets)
+    loss = ce_loss * ((1 - p_t) ** gamma)
+
+    if alpha >= 0:
+        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+        loss = alpha_t * loss
+
+    return loss.mean(1).sum() / num_boxes
+
+## InverseSigmoid
+def inverse_sigmoid(x, eps=1e-5):
+    x = x.clamp(min=0, max=1)
+    x1 = x.clamp(min=eps)
+    x2 = (1 - x).clamp(min=eps)
+    return torch.log(x1/x2)
+
+
 # ---------------------------- For Model ----------------------------
 ## fuse Conv & BN layer
 def fuse_conv_bn(module):

+ 34 - 1
utils/solver/optimizer.py

@@ -2,7 +2,7 @@ import torch
 import torch.nn as nn
 
 
-def build_optimizer(cfg, model, base_lr=0.01, resume=None):
+def build_yolo_optimizer(cfg, model, base_lr=0.01, resume=None):
     print('==============================')
     print('Optimizer: {}'.format(cfg['optimizer']))
     print('--base lr: {}'.format(base_lr))
@@ -41,3 +41,36 @@ def build_optimizer(cfg, model, base_lr=0.01, resume=None):
         start_epoch = checkpoint.pop("epoch")
                                                         
     return optimizer, start_epoch
+
+
+def build_detr_optimizer(cfg, model, resume=None):
+    print('==============================')
+    print('Optimizer: {}'.format(cfg['optimizer']))
+    print('--base lr: {}'.format(cfg['lr0']))
+    print('--weight_decay: {}'.format(cfg['weight_decay']))
+
+    param_dicts = [
+        {"params": [p for n, p in model.named_parameters() if "backbone" not in n and p.requires_grad]},
+        {
+            "params": [p for n, p in model.named_parameters() if "backbone" in n and p.requires_grad],
+            "lr": cfg['lr0'] * 0.1,
+        },
+    ]
+
+    if cfg['optimizer'] == 'adam':
+        optimizer = torch.optim.Adam(param_dicts, lr=cfg['lr0'], weight_decay=cfg['weight_decay'])
+    elif cfg['optimizer'] == 'adamw':
+        optimizer = torch.optim.AdamW(param_dicts, lr=cfg['lr0'], weight_decay=cfg['weight_decay'])
+    else:
+        raise NotImplementedError('Optimizer {} not implemented.'.format(cfg['optimizer']))
+
+    start_epoch = 0
+    if resume is not None:
+        print('keep training: ', resume)
+        checkpoint = torch.load(resume)
+        # checkpoint state dict
+        checkpoint_state_dict = checkpoint.pop("optimizer")
+        optimizer.load_state_dict(checkpoint_state_dict)
+        start_epoch = checkpoint.pop("epoch")
+                                                        
+    return optimizer, start_epoch