yjh0410 1 rok pred
commit
e8084d351b
87 zmenil súbory, kde vykonal 12416 pridanie a 0 odobranie
  1. 16 0
      .gitignore
  2. 21 0
      LICENSE
  3. 212 0
      README.md
  4. 24 0
      config/__init__.py
  5. 125 0
      config/rtdetr_config.py
  6. 96 0
      config/yolov1_config.py
  7. 0 0
      config/yolov2_config.py
  8. 0 0
      config/yolov3_config.py
  9. 0 0
      config/yolov4_config.py
  10. 0 0
      config/yolov5_config.py
  11. 0 0
      config/yolov7_config.py
  12. 131 0
      config/yolov8_config.py
  13. 0 0
      config/yolox_config.py
  14. 0 0
      dataset/__init__.py
  15. 99 0
      dataset/build.py
  16. 326 0
      dataset/coco.py
  17. 309 0
      dataset/customed.py
  18. 554 0
      dataset/data_augment/ssd_augment.py
  19. 225 0
      dataset/data_augment/strong_augment.py
  20. 291 0
      dataset/data_augment/yolo_augment.py
  21. 20 0
      dataset/scripts/COCO2017.sh
  22. 42 0
      dataset/scripts/VOC2007.sh
  23. 38 0
      dataset/scripts/VOC2012.sh
  24. 70 0
      dataset/scripts/data_to_h5py.py
  25. 313 0
      dataset/voc.py
  26. 290 0
      demo.py
  27. 562 0
      engine.py
  28. 119 0
      eval.py
  29. 33 0
      evaluator/build.py
  30. 98 0
      evaluator/coco_evaluator.py
  31. 108 0
      evaluator/customed_evaluator.py
  32. 356 0
      evaluator/voc_evaluator.py
  33. 59 0
      models/__init__.py
  34. 50 0
      models/rtdetr/README.md
  35. 103 0
      models/rtdetr/basic_modules/backbone.py
  36. 144 0
      models/rtdetr/basic_modules/conv.py
  37. 109 0
      models/rtdetr/basic_modules/dn_compoments.py
  38. 85 0
      models/rtdetr/basic_modules/ext_op/README.md
  39. 65 0
      models/rtdetr/basic_modules/ext_op/ms_deformable_attn_op.cc
  40. 1073 0
      models/rtdetr/basic_modules/ext_op/ms_deformable_attn_op.cu
  41. 7 0
      models/rtdetr/basic_modules/ext_op/setup_ms_deformable_attn_op.py
  42. 140 0
      models/rtdetr/basic_modules/ext_op/test_ms_deformable_attn_op.py
  43. 164 0
      models/rtdetr/basic_modules/fpn.py
  44. 51 0
      models/rtdetr/basic_modules/mlp.py
  45. 71 0
      models/rtdetr/basic_modules/nms_ops.py
  46. 33 0
      models/rtdetr/basic_modules/norm.py
  47. 459 0
      models/rtdetr/basic_modules/transformer.py
  48. 16 0
      models/rtdetr/build.py
  49. 170 0
      models/rtdetr/loss.py
  50. 240 0
      models/rtdetr/loss_utils.py
  51. 52 0
      models/rtdetr/matcher.py
  52. 143 0
      models/rtdetr/rtdetr.py
  53. 304 0
      models/rtdetr/rtdetr_decoder.py
  54. 34 0
      models/rtdetr/rtdetr_encoder.py
  55. 52 0
      models/yolov1/README.md
  56. 16 0
      models/yolov1/build.py
  57. 98 0
      models/yolov1/loss.py
  58. 69 0
      models/yolov1/matcher.py
  59. 146 0
      models/yolov1/yolov1.py
  60. 209 0
      models/yolov1/yolov1_backbone.py
  61. 147 0
      models/yolov1/yolov1_basic.py
  62. 121 0
      models/yolov1/yolov1_head.py
  63. 33 0
      models/yolov1/yolov1_neck.py
  64. 95 0
      models/yolov1/yolov1_pred.py
  65. 47 0
      models/yolov8/README.md
  66. 24 0
      models/yolov8/build.py
  67. 187 0
      models/yolov8/loss.py
  68. 199 0
      models/yolov8/matcher.py
  69. 145 0
      models/yolov8/yolov8.py
  70. 183 0
      models/yolov8/yolov8_backbone.py
  71. 171 0
      models/yolov8/yolov8_basic.py
  72. 277 0
      models/yolov8/yolov8_head.py
  73. 33 0
      models/yolov8/yolov8_neck.py
  74. 104 0
      models/yolov8/yolov8_pafpn.py
  75. 315 0
      models/yolov8/yolov8_pred.py
  76. 27 0
      requirements.txt
  77. 156 0
      test.py
  78. 208 0
      train.py
  79. 36 0
      train.sh
  80. 0 0
      utils/__init__.py
  81. 206 0
      utils/box_ops.py
  82. 166 0
      utils/distributed_utils.py
  83. 574 0
      utils/misc.py
  84. 0 0
      utils/solver/__init__.py
  85. 49 0
      utils/solver/lr_scheduler.py
  86. 104 0
      utils/solver/optimizer.py
  87. 169 0
      utils/vis_tools.py

+ 16 - 0
.gitignore

@@ -0,0 +1,16 @@
+*.pt
+*.pth
+*.pkl
+*.onnx
+*.pyc
+*.zip
+weights
+__pycache__
+det_results
+.vscode
+deployment/OpenVINO/cpp/build
+cluster.json
+train_nebula.py
+train_nebula.sh
+make_data_nebula.sh
+dataset/make_dataset_nebula.py

+ 21 - 0
LICENSE

@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Jianhua Yang
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.

+ 212 - 0
README.md

@@ -0,0 +1,212 @@
+# General Object Detection for Open World
+
+## Requirements
+- We recommend you to use Anaconda to create a conda environment:
+```Shell
+conda create -n odlab python=3.10
+```
+
+- Then, activate the environment:
+```Shell
+conda activate odlab
+```
+
+- Requirements:
+1. Install necessary libraies
+```Shell
+pip install -r requirements.txt 
+```
+
+2. (optional) Compile MSDeformableAttention ops for DETR series
+
+```bash
+cd ./ppdet/modeling/transformers/ext_op/
+
+python setup_ms_deformable_attn_op.py install
+```
+See [details](./models/detectors/rtdetr/basic_modules/ext_op/)
+
+My environment:
+- PyTorch = 2.2.0
+- Torchvision = 0.17.0
+
+At least, please make sure your torch is version 1.x.
+
+## Experiments
+### VOC
+- Download VOC.
+```Shell
+cd <ODLab-World>
+cd dataset/scripts/
+sh VOC2007.sh
+sh VOC2012.sh
+```
+
+- Check VOC
+```Shell
+cd <ODLab-World>
+python dataset/voc.py
+```
+
+### COCO
+
+- Download COCO.
+```Shell
+cd <ODLab-World>
+cd dataset/scripts/
+sh COCO2017.sh
+```
+
+- Clean COCO
+```Shell
+cd <ODLab-World>
+cd tools/
+python clean_coco.py --root path/to/coco --image_set val
+python clean_coco.py --root path/to/coco --image_set train
+```
+
+- Check COCO
+```Shell
+cd <ODLab-World>
+python dataset/coco.py
+```
+
+## Train 
+We kindly provide a script `train.sh` to run the training code. You need to follow the following format to use this script:
+```Shell
+bash train.sh <model> <data> <data_path> <batch_size> <num_gpus> <master_port> <resume_weight>
+```
+
+For example, we use this script to train YOLO-N from the epoch-0:
+```Shell
+bash train.sh yolo_n coco path/to/coco 128 4 1699 None
+```
+
+We can also continue training from existing weights by passing the model's weight file to the resume parameter.
+```Shell
+bash train.sh yolo_n coco path/to/coco 128 4 1699 path/to/yolo_n.pth
+```
+
+
+## Train on custom dataset
+Besides the popular datasets, we can also train the model on ourself dataset. To achieve this goal, you should follow these steps:
+- Step-1: Prepare the images (JPG/JPEG/PNG ...) and use `labelimg` to make XML format annotation files.
+
+```
+CustomedDataset
+|_ train
+|  |_ images     
+|     |_ 0.jpg
+|     |_ 1.jpg
+|     |_ ...
+|  |_ annotations
+|     |_ 0.xml
+|     |_ 1.xml
+|     |_ ...
+|_ val
+|  |_ images     
+|     |_ 0.jpg
+|     |_ 1.jpg
+|     |_ ...
+|  |_ annotations
+|     |_ 0.xml
+|     |_ 1.xml
+|     |_ ...
+|  ...
+```
+
+- Step-2: Make the configuration for our dataset.
+```Shell
+cd <ODLab-World>
+cd config/data_config
+```
+You need to edit the `dataset_cfg` defined in `dataset_config.py`. You can refer to the `customed` defined in `dataset_cfg` to modify the relevant parameters, such as `num_classes`, `classes_names`, to adapt to our dataset.
+
+For example:
+```Shell
+dataset_cfg = {
+    'customed':{
+        'data_name': 'AnimalDataset',
+        'num_classes': 9,
+        'class_indexs': (0, 1, 2, 3, 4, 5, 6, 7, 8),
+        'class_names': ('bird', 'butterfly', 'cat', 'cow', 'dog', 'lion', 'person', 'pig', 'tiger', ),
+    },
+}
+```
+
+- Step-3: Convert customed to COCO format.
+
+```Shell
+cd <ODLab-World>
+cd tools
+# convert train split
+python convert_ours_to_coco.py --root path/to/dataset/ --split train
+# convert val split
+python convert_ours_to_coco.py --root path/to/dataset/ --split val
+```
+Then, we can get a `train.json` file and a `val.json` file, as shown below.
+```
+CustomedDataset
+|_ train
+|  |_ images     
+|     |_ 0.jpg
+|     |_ 1.jpg
+|     |_ ...
+|  |_ annotations
+|     |_ 0.xml
+|     |_ 1.xml
+|     |_ ...
+|     |_ train.json
+|_ val
+|  |_ images     
+|     |_ 0.jpg
+|     |_ 1.jpg
+|     |_ ...
+|  |_ annotations
+|     |_ 0.xml
+|     |_ 1.xml
+|     |_ ...
+|     |_ val.json
+|  ...
+```
+
+- Step-4 Check the data.
+
+```Shell
+cd <ODLab-World>
+cd dataset
+# convert train split
+python customed.py --root path/to/dataset/ --split train
+# convert val split
+python customed.py --root path/to/dataset/ --split val
+```
+
+- Step-5 **Train**
+
+For example:
+
+```Shell
+cd <ODLab-World>
+python train.py --root path/to/dataset/ -d customed -m yolo_n -bs 16 -p path/to/yolo_n_coco.pth
+```
+
+- Step-6 **Test**
+
+For example:
+
+```Shell
+cd <ODLab-World>
+python test.py --root path/to/dataset/ -d customed -m yolo_n --weight path/to/checkpoint --show
+```
+
+- Step-7 **Eval**
+
+For example:
+
+```Shell
+cd <ODLab-World>
+python eval.py --root path/to/dataset/ -d customed -m yolo_n --weight path/to/checkpoint
+```
+
+## Deployment
+1. [ONNX export and an ONNXRuntime](./deployment/ONNXRuntime/)

+ 24 - 0
config/__init__.py

@@ -0,0 +1,24 @@
+# ------------------ Model Config ------------------
+from .yolov1_config   import build_yolov1_config
+from .yolov8_config   import build_yolov8_config
+from .rtdetr_config import build_rtdetr_config
+
+def build_config(args):
+    print('==============================')
+    print('Model: {} ...'.format(args.model.upper()))
+    # YOLOv8
+    if 'yolov1' in args.model:
+        cfg = build_yolov1_config(args)
+    elif 'yolov8' in args.model:
+        cfg = build_yolov8_config(args)
+    # RT-DETR
+    elif 'rtdetr' in args.model:
+        cfg = build_rtdetr_config(args)
+    else:
+        raise NotImplementedError("Unknown model config: {}".format(args.model))
+    
+    # Print model config
+    cfg.print_config()
+
+    return cfg
+

+ 125 - 0
config/rtdetr_config.py

@@ -0,0 +1,125 @@
+# Real-time Transformer-based Object Detector
+
+
+def build_rtdetr_config(args):
+    if args.model == "rtdetr_r18":
+        return RTDetrR18Config()
+    raise NotImplementedError("No config for model: {}".format(args.model))   
+ 
+# rtdetr-Base config
+class RTDetrBaseConfig(object):
+    def __init__(self) -> None:
+        # ---------------- Model config ----------------
+        self.out_stride = [8, 16, 32]
+        self.max_stride = 32
+        ## Backbone
+        self.backbone        = 'resnet18'
+        self.backbone_norm   = 'BN'
+        self.pretrained_weight  = 'imagenet1k_v1'
+        self.pretrained = True
+        self.freeze_at = 0
+        self.freeze_stem_only = False
+        ## Image Encoder - FPN
+        self.fpn      = 'hybrid_encoder'
+        self.fpn_num_blocks = 3
+        self.fpn_expand_ratio = 0.5
+        self.fpn_act  = 'silu'
+        self.fpn_norm = 'BN'
+        self.fpn_depthwise = False
+        self.hidden_dim = 256
+        self.en_num_heads = 8
+        self.en_num_layers = 1
+        self.en_ffn_dim = 1024
+        self.en_dropout = 0.0
+        self.en_act = 'gelu'
+        ## Transformer Decoder
+        self.transformer   = 'rtdetr_transformer'
+        self.de_num_heads  = 8
+        self.de_num_layers = 3
+        self.de_ffn_dim    = 1024
+        self.de_dropout    = 0.0
+        self.de_act        = 'relu'
+        self.de_num_points = 4
+        self.num_queries   = 300
+        self.learnt_init_query = False
+        ## DN
+        self.dn_num_denoising     = 100
+        self.dn_label_noise_ratio = 0.5
+        self.dn_box_noise_scale   = 1
+
+        # ---------------- Post-process config ----------------
+        ## Post process
+        self.val_topk = 300
+        self.val_conf_thresh = 0.001
+        self.val_nms_thresh  = 0.7
+        self.test_topk = 300
+        self.test_conf_thresh = 0.3
+        self.test_nms_thresh  = 0.5
+
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        self.cost_class = 2.0
+        self.cost_bbox  = 5.0
+        self.cost_giou  = 2.0
+        ## Loss weight
+        self.loss_cls  = 1.0
+        self.loss_box  = 5.0
+        self.loss_giou = 2.0
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema = True
+        self.ema_decay = 0.9999
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer = 'rtdetr'
+        self.optimizer = 'adamw'
+        self.per_image_lr = 0.0001 / 16
+        self.base_lr      = None      # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio      = 0.0
+        self.backbone_lr_ratio = 0.1
+        self.momentum  = None
+        self.weight_decay = 0.0001
+        self.clip_max_norm = 0.1
+
+        # ---------------- Lr Scheduler config ----------------
+        self.warmup = 'linear'
+        self.warmup_iters = 2000
+        self.warmup_factor = 0.00066667
+        self.lr_scheduler = "step"
+        self.lr_epoch = [100]
+        self.max_epoch = 72
+        self.eval_epoch = 1
+
+        # ---------------- Data process config ----------------
+        self.aug_type = 'ssd'
+        self.box_format = 'xywh'
+        self.normalize_coords = True
+        self.mosaic_prob = 0.0
+        self.mixup_prob  = 0.0
+        self.copt_paste  = 0.0
+        self.multi_scale = [0.75, 1.25]
+        ## Pixel mean & std
+        self.pixel_mean = [123.675, 116.28, 103.53]   # RGB format
+        self.pixel_std  = [58.395, 57.12, 57.375]     # RGB format
+        ## Transforms
+        self.train_img_size = 640
+        self.test_img_size  = 640
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))
+    
+# RT-DETR-R18
+class RTDetrR18Config(RTDetrBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        ## Backbone
+        self.backbone        = 'resnet18'
+        self.backbone_norm   = 'BN'
+        self.pretrained_weight  = 'imagenet1k_v1'
+        self.pretrained = True
+        self.freeze_at = -1
+        self.freeze_stem_only = False
+

+ 96 - 0
config/yolov1_config.py

@@ -0,0 +1,96 @@
+# yolo Config
+
+
+def build_yolov1_config(args):
+    if args.model == 'yolov1_r18':
+        return Yolov1R18Config()
+    else:
+        raise NotImplementedError("No config for model: {}".format(args.model))
+    
+# YOLOv8-Base config
+class Yolov1BaseConfig(object):
+    def __init__(self) -> None:
+        # ---------------- Model config ----------------
+        self.out_stride = 32
+        self.max_stride = 32
+        ## Backbone
+        self.backbone       = 'resnet50'
+        self.use_pretrained = True
+        ## Neck
+        self.neck_act       = 'lrelu'
+        self.neck_norm      = 'BN'
+        self.neck_depthwise = False
+        self.neck_expand_ratio = 0.5
+        self.spp_pooling_size  = 5
+        ## Head
+        self.head_act  = 'lrelu'
+        self.head_norm = 'BN'
+        self.head_depthwise = False
+        self.head_dim  = 512
+        self.num_cls_head   = 2
+        self.num_reg_head   = 2
+
+        # ---------------- Post-process config ----------------
+        ## Post process
+        self.val_topk = 1000
+        self.val_conf_thresh = 0.001
+        self.val_nms_thresh  = 0.7
+        self.test_topk = 100
+        self.test_conf_thresh = 0.2
+        self.test_nms_thresh  = 0.5
+
+        # ---------------- Assignment config ----------------
+        ## Loss weight
+        self.loss_obj = 1.0
+        self.loss_cls = 1.0
+        self.loss_box = 5.0
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema   = True
+        self.ema_decay = 0.9998
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.optimizer    = 'adamw'
+        self.per_image_lr = 0.001 / 64
+        self.base_lr      = None      # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
+        self.momentum     = 0.9
+        self.weight_decay = 0.05
+        self.clip_max_norm   = -1.
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8
+
+        # ---------------- Lr Scheduler config ----------------
+        self.warmup_epoch = 3
+        self.lr_scheduler = "cosine"
+        self.max_epoch    = 150
+        self.eval_epoch   = 10
+        self.no_aug_epoch = 20
+
+        # ---------------- Data process config ----------------
+        self.aug_type = 'ssd'
+        self.box_format = 'xyxy'
+        self.normalize_coords = False
+        self.mosaic_prob = 0.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0          # approximated by the YOLOX's mixup
+        self.multi_scale = [0.5, 1.25]   # multi scale: [img_size * 0.5, img_size * 1.5]
+        ## Pixel mean & std
+        self.pixel_mean = [123.675, 116.28, 103.53]   # RGB format
+        self.pixel_std  = [58.395, 57.12, 57.375]     # RGB format
+        ## Transforms
+        self.train_img_size = 640
+        self.test_img_size  = 640
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))
+
+# YOLOv8-S
+class Yolov1R18Config(Yolov1BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        self.backbone = 'resnet18'

+ 0 - 0
config/yolov2_config.py


+ 0 - 0
config/yolov3_config.py


+ 0 - 0
config/yolov4_config.py


+ 0 - 0
config/yolov5_config.py


+ 0 - 0
config/yolov7_config.py


+ 131 - 0
config/yolov8_config.py

@@ -0,0 +1,131 @@
+# yolo Config
+
+
+def build_yolov8_config(args):
+    if args.model == 'yolov8_s':
+        return Yolov8SConfig()
+    else:
+        raise NotImplementedError("No config for model: {}".format(args.model))
+    
+# YOLOv8-Base config
+class Yolov8BaseConfig(object):
+    def __init__(self) -> None:
+        # ---------------- Model config ----------------
+        self.width    = 1.0
+        self.depth    = 1.0
+        self.ratio    = 1.0
+        self.reg_max  = 16
+        self.out_stride = [8, 16, 32]
+        self.max_stride = 32
+        self.num_levels = 3
+        self.scale      = "b"
+        ## Backbone
+        self.bk_act   = 'silu'
+        self.bk_norm  = 'BN'
+        self.bk_depthwise = False
+        self.use_pretrained = False
+        ## Neck
+        self.neck_act       = 'silu'
+        self.neck_norm      = 'BN'
+        self.neck_depthwise = False
+        self.neck_expand_ratio = 0.5
+        self.spp_pooling_size  = 5
+        ## FPN
+        self.fpn_act  = 'silu'
+        self.fpn_norm = 'BN'
+        self.fpn_depthwise = False
+        ## Head
+        self.head_act  = 'silu'
+        self.head_norm = 'BN'
+        self.head_depthwise = False
+        self.num_cls_head   = 2
+        self.num_reg_head   = 2
+
+        # ---------------- Post-process config ----------------
+        ## Post process
+        self.val_topk = 1000
+        self.val_conf_thresh = 0.001
+        self.val_nms_thresh  = 0.7
+        self.test_topk = 100
+        self.test_conf_thresh = 0.2
+        self.test_nms_thresh  = 0.5
+
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        self.tal_topk_candidates = 10
+        self.tal_alpha = 0.5
+        self.tal_beta  = 6.0
+        ## Loss weight
+        self.loss_cls = 0.5
+        self.loss_box = 7.5
+        self.loss_dfl = 1.5
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema = True
+        self.ema_decay = 0.9998
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.optimizer    = 'adamw'
+        self.per_image_lr = 0.001 / 64
+        self.base_lr      = None      # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
+        self.momentum     = 0.9
+        self.weight_decay = 0.05
+        self.clip_max_norm   = -1.
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8
+
+        # ---------------- Lr Scheduler config ----------------
+        self.warmup_epoch = 3
+        self.lr_scheduler = "cosine"
+        self.max_epoch    = 500
+        self.eval_epoch   = 10
+        self.no_aug_epoch = 20
+
+        # ---------------- Data process config ----------------
+        self.aug_type = 'yolo'
+        self.box_format = 'xyxy'
+        self.normalize_coords = False
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.15
+        self.copy_paste  = 0.0          # approximated by the YOLOX's mixup
+        self.multi_scale = [0.5, 1.25]   # multi scale: [img_size * 0.5, img_size * 1.5]
+        ## Pixel mean & std
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [255., 255., 255.]
+        ## Transforms
+        self.train_img_size = 640
+        self.test_img_size  = 640
+        self.use_ablu = True
+        self.affine_params = {
+            'degrees': 0.0,
+            'translate': 0.1,
+            'scale': [0.1, 2.0],
+            'shear': 0.0,
+            'perspective': 0.0,
+            'hsv_h': 0.015,
+            'hsv_s': 0.7,
+            'hsv_v': 0.4,
+        }
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))
+
+# YOLOv8-S
+class Yolov8SConfig(Yolov8BaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # ---------------- Model config ----------------
+        self.width = 0.50
+        self.depth = 0.34
+        self.ratio = 2.0
+        self.scale = "s"
+
+        # ---------------- Data process config ----------------
+        self.mosaic_prob = 1.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0

+ 0 - 0
config/yolox_config.py


+ 0 - 0
dataset/__init__.py


+ 99 - 0
dataset/build.py

@@ -0,0 +1,99 @@
+import os
+
+try:
+    # dataset class
+    from .voc        import VOCDataset
+    from .coco       import COCODataset
+    from .customed   import CustomedDataset
+    # transform class
+    from .data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
+    from .data_augment.ssd_augment  import SSDAugmentation, SSDBaseTransform
+
+except:
+    # dataset class
+    from voc        import VOCDataset
+    from coco       import COCODataset
+    from customed   import CustomedDataset
+    # transform class
+    from data_augment.yolo_augment import YOLOAugmentation, YOLOBaseTransform
+    from data_augment.ssd_augment  import SSDAugmentation, SSDBaseTransform
+
+
+# ------------------------------ Dataset ------------------------------
+def build_dataset(args, cfg, transform=None, is_train=False):
+    # ------------------------- Build dataset -------------------------
+    ## VOC dataset
+    if args.dataset == 'voc':
+        image_set = [('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')]
+        cfg.num_classes  = 20
+        dataset = VOCDataset(cfg       = cfg,
+                             data_dir  = args.root,
+                             image_set = image_set,
+                             transform = transform,
+                             is_train  = is_train,
+                             )
+    ## COCO dataset
+    elif args.dataset == 'coco':
+        image_set = 'train2017' if is_train else 'val2017'
+        cfg.num_classes  = 80
+        dataset = COCODataset(cfg       = cfg,
+                              data_dir  = args.root,
+                              image_set = image_set,
+                              transform = transform,
+                              is_train  = is_train,
+                              )
+    ## Custom dataset
+    elif args.dataset == 'customed':
+        image_set = 'train' if is_train else 'val'
+        cfg.num_classes  = 20
+        dataset = CustomedDataset(cfg       = cfg,
+                                  data_dir  = args.root,
+                                  image_set = image_set,
+                                  transform = transform,
+                                  is_train  = is_train,
+                                  )
+
+    cfg.class_labels = dataset.class_labels
+    cfg.class_indexs = dataset.class_indexs
+    cfg.num_classes  = dataset.num_classes
+
+    return dataset
+
+
+# ------------------------------ Transform ------------------------------
+def build_transform(cfg, is_train=False):
+    # ---------------- Build transform ----------------
+    ## YOLO style transform
+    if cfg.aug_type == 'yolo':
+        if is_train:
+            transform = YOLOAugmentation(cfg.train_img_size,
+                                         cfg.affine_params,
+                                         cfg.use_ablu,
+                                         cfg.pixel_mean,
+                                         cfg.pixel_std,
+                                         cfg.box_format,
+                                         cfg.normalize_coords)
+        else:
+            transform = YOLOBaseTransform(cfg.test_img_size,
+                                          cfg.max_stride,
+                                          cfg.pixel_mean,
+                                          cfg.pixel_std,
+                                          cfg.box_format,
+                                          cfg.normalize_coords)
+
+    ## RT-DETR style transform
+    elif cfg.aug_type == 'ssd':
+        if is_train:
+            transform = SSDAugmentation(cfg.train_img_size,
+                                           cfg.pixel_mean,
+                                           cfg.pixel_std,
+                                           cfg.box_format,
+                                           cfg.normalize_coords)
+        else:
+            transform = SSDBaseTransform(cfg.test_img_size,
+                                            cfg.pixel_mean,
+                                            cfg.pixel_std,
+                                            cfg.box_format,
+                                            cfg.normalize_coords)
+
+    return transform

+ 326 - 0
dataset/coco.py

@@ -0,0 +1,326 @@
+import os
+import cv2
+import time
+import random
+import numpy as np
+from torch.utils.data import Dataset
+from pycocotools.coco import COCO
+
+try:
+    from .data_augment.strong_augment import MosaicAugment, MixupAugment
+except:
+    from  data_augment.strong_augment import MosaicAugment, MixupAugment
+
+
+coco_class_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
+coco_class_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',  'traffic light',  'fire hydrant',  'stop sign',  'parking meter',  'bench',  'bird',  'cat',  'dog',  'horse',  'sheep',  'cow',  'elephant',  'bear',  'zebra',  'giraffe',  'backpack',  'umbrella',  'handbag',  'tie',  'suitcase',  'frisbee',  'skis',  'snowboard',  'sports ball',  'kite',  'baseball bat',  'baseball glove',  'skateboard',  'surfboard',  'tennis racket',  'bottle',  'wine glass',  'cup',  'fork',  'knife',  'spoon',  'bowl',  'banana',  'apple',  'sandwich',  'orange',  'broccoli',  'carrot',  'hot dog',  'pizza',  'donut',  'cake',  'chair',  'couch',  'potted plant',  'bed',  'dining table',  'toilet',  'tv',  'laptop',  'mouse',  'remote',  'keyboard',  'cell phone',  'microwave',  'oven',  'toaster',  'sink',  'refrigerator',  'book',  'clock',  'vase',  'scissors',  'teddy bear',  'hair drier',  'toothbrush')
+coco_json_files = {
+    'train2017_clean': 'instances_train2017_clean.json',
+    'val2017_clean'  : 'instances_val2017_clean.json',
+    'train2017'      : 'instances_train2017.json',
+    'val2017'        : 'instances_val2017.json',
+    'test2017'       : 'image_info_test.json',
+}
+
+
+class COCODataset(Dataset):
+    def __init__(self, 
+                 cfg,
+                 data_dir  :str = None, 
+                 image_set :str = 'train2017',
+                 transform = None,
+                 is_train  :bool = False,
+                 use_mask  :bool = False,
+                 ):
+        # ----------- Basic parameters -----------
+        self.data_dir  = data_dir
+        self.image_set = image_set
+        self.is_train  = is_train
+        self.use_mask  = use_mask
+        self.num_classes = 80
+        # ----------- Data parameters -----------
+        try:
+            self.json_file = coco_json_files['{}_clean'.format(image_set)]
+            self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
+        except:
+            self.json_file = coco_json_files['{}'.format(image_set)]
+            self.coco = COCO(os.path.join(self.data_dir, 'annotations', self.json_file))
+        self.ids = self.coco.getImgIds()
+        self.class_ids = sorted(self.coco.getCatIds())
+        self.dataset_size = len(self.ids)
+        self.class_labels = coco_class_labels
+        self.class_indexs = coco_class_indexs
+        # ----------- Transform parameters -----------
+        self.transform = transform
+        if is_train:
+            self.mosaic_prob = cfg.mosaic_prob
+            self.mixup_prob  = cfg.mixup_prob
+            self.copy_paste  = cfg.copy_paste
+            self.mosaic_augment = None if cfg.mosaic_prob == 0. else MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
+            self.mixup_augment  = None if cfg.mixup_prob == 0. and cfg.copy_paste == 0.  else MixupAugment(cfg.train_img_size)
+        else:
+            self.mosaic_prob = 0.0
+            self.mixup_prob  = 0.0
+            self.copy_paste  = 0.0
+            self.mosaic_augment = None
+            self.mixup_augment  = None
+        print('==============================')
+        print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
+        print('use Mixup Augmentation: {}'.format(self.mixup_prob))
+        print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
+
+    # ------------ Basic dataset function ------------
+    def __len__(self):
+        return len(self.ids)
+
+    def __getitem__(self, index):
+        return self.pull_item(index)
+
+    # ------------ Mosaic & Mixup ------------
+    def load_mosaic(self, index):
+        # ------------ Prepare 4 indexes of images ------------
+        ## Load 4x mosaic image
+        index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
+        id1 = index
+        id2, id3, id4 = random.sample(index_list, 3)
+        indexs = [id1, id2, id3, id4]
+
+        ## Load images and targets
+        image_list = []
+        target_list = []
+        for index in indexs:
+            img_i, target_i = self.load_image_target(index)
+            image_list.append(img_i)
+            target_list.append(target_i)
+
+        # ------------ Mosaic augmentation ------------
+        image, target = self.mosaic_augment(image_list, target_list)
+
+        return image, target
+
+    def load_mixup(self, origin_image, origin_target, yolox_style=False):
+        # ------------ Load a new image & target ------------
+        new_index = np.random.randint(0, len(self.ids))
+        new_image, new_target = self.load_mosaic(new_index)
+            
+        # ------------ Mixup augmentation ------------
+        image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target, yolox_style)
+
+        return image, target
+    
+    # ------------ Load data function ------------
+    def load_image_target(self, index):
+        # load an image
+        image, _ = self.pull_image(index)
+        height, width, channels = image.shape
+
+        # load a target
+        bboxes, labels = self.pull_anno(index)
+        target = {
+            "boxes": bboxes,
+            "labels": labels,
+            "orig_size": [height, width]
+        }
+
+        return image, target
+
+    def pull_item(self, index):
+        if random.random() < self.mosaic_prob:
+            # load a mosaic image
+            mosaic = True
+            image, target = self.load_mosaic(index)
+        else:
+            mosaic = False
+            # load an image and target
+            image, target = self.load_image_target(index)
+
+        # Yolov5-MixUp
+        mixup = False
+        if random.random() < self.mixup_prob:
+            mixup = True
+            image, target = self.load_mixup(image, target)
+
+        # Copy-paste (use Yolox-Mixup to approximate copy-paste)
+        if not mixup and random.random() < self.copy_paste:
+            image, target = self.load_mixup(image, target, yolox_style=True)
+
+        # augment
+        image, target, deltas = self.transform(image, target, mosaic)
+
+        return image, target, deltas
+
+    def pull_image(self, index):
+        img_id = self.ids[index]
+        img_file = os.path.join(self.data_dir, self.image_set,
+                                '{:012}'.format(img_id) + '.jpg')
+        image = cv2.imread(img_file)
+
+        if self.json_file == 'instances_val5k.json' and image is None:
+            img_file = os.path.join(self.data_dir, 'train2017',
+                                    '{:012}'.format(img_id) + '.jpg')
+            image = cv2.imread(img_file)
+
+        assert image is not None
+
+        return image, img_id
+
+    def pull_anno(self, index):
+        img_id = self.ids[index]
+        im_ann = self.coco.loadImgs(img_id)[0]
+        anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=False)
+        annotations = self.coco.loadAnns(anno_ids)
+
+        # image infor
+        width = im_ann['width']
+        height = im_ann['height']
+        
+        #load a target
+        bboxes = []
+        labels = []
+        for anno in annotations:
+            if 'bbox' in anno and anno['area'] > 0:
+                # bbox
+                x1 = np.max((0, anno['bbox'][0]))
+                y1 = np.max((0, anno['bbox'][1]))
+                x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
+                y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
+                if x2 < x1 or y2 < y1:
+                    continue
+                # class label
+                cls_id = self.class_ids.index(anno['category_id'])
+                
+                bboxes.append([x1, y1, x2, y2])
+                labels.append(cls_id)
+
+        # guard against no boxes via resizing
+        bboxes = np.array(bboxes).reshape(-1, 4)
+        labels = np.array(labels).reshape(-1)
+        
+        return bboxes, labels
+
+
+if __name__ == "__main__":
+    import time
+    import argparse
+    from build import build_transform
+    
+    parser = argparse.ArgumentParser(description='COCO-Dataset')
+
+    # opt
+    parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
+                        help='data root')
+    parser.add_argument('--is_train', action="store_true", default=False,
+                        help='mixup augmentation.')
+    parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
+                        help='yolo, ssd.')
+
+    args = parser.parse_args()
+
+    class YoloBaseConfig(object):
+        def __init__(self) -> None:
+            self.max_stride = 32
+            # ---------------- Data process config ----------------
+            self.box_format = 'xywh'
+            self.normalize_coords = False
+            self.mosaic_prob = 1.0
+            self.mixup_prob  = 0.15
+            self.copy_paste  = 0.3
+            ## Pixel mean & std
+            self.pixel_mean = [0., 0., 0.]
+            self.pixel_std  = [255., 255., 255.]
+            ## Transforms
+            self.train_img_size = 640
+            self.test_img_size  = 640
+            self.use_ablu = True
+            self.aug_type = 'yolo'
+            self.affine_params = {
+                'degrees': 0.0,
+                'translate': 0.2,
+                'scale': [0.1, 2.0],
+                'shear': 0.0,
+                'perspective': 0.0,
+                'hsv_h': 0.015,
+                'hsv_s': 0.7,
+                'hsv_v': 0.4,
+            }
+
+    class SSDBaseConfig(object):
+        def __init__(self) -> None:
+            self.max_stride = 32
+            # ---------------- Data process config ----------------
+            self.box_format = 'xywh'
+            self.normalize_coords = False
+            self.mosaic_prob = 0.0
+            self.mixup_prob  = 0.0
+            self.copy_paste  = 0.0
+            ## Pixel mean & std
+            self.pixel_mean = [0., 0., 0.]
+            self.pixel_std  = [255., 255., 255.]
+            ## Transforms
+            self.train_img_size = 640
+            self.test_img_size  = 640
+            self.aug_type = 'ssd'
+
+    if args.aug_type == "yolo":
+        cfg = YoloBaseConfig()
+    elif args.aug_type == "ssd":
+        cfg = SSDBaseConfig()
+
+    transform = build_transform(cfg, args.is_train)
+    dataset = COCODataset(cfg, args.root, 'val2017', transform, args.is_train)
+    
+    np.random.seed(0)
+    class_colors = [(np.random.randint(255),
+                     np.random.randint(255),
+                     np.random.randint(255)) for _ in range(80)]
+    print('Data length: ', len(dataset))
+
+    for i in range(1000):
+        t0 = time.time()
+        image, target, deltas = dataset.pull_item(i)
+        print("Load data: {} s".format(time.time() - t0))
+
+        # to numpy
+        image = image.permute(1, 2, 0).numpy()
+        
+        # denormalize
+        image = image * cfg.pixel_std + cfg.pixel_mean
+
+        # rgb -> bgr
+        if transform.color_format == 'rgb':
+            image = image[..., (2, 1, 0)]
+
+        # to uint8
+        image = image.astype(np.uint8)
+        image = image.copy()
+        img_h, img_w = image.shape[:2]
+
+        boxes = target["boxes"]
+        labels = target["labels"]
+
+        for box, label in zip(boxes, labels):
+            if cfg.box_format == 'xyxy':
+                x1, y1, x2, y2 = box
+            elif cfg.box_format == 'xywh':
+                cx, cy, bw, bh = box
+                x1 = cx - 0.5 * bw
+                y1 = cy - 0.5 * bh
+                x2 = cx + 0.5 * bw
+                y2 = cy + 0.5 * bh
+            
+            if cfg.normalize_coords:
+                x1 *= img_w
+                y1 *= img_h
+                x2 *= img_w
+                y2 *= img_h
+
+            cls_id = int(label)
+            color = class_colors[cls_id]
+            # class name
+            label = coco_class_labels[cls_id]
+            image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
+            # put the test on the bbox
+            cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
+        cv2.imshow('gt', image)
+        # cv2.imwrite(str(i)+'.jpg', img)
+        cv2.waitKey(0)

+ 309 - 0
dataset/customed.py

@@ -0,0 +1,309 @@
+import os
+import cv2
+import time
+import random
+import numpy as np
+from torch.utils.data import Dataset
+from pycocotools.coco import COCO
+
+try:
+    from .data_augment.strong_augment import MosaicAugment, MixupAugment
+except:
+    from  data_augment.strong_augment import MosaicAugment, MixupAugment
+
+
+customed_class_indexs = [0, 1, 2, 3, 4, 5, 6, 7, 8]
+customed_class_labels = ('bird', 'butterfly', 'cat', 'cow', 'dog', 'lion', 'person', 'pig', 'tiger', )
+
+
+class CustomedDataset(Dataset):
+    def __init__(self, 
+                 cfg,
+                 data_dir     :str = None, 
+                 image_set    :str = 'train2017',
+                 transform    = None,
+                 is_train     :bool =False,
+                 ):
+        # ----------- Basic parameters -----------
+        self.image_set = image_set
+        self.is_train  = is_train
+        self.num_classes = len(customed_class_labels)
+        # ----------- Path parameters -----------
+        self.data_dir = data_dir
+        self.json_file = '{}.json'.format(image_set)
+        # ----------- Data parameters -----------
+        self.coco = COCO(os.path.join(self.data_dir, image_set, 'annotations', self.json_file))
+        self.ids = self.coco.getImgIds()
+        self.class_ids = sorted(self.coco.getCatIds())
+        self.dataset_size = len(self.ids)
+        self.class_labels = customed_class_labels
+        self.class_indexs = customed_class_indexs
+        # ----------- Transform parameters -----------
+        self.transform = transform
+        if is_train:
+            self.mosaic_prob = cfg.mosaic_prob
+            self.mixup_prob  = cfg.mixup_prob
+            self.copy_paste  = cfg.copy_paste
+            self.mosaic_augment = None if cfg.mosaic_prob == 0. else MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
+            self.mixup_augment  = None if cfg.mixup_prob == 0. and cfg.copy_paste == 0.  else MixupAugment(cfg.train_img_size)
+        else:
+            self.mosaic_prob = 0.0
+            self.mixup_prob  = 0.0
+            self.copy_paste  = 0.0
+            self.mosaic_augment = None
+            self.mixup_augment  = None
+        print('==============================')
+        print('Image Set: {}'.format(image_set))
+        print('Json file: {}'.format(self.json_file))
+        print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
+        print('use Mixup Augmentation: {}'.format(self.mixup_prob))
+        print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
+
+    # ------------ Basic dataset function ------------
+    def __len__(self):
+        return len(self.ids)
+
+    def __getitem__(self, index):
+        return self.pull_item(index)
+
+    # ------------ Mosaic & Mixup ------------
+    def load_mosaic(self, index):
+        # ------------ Prepare 4 indexes of images ------------
+        ## Load 4x mosaic image
+        index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
+        id1 = index
+        id2, id3, id4 = random.sample(index_list, 3)
+        indexs = [id1, id2, id3, id4]
+
+        ## Load images and targets
+        image_list = []
+        target_list = []
+        for index in indexs:
+            img_i, target_i = self.load_image_target(index)
+            image_list.append(img_i)
+            target_list.append(target_i)
+
+        # ------------ Mosaic augmentation ------------
+        image, target = self.mosaic_augment(image_list, target_list)
+
+        return image, target
+
+    def load_mixup(self, origin_image, origin_target, yolox_style=False):
+        # ------------ Load a new image & target ------------
+        new_index = np.random.randint(0, len(self.ids))
+        new_image, new_target = self.load_mosaic(new_index)
+            
+        # ------------ Mixup augmentation ------------
+        image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target, yolox_style)
+
+        return image, target
+    
+    # ------------ Load data function ------------
+    def load_image_target(self, index):
+        # load an image
+        image, _ = self.pull_image(index)
+        height, width, channels = image.shape
+
+        # load a target
+        bboxes, labels = self.pull_anno(index)
+        target = {
+            "boxes": bboxes,
+            "labels": labels,
+            "orig_size": [height, width]
+        }
+
+        return image, target
+
+    def pull_item(self, index):
+        if random.random() < self.mosaic_prob:
+            # load a mosaic image
+            mosaic = True
+            image, target = self.load_mosaic(index)
+        else:
+            mosaic = False
+            # load an image and target
+            image, target = self.load_image_target(index)
+
+        # Yolov5-MixUp
+        mixup = False
+        if random.random() < self.mixup_prob:
+            mixup = True
+            image, target = self.load_mixup(image, target)
+
+        # Copy-paste (use Yolox-Mixup to approximate copy-paste)
+        if not mixup and random.random() < self.copy_paste:
+            image, target = self.load_mixup(image, target, yolox_style=True)
+
+        # augment
+        image, target, deltas = self.transform(image, target, mosaic)
+
+        return image, target, deltas
+
+    def pull_image(self, index):
+        id_ = self.ids[index]
+        im_ann = self.coco.loadImgs(id_)[0] 
+        img_file = os.path.join(
+                self.data_dir, self.image_set, 'images', im_ann["file_name"])
+        image = cv2.imread(img_file)
+
+        return image, id_
+
+    def pull_anno(self, index):
+        img_id = self.ids[index]
+        im_ann = self.coco.loadImgs(img_id)[0]
+        anno_ids = self.coco.getAnnIds(imgIds=[int(img_id)], iscrowd=0)
+        annotations = self.coco.loadAnns(anno_ids)
+        
+        # image infor
+        width = im_ann['width']
+        height = im_ann['height']
+        
+        #load a target
+        bboxes = []
+        labels = []
+        for anno in annotations:
+            if 'bbox' in anno and anno['area'] > 0:
+                # bbox
+                x1 = np.max((0, anno['bbox'][0]))
+                y1 = np.max((0, anno['bbox'][1]))
+                x2 = np.min((width - 1, x1 + np.max((0, anno['bbox'][2] - 1))))
+                y2 = np.min((height - 1, y1 + np.max((0, anno['bbox'][3] - 1))))
+                if x2 <= x1 or y2 <= y1:
+                    continue
+                # class label
+                cls_id = self.class_ids.index(anno['category_id'])
+                
+                bboxes.append([x1, y1, x2, y2])
+                labels.append(cls_id)
+
+        # guard against no boxes via resizing
+        bboxes = np.array(bboxes).reshape(-1, 4)
+        labels = np.array(labels).reshape(-1)
+        
+        return bboxes, labels
+
+
+if __name__ == "__main__":
+    import time
+    import argparse
+    from build import build_transform
+
+    parser = argparse.ArgumentParser(description='RT-ODLab')
+
+    # opt
+    parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
+                        help='data root')
+    parser.add_argument('--is_train', action="store_true", default=False,
+                        help='mixup augmentation.')
+    
+    args = parser.parse_args()
+
+    class YoloBaseConfig(object):
+        def __init__(self) -> None:
+            self.max_stride = 32
+            # ---------------- Data process config ----------------
+            self.box_format = 'xywh'
+            self.normalize_coords = False
+            self.mosaic_prob = 1.0
+            self.mixup_prob  = 0.15
+            self.copy_paste  = 0.3
+            ## Pixel mean & std
+            self.pixel_mean = [0., 0., 0.]
+            self.pixel_std  = [255., 255., 255.]
+            ## Transforms
+            self.train_img_size = 640
+            self.test_img_size  = 640
+            self.random_crop_size = [320, 352, 384, 416, 448, 480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]
+            self.use_ablu = True
+            self.aug_type = 'yolo'
+            self.affine_params = {
+                'degrees': 0.0,
+                'translate': 0.2,
+                'scale': [0.1, 2.0],
+                'shear': 0.0,
+                'perspective': 0.0,
+                'hsv_h': 0.015,
+                'hsv_s': 0.7,
+                'hsv_v': 0.4,
+            }
+
+    class RTDetrBaseConfig(object):
+        def __init__(self) -> None:
+            self.max_stride = 32
+            # ---------------- Data process config ----------------
+            self.box_format = 'xywh'
+            self.normalize_coords = False
+            self.mosaic_prob = 0.0
+            self.mixup_prob  = 0.0
+            self.copy_paste  = 0.0
+            ## Pixel mean & std
+            self.pixel_mean = [0., 0., 0.]
+            self.pixel_std  = [255., 255., 255.]
+            ## Transforms
+            self.train_img_size = 640
+            self.test_img_size  = 640
+            self.aug_type = 'rtdetr'
+
+    if args.aug_type == "yolo":
+        cfg = YoloBaseConfig()
+    elif args.aug_type == "rtdetr":
+        cfg = RTDetrBaseConfig()
+
+    transform = build_transform(cfg, args.is_train)
+    dataset = CustomedDataset(cfg, args.root, 'val', transform, args.is_train)
+    
+    np.random.seed(0)
+    class_colors = [(np.random.randint(255),
+                     np.random.randint(255),
+                     np.random.randint(255)) for _ in range(80)]
+    print('Data length: ', len(dataset))
+
+    for i in range(1000):
+        t0 = time.time()
+        image, target = dataset.pull_item(i)
+        print("Load data: {} s".format(time.time() - t0))
+
+        # to numpy
+        image = image.permute(1, 2, 0).numpy()
+        
+        # denormalize
+        image = image * cfg.pixel_std + cfg.pixel_mean
+
+        # rgb -> bgr
+        if transform.color_format == 'rgb':
+            image = image[..., (2, 1, 0)]
+
+        # to uint8
+        image = image.astype(np.uint8)
+        image = image.copy()
+        img_h, img_w = image.shape[:2]
+
+        boxes = target["boxes"]
+        labels = target["labels"]
+
+        for box, label in zip(boxes, labels):
+            if cfg.box_format == 'xyxy':
+                x1, y1, x2, y2 = box
+            elif cfg.box_format == 'xywh':
+                cx, cy, bw, bh = box
+                x1 = cx - 0.5 * bw
+                y1 = cy - 0.5 * bh
+                x2 = cx + 0.5 * bw
+                y2 = cy + 0.5 * bh
+            
+            if cfg.normalize_coords:
+                x1 *= img_w
+                y1 *= img_h
+                x2 *= img_w
+                y2 *= img_h
+
+            cls_id = int(label)
+            color = class_colors[cls_id]
+            # class name
+            label = customed_class_labels[cls_id]
+            image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
+            # put the test on the bbox
+            cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
+        cv2.imshow('gt', image)
+        # cv2.imwrite(str(i)+'.jpg', img)
+        cv2.waitKey(0)

+ 554 - 0
dataset/data_augment/ssd_augment.py

@@ -0,0 +1,554 @@
+# ------------------------------------------------------------
+# Data preprocessor for Real-time DETR
+# ------------------------------------------------------------
+import cv2
+import numpy as np
+from numpy import random
+
+import torch
+import torch.nn.functional as F
+
+
+# ------------------------- Augmentations -------------------------
+class Compose(object):
+    """Composes several augmentations together.
+    Args:
+        transforms (List[Transform]): list of transforms to compose.
+    Example:
+        >>> augmentations.Compose([
+        >>>     transforms.CenterCrop(10),
+        >>>     transforms.ToTensor(),
+        >>> ])
+    """
+
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, image, target=None):
+        for t in self.transforms:
+            image, target = t(image, target)
+        return image, target
+
+## Convert color format
+class ConvertColorFormat(object):
+    def __init__(self, color_format='rgb'):
+        self.color_format = color_format
+
+    def __call__(self, image, target=None):
+        """
+        Input:
+            image: (np.array) a OpenCV image with BGR color format.
+            target: None
+        Output:
+            image: (np.array) a OpenCV image with given color format.
+            target: None
+        """
+        # Convert color format
+        if self.color_format == 'rgb':
+            image = image[..., (2, 1, 0)]    # BGR -> RGB
+        elif self.color_format == 'bgr':
+            image = image
+        else:
+            raise NotImplementedError("Unknown color format: <{}>".format(self.color_format))
+
+        return image, target
+
+## Random color jitter
+class RandomDistort(object):
+    def __init__(self,
+                 hue=[-18, 18, 0.5],
+                 saturation=[0.5, 1.5, 0.5],
+                 contrast=[0.5, 1.5, 0.5],
+                 brightness=[0.5, 1.5, 0.5],
+                 random_apply=True,
+                 count=4,
+                 random_channel=False,
+                 prob=1.0):
+        super(RandomDistort, self).__init__()
+        self.hue = hue
+        self.saturation = saturation
+        self.contrast = contrast
+        self.brightness = brightness
+        self.random_apply = random_apply
+        self.count = count
+        self.random_channel = random_channel
+        self.prob = prob
+
+    def apply_hue(self, image, target=None):
+        if np.random.uniform(0., 1.) < self.prob:
+            return image, target
+
+        low, high, prob = self.hue
+        image = image.astype(np.float32)
+        # it works, but result differ from HSV version
+        delta = np.random.uniform(low, high)
+        u = np.cos(delta * np.pi)
+        w = np.sin(delta * np.pi)
+        bt = np.array([[1.0, 0.0, 0.0], [0.0, u, -w], [0.0, w, u]])
+        tyiq = np.array([[0.299, 0.587, 0.114], [0.596, -0.274, -0.321],
+                         [0.211, -0.523, 0.311]])
+        ityiq = np.array([[1.0, 0.956, 0.621], [1.0, -0.272, -0.647],
+                          [1.0, -1.107, 1.705]])
+        t = np.dot(np.dot(ityiq, bt), tyiq).T
+        image = np.dot(image, t)
+
+        return image, target
+
+    def apply_saturation(self, image, target=None):
+        low, high, prob = self.saturation
+        if np.random.uniform(0., 1.) < self.prob:
+            return image, target
+        delta = np.random.uniform(low, high)
+        image = image.astype(np.float32)
+        # it works, but result differ from HSV version
+        gray = image * np.array([[[0.299, 0.587, 0.114]]], dtype=np.float32)
+        gray = gray.sum(axis=2, keepdims=True)
+        gray *= (1.0 - delta)
+        image *= delta
+        image += gray
+
+        return image, target
+
+    def apply_contrast(self, image, target=None):
+        if np.random.uniform(0., 1.) < self.prob:
+            return image, target
+        
+        low, high, prob = self.contrast
+        delta = np.random.uniform(low, high)
+        image = image.astype(np.float32)
+        image *= delta
+
+        return image, target
+
+    def apply_brightness(self, image, target=None):
+        if np.random.uniform(0., 1.) < self.prob:
+            return image, target
+        
+        low, high, prob = self.brightness
+        delta = np.random.uniform(low, high)
+        image = image.astype(np.float32)
+        image += delta
+
+        return image, target
+
+    def __call__(self, image, target=None):
+        if random.random() > self.prob:
+            return image, target
+
+        if self.random_apply:
+            functions = [
+                self.apply_brightness, self.apply_contrast,
+                self.apply_saturation, self.apply_hue
+            ]
+            distortions = np.random.permutation(functions)[:self.count]
+            for func in distortions:
+                image, target = func(image, target)
+
+            return image, target
+
+        image, target = self.apply_brightness(image, target)
+        mode = np.random.randint(0, 2)
+
+        if mode:
+            image, target = self.apply_contrast(image, target)
+
+        image, target = self.apply_saturation(image, target)
+        image, target = self.apply_hue(image, target)
+
+        if not mode:
+            image, target = self.apply_contrast(image, target)
+
+        if self.random_channel:
+            if np.random.randint(0, 2):
+                image = image[..., np.random.permutation(3)]
+
+        return image, target
+
+## Random scaling
+class RandomExpand(object):
+    def __init__(self, fill_value) -> None:
+        self.fill_value = fill_value
+
+    def __call__(self, image, target=None):
+        if random.randint(2):
+            return image, target
+
+        height, width, channels = image.shape
+        ratio = random.uniform(1, 4)
+        left = random.uniform(0, width*ratio - width)
+        top = random.uniform(0, height*ratio - height)
+
+        expand_image = np.ones(
+            (int(height*ratio), int(width*ratio), channels),
+            dtype=image.dtype) * self.fill_value
+        expand_image[int(top):int(top + height),
+                     int(left):int(left + width)] = image
+        image = expand_image
+
+        boxes = target['boxes'].copy()
+        boxes[:, :2] += (int(left), int(top))
+        boxes[:, 2:] += (int(left), int(top))
+        target['boxes'] = boxes
+
+        return image, target
+
+## Random IoU based Sample Crop
+class RandomIoUCrop(object):
+    def __init__(self, p=0.5):
+        self.p = p
+        self.sample_options = (
+            # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9
+            (0.1, None),
+            (0.3, None),
+            (0.5, None),
+            (0.7, None),
+            (0.9, None),
+        )
+
+    def intersect(self, box_a, box_b):
+        max_xy = np.minimum(box_a[:, 2:], box_b[2:])
+        min_xy = np.maximum(box_a[:, :2], box_b[:2])
+        inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
+
+        return inter[:, 0] * inter[:, 1]
+
+    def compute_iou(self, box_a, box_b):
+        inter = self.intersect(box_a, box_b)
+        area_a = ((box_a[:, 2]-box_a[:, 0]) *
+                (box_a[:, 3]-box_a[:, 1]))  # [A,B]
+        area_b = ((box_b[2]-box_b[0]) *
+                (box_b[3]-box_b[1]))  # [A,B]
+        union = area_a + area_b - inter
+        return inter / union  # [A,B]
+
+    def __call__(self, image, target=None):
+        height, width, _ = image.shape
+
+        # check target
+        if len(target["boxes"]) == 0 or random.random() > self.p:
+            return image, target
+
+        while True:
+            # randomly choose a mode
+            sample_id = np.random.randint(len(self.sample_options))
+            mode = self.sample_options[sample_id]
+            if mode is None:
+                return image, target
+
+            boxes = target["boxes"]
+            labels = target["labels"]
+
+            min_iou, max_iou = mode
+            if min_iou is None:
+                min_iou = float('-inf')
+            if max_iou is None:
+                max_iou = float('inf')
+
+            # max trails (50)
+            for _ in range(50):
+                current_image = image
+
+                w = random.uniform(0.3 * width, width)
+                h = random.uniform(0.3 * height, height)
+
+                # aspect ratio constraint b/t .5 & 2
+                if h / w < 0.5 or h / w > 2:
+                    continue
+
+                left = random.uniform(width - w)
+                top = random.uniform(height - h)
+
+                # convert to integer rect x1,y1,x2,y2
+                rect = np.array([int(left), int(top), int(left+w), int(top+h)])
+
+                # calculate IoU (jaccard overlap) b/t the cropped and gt boxes
+                overlap = self.compute_iou(boxes, rect)
+
+                # is min and max overlap constraint satisfied? if not try again
+                if overlap.min() < min_iou and max_iou < overlap.max():
+                    continue
+
+                # cut the crop from the image
+                current_image = current_image[rect[1]:rect[3], rect[0]:rect[2],
+                                              :]
+
+                # keep overlap with gt box IF center in sampled patch
+                centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
+
+                # mask in all gt boxes that above and to the left of centers
+                m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
+
+                # mask in all gt boxes that under and to the right of centers
+                m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
+
+                # mask in that both m1 and m2 are true
+                mask = m1 * m2
+
+                # have any valid boxes? try again if not
+                if not mask.any():
+                    continue
+
+                # take only matching gt boxes
+                current_boxes = boxes[mask, :].copy()
+
+                # take only matching gt labels
+                current_labels = labels[mask]
+
+                # should we use the box left and top corner or the crop's
+                current_boxes[:, :2] = np.maximum(current_boxes[:, :2],
+                                                  rect[:2])
+                # adjust to crop (by substracting crop's left,top)
+                current_boxes[:, :2] -= rect[:2]
+
+                current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:],
+                                                  rect[2:])
+                # adjust to crop (by substracting crop's left,top)
+                current_boxes[:, 2:] -= rect[:2]
+
+                # update target
+                target["boxes"] = current_boxes
+                target["labels"] = current_labels
+
+                return current_image, target
+
+## Random JitterCrop
+class RandomJitterCrop(object):
+    """Jitter and crop the image and box."""
+    def __init__(self, fill_value, p=0.5, jitter_ratio=0.3):
+        super().__init__()
+        self.p = p
+        self.jitter_ratio = jitter_ratio
+        self.fill_value = fill_value
+
+    def crop(self, image, pleft, pright, ptop, pbot, output_size):
+        oh, ow = image.shape[:2]
+
+        swidth, sheight = output_size
+
+        src_rect = [pleft, ptop, swidth + pleft,
+                    sheight + ptop]  # x1,y1,x2,y2
+        img_rect = [0, 0, ow, oh]
+        # rect intersection
+        new_src_rect = [max(src_rect[0], img_rect[0]),
+                        max(src_rect[1], img_rect[1]),
+                        min(src_rect[2], img_rect[2]),
+                        min(src_rect[3], img_rect[3])]
+        dst_rect = [max(0, -pleft),
+                    max(0, -ptop),
+                    max(0, -pleft) + new_src_rect[2] - new_src_rect[0],
+                    max(0, -ptop) + new_src_rect[3] - new_src_rect[1]]
+
+        # crop the image
+        cropped = np.ones([sheight, swidth, 3], dtype=image.dtype) * self.fill_value
+        # cropped[:, :, ] = np.mean(image, axis=(0, 1))
+        cropped[dst_rect[1]:dst_rect[3], dst_rect[0]:dst_rect[2]] = \
+            image[new_src_rect[1]:new_src_rect[3],
+            new_src_rect[0]:new_src_rect[2]]
+
+        return cropped
+
+    def __call__(self, image, target=None):
+        if random.random() > self.p:
+            return image, target
+        else:
+            oh, ow = image.shape[:2]
+            dw = int(ow * self.jitter_ratio)
+            dh = int(oh * self.jitter_ratio)
+            pleft = np.random.randint(-dw, dw)
+            pright = np.random.randint(-dw, dw)
+            ptop = np.random.randint(-dh, dh)
+            pbot = np.random.randint(-dh, dh)
+
+            swidth = ow - pleft - pright
+            sheight = oh - ptop - pbot
+            output_size = (swidth, sheight)
+            # crop image
+            cropped_image = self.crop(image=image,
+                                    pleft=pleft, 
+                                    pright=pright, 
+                                    ptop=ptop, 
+                                    pbot=pbot,
+                                    output_size=output_size)
+            # crop bbox
+            if target is not None:
+                bboxes = target['boxes'].copy()
+                coords_offset = np.array([pleft, ptop], dtype=np.float32)
+                bboxes[..., [0, 2]] = bboxes[..., [0, 2]] - coords_offset[0]
+                bboxes[..., [1, 3]] = bboxes[..., [1, 3]] - coords_offset[1]
+                swidth, sheight = output_size
+
+                bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], 0, swidth - 1)
+                bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], 0, sheight - 1)
+                target['boxes'] = bboxes
+
+            return cropped_image, target
+    
+## Random HFlip
+class RandomHorizontalFlip(object):
+    def __init__(self, p=0.5):
+        self.p = p
+
+    def __call__(self, image, target=None):
+        if random.random() < self.p:
+            orig_h, orig_w = image.shape[:2]
+            image = image[:, ::-1]
+            if target is not None:
+                if "boxes" in target:
+                    boxes = target["boxes"].copy()
+                    boxes[..., [0, 2]] = orig_w - boxes[..., [2, 0]]
+                    target["boxes"] = boxes
+
+        return image, target
+
+## Resize tensor image
+class Resize(object):
+    def __init__(self, img_size=640):
+        self.img_size = img_size
+
+    def __call__(self, image, target=None):
+        orig_h, orig_w = image.shape[:2]
+
+        # resize
+        image = cv2.resize(image, (self.img_size, self.img_size)).astype(np.float32)
+        img_h, img_w = image.shape[:2]
+
+        # rescale bboxes
+        if target is not None:
+            boxes = target["boxes"].astype(np.float32)
+            boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
+            boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
+            target["boxes"] = boxes
+
+        return image, target
+
+## Normalize tensor image
+class Normalize(object):
+    def __init__(self, pixel_mean, pixel_std, normalize_coords=False):
+        self.pixel_mean = pixel_mean
+        self.pixel_std = pixel_std
+        self.normalize_coords = normalize_coords
+
+    def __call__(self, image, target=None):
+        # normalize image
+        image = (image - self.pixel_mean) / self.pixel_std
+
+        # normalize bbox
+        if target is not None and self.normalize_coords:
+            img_h, img_w = image.shape[:2]
+            target["boxes"][..., [0, 2]] = target["boxes"][..., [0, 2]] / float(img_w)
+            target["boxes"][..., [1, 3]] = target["boxes"][..., [1, 3]] / float(img_h)
+
+        return image, target
+
+## Convert ndarray to torch.Tensor
+class ToTensor(object):
+    def __call__(self, image, target=None):        
+        # Convert torch.Tensor
+        image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
+
+        if target is not None:
+            target["boxes"] = torch.as_tensor(target["boxes"]).float()
+            target["labels"] = torch.as_tensor(target["labels"]).long()
+
+        return image, target
+
+## Convert BBox foramt
+class ConvertBoxFormat(object):
+    def __init__(self, box_format="xyxy"):
+        self.box_format = box_format
+
+    def __call__(self, image, target=None):
+        # convert box format
+        if self.box_format == "xyxy" or target is None:
+            pass
+        elif self.box_format == "xywh":
+            target = target.copy()
+            if "boxes" in target:
+                boxes_xyxy = target["boxes"]
+                boxes_xywh = torch.zeros_like(boxes_xyxy)
+                boxes_xywh[..., :2] = (boxes_xyxy[..., :2] + boxes_xyxy[..., 2:]) * 0.5   # cxcy
+                boxes_xywh[..., 2:] = boxes_xyxy[..., 2:] - boxes_xyxy[..., :2]           # bwbh
+                target["boxes"] = boxes_xywh
+        else:
+            raise NotImplementedError("Unknown box format: {}".format(self.box_format))
+
+        return image, target
+
+
+# ------------------------- Preprocessers -------------------------
+## Transform for Train
+class SSDAugmentation(object):
+    def __init__(self,
+                 img_size   = 640,
+                 pixel_mean = [123.675, 116.28, 103.53],
+                 pixel_std  = [58.395, 57.12, 57.375],
+                 box_format = 'xywh',
+                 normalize_coords = False):
+        # ----------------- Basic parameters -----------------
+        self.img_size = img_size
+        self.box_format = box_format
+        self.pixel_mean = pixel_mean   # RGB format
+        self.pixel_std  = pixel_std    # RGB format
+        self.normalize_coords = normalize_coords
+        self.color_format = 'rgb'
+        print("================= Pixel Statistics =================")
+        print("Pixel mean: {}".format(self.pixel_mean))
+        print("Pixel std:  {}".format(self.pixel_std))
+
+        # ----------------- Transforms -----------------
+        self.augment = Compose([
+            RandomDistort(prob=0.8),
+            RandomExpand(fill_value=self.pixel_mean[::-1]),
+            RandomIoUCrop(p=0.8),
+            RandomHorizontalFlip(p=0.5),
+            Resize(img_size=self.img_size),
+            ConvertColorFormat(self.color_format),
+            Normalize(self.pixel_mean, self.pixel_std, normalize_coords),
+            ToTensor(),
+            ConvertBoxFormat(self.box_format),
+        ])
+
+    def __call__(self, image, target, mosaic=False):
+        orig_h, orig_w = image.shape[:2]
+        ratio = [self.img_size / orig_w, self.img_size / orig_h]
+
+        image, target = self.augment(image, target)
+
+        return image, target, ratio
+
+## Transform for Eval
+class SSDBaseTransform(object):
+    def __init__(self,
+                 img_size   = 640,
+                 pixel_mean = [123.675, 116.28, 103.53],
+                 pixel_std  = [58.395, 57.12, 57.375],
+                 box_format = 'xywh',
+                 normalize_coords = False):
+        # ----------------- Basic parameters -----------------
+        self.img_size = img_size
+        self.box_format = box_format
+        self.pixel_mean = pixel_mean  # RGB format
+        self.pixel_std  = pixel_std    # RGB format
+        self.normalize_coords = normalize_coords
+        self.color_format = 'rgb'
+        print("================= Pixel Statistics =================")
+        print("Pixel mean: {}".format(self.pixel_mean))
+        print("Pixel std:  {}".format(self.pixel_std))
+
+        # ----------------- Transforms -----------------
+        self.transform = Compose([
+            Resize(img_size=self.img_size),
+            ConvertColorFormat(self.color_format),
+            Normalize(self.pixel_mean, self.pixel_std, self.normalize_coords),
+            ToTensor(),
+            ConvertBoxFormat(self.box_format),
+        ])
+
+
+    def __call__(self, image, target=None, mosaic=False):
+        orig_h, orig_w = image.shape[:2]
+        ratio = [self.img_size / orig_w, self.img_size / orig_h]
+
+        image, target = self.transform(image, target)
+
+        return image, target, ratio

+ 225 - 0
dataset/data_augment/strong_augment.py

@@ -0,0 +1,225 @@
+import random
+import cv2
+import numpy as np
+
+from .yolo_augment import random_perspective
+
+
+# ------------------------- Strong augmentations -------------------------
+## Mosaic Augmentation
+class MosaicAugment(object):
+    def __init__(self,
+                 img_size,
+                 affine_params,
+                 is_train=False,
+                 ) -> None:
+        self.img_size = img_size
+        self.is_train = is_train
+        self.affine_params = affine_params
+
+    def __call__(self, image_list, target_list):
+        assert len(image_list) == 4
+        # mosaic center
+        yc, xc = [int(random.uniform(-x, 2*self.img_size + x)) for x in [-self.img_size // 2, -self.img_size // 2]]
+
+        mosaic_bboxes = []
+        mosaic_labels = []
+        mosaic_img = np.zeros([self.img_size*2, self.img_size*2, image_list[0].shape[2]], dtype=np.uint8)
+        for i in range(4):
+            img_i, target_i = image_list[i], target_list[i]
+            bboxes_i = target_i["boxes"]
+            labels_i = target_i["labels"]
+            orig_h, orig_w, _ = img_i.shape
+
+            # ------------------ Keep ratio Resize ------------------
+            r = self.img_size / max(orig_h, orig_w)
+            if r != 1: 
+                interp = cv2.INTER_LINEAR if (self.is_train or r > 1) else cv2.INTER_AREA
+                img_i = cv2.resize(img_i, (int(orig_w * r), int(orig_h * r)), interpolation=interp)
+            h, w, _ = img_i.shape
+
+            # ------------------ Create mosaic image ------------------
+            ## Place image in mosaic image
+            if i == 0:  # top left
+                x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc  # xmin, ymin, xmax, ymax (large image)
+                x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h  # xmin, ymin, xmax, ymax (small image)
+            elif i == 1:  # top right
+                x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, self.img_size * 2), yc
+                x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
+            elif i == 2:  # bottom left
+                x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(self.img_size * 2, yc + h)
+                x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
+            elif i == 3:  # bottom right
+                x1a, y1a, x2a, y2a = xc, yc, min(xc + w, self.img_size * 2), min(self.img_size * 2, yc + h)
+                x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
+
+            mosaic_img[y1a:y2a, x1a:x2a] = img_i[y1b:y2b, x1b:x2b]
+            padw = x1a - x1b
+            padh = y1a - y1b
+
+            ## Mosaic target
+            bboxes_i_ = bboxes_i.copy()
+            if len(bboxes_i) > 0:
+                # a valid target, and modify it.
+                bboxes_i_[:, 0] = (w * bboxes_i[:, 0] / orig_w + padw)
+                bboxes_i_[:, 1] = (h * bboxes_i[:, 1] / orig_h + padh)
+                bboxes_i_[:, 2] = (w * bboxes_i[:, 2] / orig_w + padw)
+                bboxes_i_[:, 3] = (h * bboxes_i[:, 3] / orig_h + padh)    
+
+                mosaic_bboxes.append(bboxes_i_)
+                mosaic_labels.append(labels_i)
+
+        if len(mosaic_bboxes) == 0:
+            mosaic_bboxes = np.array([]).reshape(-1, 4)
+            mosaic_labels = np.array([]).reshape(-1)
+        else:
+            mosaic_bboxes = np.concatenate(mosaic_bboxes)
+            mosaic_labels = np.concatenate(mosaic_labels)
+
+        # clip
+        mosaic_bboxes = mosaic_bboxes.clip(0, self.img_size * 2)
+
+        # ----------------------- Random perspective -----------------------
+        mosaic_targets = np.concatenate([mosaic_labels[..., None], mosaic_bboxes], axis=-1)
+        mosaic_img, mosaic_targets = random_perspective(
+            mosaic_img,
+            mosaic_targets,
+            self.affine_params['degrees'],
+            translate   = self.affine_params['translate'],
+            scale       = self.affine_params['scale'],
+            shear       = self.affine_params['shear'],
+            perspective = self.affine_params['perspective'],
+            border      = [-self.img_size//2, -self.img_size//2]
+            )
+
+        # target
+        mosaic_target = {
+            "boxes": mosaic_targets[..., 1:],
+            "labels": mosaic_targets[..., 0],
+        }
+
+        return mosaic_img, mosaic_target
+
+## Mixup Augmentation
+class MixupAugment(object):
+    def __init__(self, img_size) -> None:
+        self.img_size = img_size
+
+    def yolox_mixup_augment(self, origin_image, origin_target, new_image, new_target):
+        jit_factor = random.uniform(0.5, 1.5)
+        FLIP = random.uniform(0, 1) > 0.5
+
+        # resize new image
+        orig_h, orig_w = new_image.shape[:2]
+        cp_scale_ratio = self.img_size / max(orig_h, orig_w)
+        if cp_scale_ratio != 1: 
+            interp = cv2.INTER_LINEAR if cp_scale_ratio > 1 else cv2.INTER_AREA
+            resized_new_img = cv2.resize(
+                new_image, (int(orig_w * cp_scale_ratio), int(orig_h * cp_scale_ratio)), interpolation=interp)
+        else:
+            resized_new_img = new_image
+
+        # pad new image
+        cp_img = np.ones([self.img_size, self.img_size, new_image.shape[2]], dtype=np.uint8) * 114
+        new_shape = (resized_new_img.shape[1], resized_new_img.shape[0])
+        cp_img[:new_shape[1], :new_shape[0]] = resized_new_img
+
+        # resize padded new image
+        cp_img_h, cp_img_w = cp_img.shape[:2]
+        cp_new_shape = (int(cp_img_w * jit_factor),
+                        int(cp_img_h * jit_factor))
+        cp_img = cv2.resize(cp_img, (cp_new_shape[0], cp_new_shape[1]))
+        cp_scale_ratio *= jit_factor
+
+        # flip new image
+        if FLIP:
+            cp_img = cp_img[:, ::-1, :]
+
+        # pad image
+        origin_h, origin_w = cp_img.shape[:2]
+        target_h, target_w = origin_image.shape[:2]
+        padded_img = np.zeros(
+            (max(origin_h, target_h), max(origin_w, target_w), 3), dtype=np.uint8
+        )
+        padded_img[:origin_h, :origin_w] = cp_img
+
+        # crop padded image
+        x_offset, y_offset = 0, 0
+        if padded_img.shape[0] > target_h:
+            y_offset = random.randint(0, padded_img.shape[0] - target_h - 1)
+        if padded_img.shape[1] > target_w:
+            x_offset = random.randint(0, padded_img.shape[1] - target_w - 1)
+        padded_cropped_img = padded_img[
+            y_offset: y_offset + target_h, x_offset: x_offset + target_w
+        ]
+
+        # process target
+        new_boxes = new_target["boxes"]
+        new_labels = new_target["labels"]
+        new_boxes[:, 0::2] = np.clip(new_boxes[:, 0::2] * cp_scale_ratio, 0, origin_w)
+        new_boxes[:, 1::2] = np.clip(new_boxes[:, 1::2] * cp_scale_ratio, 0, origin_h)
+        if FLIP:
+            new_boxes[:, 0::2] = (
+                origin_w - new_boxes[:, 0::2][:, ::-1]
+            )
+        new_boxes[:, 0::2] = np.clip(
+            new_boxes[:, 0::2] - x_offset, 0, target_w
+        )
+        new_boxes[:, 1::2] = np.clip(
+            new_boxes[:, 1::2] - y_offset, 0, target_h
+        )
+
+        # mixup target
+        mixup_boxes = np.concatenate([new_boxes, origin_target['boxes']], axis=0)
+        mixup_labels = np.concatenate([new_labels, origin_target['labels']], axis=0)
+        mixup_target = {
+            'boxes': mixup_boxes,
+            'labels': mixup_labels
+        }
+
+        # mixup images
+        origin_image = origin_image.astype(np.float32)
+        origin_image = 0.5 * origin_image + 0.5 * padded_cropped_img.astype(np.float32)
+
+        return origin_image.astype(np.uint8), mixup_target
+            
+    def yolo_mixup_augment(self, origin_image, origin_target, new_image, new_target):
+        if origin_image.shape[:2] != new_image.shape[:2]:
+            img_size = max(new_image.shape[:2])
+            # origin_image is not a mosaic image
+            orig_h, orig_w = origin_image.shape[:2]
+            scale_ratio = img_size / max(orig_h, orig_w)
+            if scale_ratio != 1: 
+                interp = cv2.INTER_LINEAR if scale_ratio > 1 else cv2.INTER_AREA
+                resize_size = (int(orig_w * scale_ratio), int(orig_h * scale_ratio))
+                origin_image = cv2.resize(origin_image, resize_size, interpolation=interp)
+
+            # pad new image
+            pad_origin_image = np.zeros([img_size, img_size, origin_image.shape[2]], dtype=np.uint8)
+            pad_origin_image[:resize_size[1], :resize_size[0]] = origin_image
+            origin_image = pad_origin_image.copy()
+            del pad_origin_image
+
+        r = np.random.beta(32.0, 32.0)
+        mixup_image = r * origin_image.astype(np.float32) + \
+                    (1.0 - r)* new_image.astype(np.float32)
+        mixup_image = mixup_image.astype(np.uint8)
+        
+        cls_labels = new_target["labels"].copy()
+        box_labels = new_target["boxes"].copy()
+
+        mixup_bboxes = np.concatenate([origin_target["boxes"], box_labels], axis=0)
+        mixup_labels = np.concatenate([origin_target["labels"], cls_labels], axis=0)
+
+        mixup_target = {
+            "boxes": mixup_bboxes,
+            "labels": mixup_labels,
+        }
+        
+        return mixup_image, mixup_target
+
+    def __call__(self, origin_image, origin_target, new_image, new_target, yolox_style=False):
+        if yolox_style:
+            return self.yolox_mixup_augment(origin_image, origin_target, new_image, new_target)
+        else:
+            return self.yolo_mixup_augment(origin_image, origin_target, new_image, new_target)

+ 291 - 0
dataset/data_augment/yolo_augment.py

@@ -0,0 +1,291 @@
+import random
+import cv2
+import math
+import numpy as np
+import albumentations as albu
+
+import torch
+import torchvision.transforms.functional as F
+
+
+# ------------------------- Basic augmentations -------------------------
+## Spatial transform
+def random_perspective(image,
+                       targets=(),
+                       degrees=10,
+                       translate=.1,
+                       scale=[0.1, 2.0],
+                       shear=10,
+                       perspective=0.0,
+                       border=(0, 0)):
+    # torchvision.transforms.RandomAffine(degrees=(-10, 10), translate=(0.1, 0.1), scale=(0.9, 1.1), shear=(-10, 10))
+    # targets = [cls, xyxy]
+
+    height = image.shape[0] + border[0] * 2  # shape(h,w,c)
+    width = image.shape[1] + border[1] * 2
+
+    # Center
+    C = np.eye(3)
+    C[0, 2] = -image.shape[1] / 2  # x translation (pixels)
+    C[1, 2] = -image.shape[0] / 2  # y translation (pixels)
+
+    # Perspective
+    P = np.eye(3)
+    P[2, 0] = random.uniform(-perspective, perspective)  # x perspective (about y)
+    P[2, 1] = random.uniform(-perspective, perspective)  # y perspective (about x)
+
+    # Rotation and Scale
+    R = np.eye(3)
+    a = random.uniform(-degrees, degrees)
+    # a += random.choice([-180, -90, 0, 90])  # add 90deg rotations to small rotations
+    s = random.uniform(scale[0], scale[1])
+    # s = 2 ** random.uniform(-scale, scale)
+    R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
+
+    # Shear
+    S = np.eye(3)
+    S[0, 1] = math.tan(random.uniform(-shear, shear) * math.pi / 180)  # x shear (deg)
+    S[1, 0] = math.tan(random.uniform(-shear, shear) * math.pi / 180)  # y shear (deg)
+
+    # Translation
+    T = np.eye(3)
+    T[0, 2] = random.uniform(0.5 - translate, 0.5 + translate) * width  # x translation (pixels)
+    T[1, 2] = random.uniform(0.5 - translate, 0.5 + translate) * height  # y translation (pixels)
+
+    # Combined rotation matrix
+    M = T @ S @ R @ P @ C  # order of operations (right to left) is IMPORTANT
+    if (border[0] != 0) or (border[1] != 0) or (M != np.eye(3)).any():  # image changed
+        if perspective:
+            image = cv2.warpPerspective(image, M, dsize=(width, height), borderValue=(0, 0, 0))
+        else:  # affine
+            image = cv2.warpAffine(image, M[:2], dsize=(width, height), borderValue=(0, 0, 0))
+
+    # Transform label coordinates
+    n = len(targets)
+    if n:
+        new = np.zeros((n, 4))
+        # warp boxes
+        xy = np.ones((n * 4, 3))
+        xy[:, :2] = targets[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2)  # x1y1, x2y2, x1y2, x2y1
+        xy = xy @ M.T  # transform
+        xy = (xy[:, :2] / xy[:, 2:3] if perspective else xy[:, :2]).reshape(n, 8)  # perspective rescale or affine
+
+        # create new boxes
+        x = xy[:, [0, 2, 4, 6]]
+        y = xy[:, [1, 3, 5, 7]]
+        new = np.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
+
+        # clip
+        new[:, [0, 2]] = new[:, [0, 2]].clip(0, width)
+        new[:, [1, 3]] = new[:, [1, 3]].clip(0, height)
+
+        targets[:, 1:5] = new
+
+    return image, targets
+
+## Color transform
+def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
+    r = np.random.uniform(-1, 1, 3) * [hgain, sgain, vgain] + 1  # random gains
+    hue, sat, val = cv2.split(cv2.cvtColor(img, cv2.COLOR_BGR2HSV))
+    dtype = img.dtype  # uint8
+
+    x = np.arange(0, 256, dtype=np.int16)
+    lut_hue = ((x * r[0]) % 180).astype(dtype)
+    lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
+    lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
+
+    img_hsv = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val))).astype(dtype)
+    cv2.cvtColor(img_hsv, cv2.COLOR_HSV2BGR, dst=img)  # no return needed
+
+    return img
+
+## Ablu transform
+class Albumentations(object):
+    def __init__(self, img_size=640):
+        self.img_size = img_size
+        self.transform = albu.Compose(
+            [albu.Blur(p=0.01),
+             albu.MedianBlur(p=0.01),
+             albu.ToGray(p=0.01),
+             albu.CLAHE(p=0.01),
+             ],
+             bbox_params=albu.BboxParams(format='pascal_voc', label_fields=['labels'])
+        )
+
+    def __call__(self, image, target=None):
+        labels = target['labels']
+        bboxes = target['boxes']
+        if len(labels) > 0:
+            new = self.transform(image=image, bboxes=bboxes, labels=labels)
+            if len(new["labels"]) > 0:
+                image = new['image']
+                target['labels'] = np.array(new["labels"], dtype=labels.dtype)
+                target['boxes'] = np.array(new["bboxes"], dtype=bboxes.dtype)
+
+        return image, target
+
+
+# ------------------------- Preprocessers -------------------------
+## YOLO-style Transform for Train
+class YOLOAugmentation(object):
+    def __init__(self,
+                 img_size=640,
+                 affine_params=None,
+                 use_ablu=False,
+                 pixel_mean = [0., 0., 0.],
+                 pixel_std  = [255., 255., 255.],
+                 box_format='xyxy',
+                 normalize_coords=False):
+        # Basic parameters
+        self.img_size   = img_size
+        self.pixel_mean = pixel_mean
+        self.pixel_std  = pixel_std
+        self.box_format = box_format
+        self.affine_params = affine_params
+        self.normalize_coords = normalize_coords
+        self.color_format = 'bgr'
+        # Albumentations
+        self.ablu_trans = Albumentations(img_size) if use_ablu else None
+
+    def __call__(self, image, target, mosaic=False):
+        # --------------- Resize image ---------------
+        orig_h, orig_w = image.shape[:2]
+        ratio = self.img_size / max(orig_h, orig_w)
+        if ratio != 1: 
+            new_shape = (int(round(orig_w * ratio)), int(round(orig_h * ratio)))
+            image = cv2.resize(image, new_shape)
+        img_h, img_w = image.shape[:2]
+
+        # --------------- Filter bad targets ---------------
+        tgt_boxes_wh = target["boxes"][..., 2:] - target["boxes"][..., :2]
+        min_tgt_size = np.min(tgt_boxes_wh, axis=-1)
+        keep = (min_tgt_size > 1)
+        target["boxes"]  = target["boxes"][keep]
+        target["labels"] = target["labels"][keep]
+
+        # --------------- Albumentations ---------------
+        if self.ablu_trans is not None:
+            image, target = self.ablu_trans(image, target)
+
+        # --------------- HSV augmentations ---------------
+        image = augment_hsv(image,
+                            hgain=self.affine_params['hsv_h'], 
+                            sgain=self.affine_params['hsv_s'], 
+                            vgain=self.affine_params['hsv_v'])
+        
+        # --------------- Spatial augmentations ---------------
+        ## Random perspective
+        if not mosaic:
+            # rescale bbox
+            target["boxes"][..., [0, 2]] = target["boxes"][..., [0, 2]] / orig_w * img_w
+            target["boxes"][..., [1, 3]] = target["boxes"][..., [1, 3]] / orig_h * img_h
+
+            # spatial augment
+            target_ = np.concatenate((target['labels'][..., None], target['boxes']), axis=-1)
+            image, target_ = random_perspective(image, target_,
+                                                degrees     = self.affine_params['degrees'],
+                                                translate   = self.affine_params['translate'],
+                                                scale       = self.affine_params['scale'],
+                                                shear       = self.affine_params['shear'],
+                                                perspective = self.affine_params['perspective']
+                                                )
+            target['boxes']  = target_[..., 1:]
+            target['labels'] = target_[..., 0]
+
+        ## Random flip
+        if random.random() < 0.5:
+            w = image.shape[1]
+            image = np.fliplr(image).copy()
+            boxes = target['boxes'].copy()
+            boxes[..., [0, 2]] = w - boxes[..., [2, 0]]
+            target["boxes"] = boxes
+
+        # --------------- To torch.Tensor ---------------
+        image = F.to_tensor(image) * 255.
+        image = F.normalize(image, self.pixel_mean, self.pixel_std)
+        if target is not None:
+            target["boxes"] = torch.as_tensor(target["boxes"]).float()
+            target["labels"] = torch.as_tensor(target["labels"]).long()
+
+            # normalize coords
+            if self.normalize_coords:
+                target["boxes"][..., [0, 2]] /= img_w
+                target["boxes"][..., [1, 3]] /= img_h
+
+            # xyxy -> xywh
+            if self.box_format == "xywh":
+                box_cxcy = (target["boxes"][..., :2] + target["boxes"][..., 2:]) * 0.5
+                box_bwbh =  target["boxes"][..., 2:] - target["boxes"][..., :2]
+                target["boxes"] = torch.cat([box_cxcy, box_bwbh], dim=-1)
+
+
+        # --------------- Pad Image ---------------
+        img_h0, img_w0 = image.shape[1:]
+        pad_image = torch.zeros([image.size(0), self.img_size, self.img_size]).float()
+        pad_image[:, :img_h0, :img_w0] = image
+
+        return pad_image, target, ratio
+
+## YOLO-style Transform for Eval
+class YOLOBaseTransform(object):
+    def __init__(self,
+                 img_size=640,
+                 max_stride=32,
+                 pixel_mean = [0., 0., 0.],
+                 pixel_std  = [255., 255., 255.],
+                 box_format='xyxy',
+                 normalize_coords=False):
+        self.img_size = img_size
+        self.max_stride = max_stride
+        self.pixel_mean = pixel_mean
+        self.pixel_std  = pixel_std
+        self.box_format = box_format
+        self.normalize_coords = normalize_coords
+        self.color_format = 'bgr'
+
+    def __call__(self, image, target=None, mosaic=False):
+        # --------------- Resize image ---------------
+        orig_h, orig_w = image.shape[:2]
+        ratio = self.img_size / max(orig_h, orig_w)
+        if ratio != 1: 
+            new_shape = (int(round(orig_w * ratio)), int(round(orig_h * ratio)))
+            image = cv2.resize(image, new_shape)
+        img_h, img_w = image.shape[:2]
+
+        # --------------- Rescale bboxes ---------------
+        if target is not None:
+            # rescale bbox
+            target["boxes"][..., [0, 2]] = target["boxes"][..., [0, 2]] / orig_w * img_w
+            target["boxes"][..., [1, 3]] = target["boxes"][..., [1, 3]] / orig_h * img_h
+
+        # --------------- To torch.Tensor ---------------
+        image = F.to_tensor(image) * 255.
+        image = F.normalize(image, self.pixel_mean, self.pixel_std)
+        if target is not None:
+            target["boxes"] = torch.as_tensor(target["boxes"]).float()
+            target["labels"] = torch.as_tensor(target["labels"]).long()
+
+            # normalize coords
+            if self.normalize_coords:
+                target["boxes"][..., [0, 2]] /= img_w
+                target["boxes"][..., [1, 3]] /= img_h
+            
+            # xyxy -> xywh
+            if self.box_format == "xywh":
+                box_cxcy = (target["boxes"][..., :2] + target["boxes"][..., 2:]) * 0.5
+                box_bwbh =  target["boxes"][..., 2:] - target["boxes"][..., :2]
+                target["boxes"] = torch.cat([box_cxcy, box_bwbh], dim=-1)
+
+        # --------------- Pad image ---------------
+        img_h0, img_w0 = image.shape[1:]
+        dh = img_h0 % self.max_stride
+        dw = img_w0 % self.max_stride
+        dh = dh if dh == 0 else self.max_stride - dh
+        dw = dw if dw == 0 else self.max_stride - dw
+        
+        pad_img_h = img_h0 + dh
+        pad_img_w = img_w0 + dw
+        pad_image = torch.zeros([image.size(0), pad_img_h, pad_img_w]).float()
+        pad_image[:, :img_h0, :img_w0] = image
+
+        return pad_image, target, ratio

+ 20 - 0
dataset/scripts/COCO2017.sh

@@ -0,0 +1,20 @@
+mkdir COCO
+cd COCO
+
+wget http://images.cocodataset.org/zips/train2017.zip
+wget http://images.cocodataset.org/zips/val2017.zip
+wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
+wget http://images.cocodataset.org/zips/test2017.zip
+wget http://images.cocodataset.org/annotations/image_info_test2017.zip 
+
+unzip train2017.zip
+unzip val2017.zip
+unzip annotations_trainval2017.zip
+unzip test2017.zip
+unzip image_info_test2017.zip
+
+# rm -f train2017.zip
+# rm -f val2017.zip
+# rm -f annotations_trainval2017.zip
+# rm -f test2017.zip
+# rm -f image_info_test2017.zip

+ 42 - 0
dataset/scripts/VOC2007.sh

@@ -0,0 +1,42 @@
+#!/bin/bash
+# Ellis Brown
+
+start=`date +%s`
+
+# handle optional download dir
+if [ -z "$1" ]
+  then
+    # navigate to ~/data
+    echo "navigating to ~/data/ ..." 
+    mkdir -p ~/data
+    cd ~/data/
+  else
+    # check if is valid directory
+    if [ ! -d $1 ]; then
+        echo $1 "is not a valid directory"
+        exit 0
+    fi
+    echo "navigating to" $1 "..."
+    cd $1
+fi
+
+echo "Downloading VOC2007 trainval ..."
+# Download the data.
+curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar
+echo "Downloading VOC2007 test data ..."
+curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtest_06-Nov-2007.tar
+echo "Done downloading."
+
+# Extract data
+echo "Extracting trainval ..."
+tar -xvf VOCtrainval_06-Nov-2007.tar
+echo "Extracting test ..."
+tar -xvf VOCtest_06-Nov-2007.tar
+echo "removing tars ..."
+rm VOCtrainval_06-Nov-2007.tar
+rm VOCtest_06-Nov-2007.tar
+
+end=`date +%s`
+runtime=$((end-start))
+
+echo "Completed in" $runtime "seconds"

+ 38 - 0
dataset/scripts/VOC2012.sh

@@ -0,0 +1,38 @@
+#!/bin/bash
+# Ellis Brown
+
+start=`date +%s`
+
+# handle optional download dir
+if [ -z "$1" ]
+  then
+    # navigate to ~/data
+    echo "navigating to ~/data/ ..." 
+    mkdir -p ~/data
+    cd ~/data/
+  else
+    # check if is valid directory
+    if [ ! -d $1 ]; then
+        echo $1 "is not a valid directory"
+        exit 0
+    fi
+    echo "navigating to" $1 "..."
+    cd $1
+fi
+
+echo "Downloading VOC2012 trainval ..."
+# Download the data.
+curl -LO http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
+echo "Done downloading."
+
+
+# Extract data
+echo "Extracting trainval ..."
+tar -xvf VOCtrainval_11-May-2012.tar
+echo "removing tar ..."
+rm VOCtrainval_11-May-2012.tar
+
+end=`date +%s`
+runtime=$((end-start))
+
+echo "Completed in" $runtime "seconds"

+ 70 - 0
dataset/scripts/data_to_h5py.py

@@ -0,0 +1,70 @@
+import cv2
+import h5py
+import os
+import argparse
+import numpy as np
+import sys
+
+sys.path.append('..')
+from voc import VOCDetection
+from coco import COCODataset
+
+# ---------------------- Opt ----------------------
+parser = argparse.ArgumentParser(description='Cache-Dataset')
+parser.add_argument('-d', '--dataset', default='voc',
+                    help='coco, voc, widerface, crowdhuman')
+parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
+                    help='data root')
+parser.add_argument('-size', '--img_size', default=640, type=int,
+                    help='input image size.')
+parser.add_argument('--mosaic', default=None, type=float,
+                    help='mosaic augmentation.')
+parser.add_argument('--mixup', default=None, type=float,
+                    help='mixup augmentation.')
+parser.add_argument('--keep_ratio', action="store_true", default=False,
+                    help='keep aspect ratio.')
+parser.add_argument('--show', action="store_true", default=False,
+                    help='keep aspect ratio.')
+
+args = parser.parse_args()
+
+
+# ---------------------- Build Dataset ----------------------
+if args.dataset == 'voc':
+    root = os.path.join(args.root, 'VOCdevkit')
+    dataset = VOCDetection(args.img_size, root)
+elif args.dataset == 'coco':
+    root = os.path.join(args.root, 'COCO')
+    dataset = COCODataset(args.img_size, args.root)
+print('Data length: ', len(dataset))
+
+
+# ---------------------- Main Process ----------------------
+cached_image = []
+dataset_size = len(dataset)
+for i in range(len(dataset)):
+    if i % 5000 == 0:
+        print("[{} / {}]".format(i, dataset_size))
+    # load an image
+    image, image_id = dataset.pull_image(i)
+    orig_h, orig_w, _ = image.shape
+
+    # resize image
+    if args.keep_ratio:
+        r = args.img_size / max(orig_h, orig_w)
+        if r != 1: 
+            interp = cv2.INTER_LINEAR
+            new_size = (int(orig_w * r), int(orig_h * r))
+            image = cv2.resize(image, new_size, interpolation=interp)
+    else:
+        image = cv2.resize(image, (int(args.img_size), int(args.img_size)))
+
+    cached_image.append(image)
+    if args.show:
+        cv2.imshow('image', image)
+        # cv2.imwrite(str(i)+'.jpg', img)
+        cv2.waitKey(0)
+
+save_path = "dataset/cache/"
+os.makedirs(save_path, exist_ok=True)
+np.save(save_path + '{}_train_images.npy'.format(args.dataset), cached_image)

+ 313 - 0
dataset/voc.py

@@ -0,0 +1,313 @@
+import cv2
+import random
+import numpy as np
+import os.path as osp
+import xml.etree.ElementTree as ET
+import torch.utils.data as data
+
+try:
+    from .data_augment.strong_augment import MosaicAugment, MixupAugment
+except:
+    from  data_augment.strong_augment import MosaicAugment, MixupAugment
+
+
+VOC_CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
+voc_class_indexs = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19]
+voc_class_labels = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor')
+
+
+class VOCAnnotationTransform(object):
+    def __init__(self, class_to_ind=None, keep_difficult=False):
+        self.class_to_ind = class_to_ind or dict(
+            zip(VOC_CLASSES, range(len(VOC_CLASSES))))
+        self.keep_difficult = keep_difficult
+
+    def __call__(self, target):
+        res = []
+        for obj in target.iter('object'):
+            difficult = int(obj.find('difficult').text) == 1
+            if not self.keep_difficult and difficult:
+                continue
+            name = obj.find('name').text.lower().strip()
+            bbox = obj.find('bndbox')
+
+            pts = ['xmin', 'ymin', 'xmax', 'ymax']
+            bndbox = []
+            for i, pt in enumerate(pts):
+                cur_pt = int(bbox.find(pt).text) - 1
+                bndbox.append(cur_pt)
+            label_idx = self.class_to_ind[name]
+            bndbox.append(label_idx)
+            res += [bndbox]  # [x1, y1, x2, y2, label_ind]
+
+        return res  # [[x1, y1, x2, y2, label_ind], ... ]
+
+
+class VOCDataset(data.Dataset):
+    def __init__(self, 
+                 cfg,
+                 data_dir   :str = None, 
+                 image_set  = [('2007', 'trainval'), ('2012', 'trainval')],
+                 transform  = None,
+                 is_train   :bool =False,
+                 ):
+        # ----------- Basic parameters -----------
+        self.image_set = image_set
+        self.is_train  = is_train
+        self.num_classes = 80
+        # ----------- Path parameters -----------
+        self.root = data_dir
+        self._annopath = osp.join('%s', 'Annotations', '%s.xml')
+        self._imgpath = osp.join('%s', 'JPEGImages', '%s.jpg')
+        # ----------- Data parameters -----------
+        self.ids = list()
+        for (year, name) in image_set:
+            rootpath = osp.join(self.root, 'VOC' + year)
+            for line in open(osp.join(rootpath, 'ImageSets', 'Main', name + '.txt')):
+                self.ids.append((rootpath, line.strip()))
+        self.dataset_size = len(self.ids)
+        self.class_labels = voc_class_labels
+        self.class_indexs = voc_class_indexs
+        # ----------- Transform parameters -----------
+        self.target_transform = VOCAnnotationTransform()
+        self.transform = transform
+        if is_train:
+            self.mosaic_prob = cfg.mosaic_prob
+            self.mixup_prob  = cfg.mixup_prob
+            self.copy_paste  = cfg.copy_paste
+            self.mosaic_augment = None if cfg.mosaic_prob == 0. else MosaicAugment(cfg.train_img_size, cfg.affine_params, is_train)
+            self.mixup_augment  = None if cfg.mixup_prob == 0. and cfg.copy_paste == 0.  else MixupAugment(cfg.train_img_size)
+        else:
+            self.mosaic_prob = 0.0
+            self.mixup_prob  = 0.0
+            self.copy_paste  = 0.0
+            self.mosaic_augment = None
+            self.mixup_augment  = None
+        print('==============================')
+        print('use Mosaic Augmentation: {}'.format(self.mosaic_prob))
+        print('use Mixup Augmentation:  {}'.format(self.mixup_prob))
+        print('use Copy-paste Augmentation: {}'.format(self.copy_paste))
+
+    # ------------ Basic dataset function ------------
+    def __getitem__(self, index):
+        image, target, deltas = self.pull_item(index)
+        return image, target, deltas
+
+    def __len__(self):
+        return self.dataset_size
+
+    # ------------ Mosaic & Mixup ------------
+    def load_mosaic(self, index):
+        # ------------ Prepare 4 indexes of images ------------
+        ## Load 4x mosaic image
+        index_list = np.arange(index).tolist() + np.arange(index+1, len(self.ids)).tolist()
+        id1 = index
+        id2, id3, id4 = random.sample(index_list, 3)
+        indexs = [id1, id2, id3, id4]
+
+        ## Load images and targets
+        image_list = []
+        target_list = []
+        for index in indexs:
+            img_i, target_i = self.load_image_target(index)
+            image_list.append(img_i)
+            target_list.append(target_i)
+
+        # ------------ Mosaic augmentation ------------
+        image, target = self.mosaic_augment(image_list, target_list)
+
+        return image, target
+
+    def load_mixup(self, origin_image, origin_target, yolox_style=False):
+        # ------------ Load a new image & target ------------
+        new_index = np.random.randint(0, len(self.ids))
+        new_image, new_target = self.load_mosaic(new_index)
+            
+        # ------------ Mixup augmentation ------------
+        image, target = self.mixup_augment(origin_image, origin_target, new_image, new_target, yolox_style)
+
+        return image, target
+    
+    # ------------ Load data function ------------
+    def load_image_target(self, index):
+        # load an image
+        image, _ = self.pull_image(index)
+        height, width, channels = image.shape
+
+        # laod an annotation
+        anno, _ = self.pull_anno(index)
+
+        # guard against no boxes via resizing
+        anno = np.array(anno).reshape(-1, 5)
+        target = {
+            "boxes": anno[:, :4],
+            "labels": anno[:, 4],
+            "orig_size": [height, width]
+        }
+        
+        return image, target
+
+    def pull_item(self, index):
+        if random.random() < self.mosaic_prob:
+            # load a mosaic image
+            mosaic = True
+            image, target = self.load_mosaic(index)
+        else:
+            mosaic = False
+            # load an image and target
+            image, target = self.load_image_target(index)
+
+        # Yolov5-MixUp
+        mixup = False
+        if random.random() < self.mixup_prob:
+            mixup = True
+            image, target = self.load_mixup(image, target)
+
+        # Copy-paste (use Yolox-Mixup to approximate copy-paste)
+        if not mixup and random.random() < self.copy_paste:
+            image, target = self.load_mixup(image, target, yolox_style=True)
+
+        # augment
+        image, target, deltas = self.transform(image, target, mosaic)
+
+        return image, target, deltas
+
+    def pull_image(self, index):
+        img_id = self.ids[index]
+        image = cv2.imread(self._imgpath % img_id, cv2.IMREAD_COLOR)
+
+        return image, img_id
+
+    def pull_anno(self, index):
+        img_id = self.ids[index]
+        anno = ET.parse(self._annopath % img_id).getroot()
+        anno = self.target_transform(anno)
+
+        return anno, img_id
+
+
+if __name__ == "__main__":
+    import time
+    import argparse
+    from build import build_transform
+    
+    parser = argparse.ArgumentParser(description='VOC-Dataset')
+
+    # opt
+    parser.add_argument('--root', default='D:/python_work/dataset/VOCdevkit/',
+                        help='data root')
+    parser.add_argument('--is_train', action="store_true", default=False,
+                        help='train or not.')
+    parser.add_argument('--aug_type', default="yolo", type=str, choices=["yolo", "ssd"],
+                        help='yolo, ssd.')
+    
+    args = parser.parse_args()
+
+    class YoloBaseConfig(object):
+        def __init__(self) -> None:
+            self.max_stride = 32
+            # ---------------- Data process config ----------------
+            self.box_format = 'xywh'
+            self.normalize_coords = False
+            self.mosaic_prob = 1.0
+            self.mixup_prob  = 0.15
+            self.copy_paste  = 0.3
+            ## Pixel mean & std
+            self.pixel_mean = [0., 0., 0.]
+            self.pixel_std  = [255., 255., 255.]
+            ## Transforms
+            self.train_img_size = 640
+            self.test_img_size  = 640
+            self.use_ablu = True
+            self.aug_type = 'yolo'
+            self.affine_params = {
+                'degrees': 0.0,
+                'translate': 0.2,
+                'scale': [0.1, 2.0],
+                'shear': 0.0,
+                'perspective': 0.0,
+                'hsv_h': 0.015,
+                'hsv_s': 0.7,
+                'hsv_v': 0.4,
+            }
+
+    class SSDBaseConfig(object):
+        def __init__(self) -> None:
+            self.max_stride = 32
+            # ---------------- Data process config ----------------
+            self.box_format = 'xywh'
+            self.normalize_coords = False
+            self.mosaic_prob = 0.0
+            self.mixup_prob  = 0.0
+            self.copy_paste  = 0.0
+            ## Pixel mean & std
+            self.pixel_mean = [0., 0., 0.]
+            self.pixel_std  = [255., 255., 255.]
+            ## Transforms
+            self.train_img_size = 640
+            self.test_img_size  = 640
+            self.aug_type = 'ssd'
+
+    if args.aug_type == "yolo":
+        cfg = YoloBaseConfig()
+    elif args.aug_type == "ssd":
+        cfg = SSDBaseConfig()
+
+    transform = build_transform(cfg, args.is_train)
+    dataset = VOCDataset(cfg, args.root, [('2007', 'test')], transform, args.is_train)
+    
+    np.random.seed(0)
+    class_colors = [(np.random.randint(255),
+                     np.random.randint(255),
+                     np.random.randint(255)) for _ in range(80)]
+    print('Data length: ', len(dataset))
+
+    for i in range(1000):
+        t0 = time.time()
+        image, target, deltas = dataset.pull_item(i)
+        print("Load data: {} s".format(time.time() - t0))
+
+        # to numpy
+        image = image.permute(1, 2, 0).numpy()
+        
+        # denormalize
+        image = image * cfg.pixel_std + cfg.pixel_mean
+
+        # rgb -> bgr
+        if transform.color_format == 'rgb':
+            image = image[..., (2, 1, 0)]
+
+        # to uint8
+        image = image.astype(np.uint8)
+        image = image.copy()
+        img_h, img_w = image.shape[:2]
+
+        boxes = target["boxes"]
+        labels = target["labels"]
+
+        for box, label in zip(boxes, labels):
+            if cfg.box_format == 'xyxy':
+                x1, y1, x2, y2 = box
+            elif cfg.box_format == 'xywh':
+                cx, cy, bw, bh = box
+                x1 = cx - 0.5 * bw
+                y1 = cy - 0.5 * bh
+                x2 = cx + 0.5 * bw
+                y2 = cy + 0.5 * bh
+            
+            if cfg.normalize_coords:
+                x1 *= img_w
+                y1 *= img_h
+                x2 *= img_w
+                y2 *= img_h
+
+            cls_id = int(label)
+            color = class_colors[cls_id]
+            # class name
+            label = voc_class_labels[cls_id]
+            image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
+            # put the test on the bbox
+            cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
+        cv2.imshow('gt', image)
+        # cv2.imwrite(str(i)+'.jpg', img)
+        cv2.waitKey(0)

+ 290 - 0
demo.py

@@ -0,0 +1,290 @@
+import argparse
+import cv2
+import os
+import time
+import numpy as np
+import imageio
+
+import torch
+
+# load transform
+from dataset.build import build_transform
+
+# load some utils
+from utils.misc import load_weight
+from utils.box_ops import rescale_bboxes
+from utils.vis_tools import visualize
+
+from models import build_model
+from config import build_config
+from dataset.coco import coco_class_labels
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
+    # Basic setting
+    parser.add_argument('-size', '--img_size', default=640, type=int,
+                        help='the max size of input image')
+    parser.add_argument('--mode', default='image',
+                        type=str, help='Use the data from image, video or camera')
+    parser.add_argument('--cuda', action='store_true', default=False,
+                        help='Use cuda')
+    parser.add_argument('--path_to_img', default='dataset/demo/images/',
+                        type=str, help='The path to image files')
+    parser.add_argument('--path_to_vid', default='dataset/demo/videos/',
+                        type=str, help='The path to video files')
+    parser.add_argument('--path_to_save', default='det_results/demos/',
+                        type=str, help='The path to save the detection results')
+    parser.add_argument('--show', action='store_true', default=False,
+                        help='show visualization')
+    parser.add_argument('--gif', action='store_true', default=False, 
+                        help='generate gif.')
+
+    # Model setting
+    parser.add_argument('-m', '--model', default='yolo_n', type=str,
+                        help='build yolo')
+    parser.add_argument('-nc', '--num_classes', default=80, type=int,
+                        help='number of classes.')
+    parser.add_argument('--weight', default=None,
+                        type=str, help='Trained state_dict file path to open')
+    parser.add_argument("--deploy", action="store_true", default=False,
+                        help="deploy mode or not")
+    parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
+                        help='fuse Conv & BN')
+
+    # Data setting
+    parser.add_argument('-d', '--dataset', default='coco',
+                        help='coco, voc, crowdhuman, widerface.')
+
+    return parser.parse_args()
+                    
+
+def detect(args,
+           model, 
+           device, 
+           transform, 
+           num_classes,
+           class_names,
+           mode='image'):
+    # class color
+    np.random.seed(0)
+    class_colors = [(np.random.randint(255),
+                     np.random.randint(255),
+                     np.random.randint(255)) for _ in range(num_classes)]
+    save_path = os.path.join(args.path_to_save, mode)
+    os.makedirs(save_path, exist_ok=True)
+
+    # ------------------------- Camera ----------------------------
+    if mode == 'camera':
+        print('use camera !!!')
+        fourcc = cv2.VideoWriter_fourcc(*'XVID')
+        save_size = (640, 480)
+        cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
+        save_video_name = os.path.join(save_path, cur_time+'.avi')
+        fps = 15.0
+        out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
+        print(save_video_name)
+        image_list = []
+
+        cap = cv2.VideoCapture(0, cv2.CAP_DSHOW)
+        while True:
+            ret, frame = cap.read()
+            if ret:
+                if cv2.waitKey(1) == ord('q'):
+                    break
+                orig_h, orig_w, _ = frame.shape
+
+                # prepare
+                x, _, ratio = transform(frame)
+                x = x.unsqueeze(0).to(device)
+                
+                # inference
+                t0 = time.time()
+                outputs = model(x)
+                scores = outputs['scores']
+                labels = outputs['labels']
+                bboxes = outputs['bboxes']
+                t1 = time.time()
+                print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
+
+                # rescale bboxes
+                bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
+
+                # vis detection
+                frame_vis = visualize(image=frame, 
+                                      bboxes=bboxes,
+                                      scores=scores, 
+                                      labels=labels,
+                                      class_colors=class_colors,
+                                      class_names=class_names
+                                      )
+                frame_resized = cv2.resize(frame_vis, save_size)
+                out.write(frame_resized)
+
+                if args.gif:
+                    gif_resized = cv2.resize(frame, (640, 480))
+                    gif_resized_rgb = gif_resized[..., (2, 1, 0)]
+                    image_list.append(gif_resized_rgb)
+
+                if args.show:
+                    cv2.imshow('detection', frame_resized)
+                    cv2.waitKey(1)
+            else:
+                break
+        cap.release()
+        out.release()
+        cv2.destroyAllWindows()
+
+        # generate GIF
+        if args.gif:
+            save_gif_path =  os.path.join(save_path, 'gif_files')
+            os.makedirs(save_gif_path, exist_ok=True)
+            save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
+            print('generating GIF ...')
+            imageio.mimsave(save_gif_name, image_list, fps=fps)
+            print('GIF done: {}'.format(save_gif_name))
+
+    # ------------------------- Video ---------------------------
+    elif mode == 'video':
+        video = cv2.VideoCapture(args.path_to_vid)
+        fourcc = cv2.VideoWriter_fourcc(*'XVID')
+        save_size = (640, 480)
+        cur_time = time.strftime('%Y-%m-%d-%H-%M-%S',time.localtime(time.time()))
+        save_video_name = os.path.join(save_path, cur_time+'.avi')
+        fps = 15.0
+        out = cv2.VideoWriter(save_video_name, fourcc, fps, save_size)
+        print(save_video_name)
+        image_list = []
+
+        while(True):
+            ret, frame = video.read()
+            
+            if ret:
+                # ------------------------- Detection ---------------------------
+                orig_h, orig_w, _ = frame.shape
+
+                # prepare
+                x, _, ratio = transform(frame)
+                x = x.unsqueeze(0).to(device)
+
+                # inference
+                t0 = time.time()
+                outputs = model(x)
+                scores = outputs['scores']
+                labels = outputs['labels']
+                bboxes = outputs['bboxes']
+                t1 = time.time()
+                print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
+
+                # rescale bboxes
+                bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
+
+                # vis detection
+                frame_vis = visualize(image=frame, 
+                                      bboxes=bboxes,
+                                      scores=scores, 
+                                      labels=labels,
+                                      class_colors=class_colors,
+                                      class_names=class_names
+                                      )
+
+                frame_resized = cv2.resize(frame_vis, save_size)
+                out.write(frame_resized)
+
+                if args.gif:
+                    gif_resized = cv2.resize(frame, (640, 480))
+                    gif_resized_rgb = gif_resized[..., (2, 1, 0)]
+                    image_list.append(gif_resized_rgb)
+
+                if args.show:
+                    cv2.imshow('detection', frame_resized)
+                    cv2.waitKey(1)
+            else:
+                break
+        video.release()
+        out.release()
+        cv2.destroyAllWindows()
+
+        # generate GIF
+        if args.gif:
+            save_gif_path =  os.path.join(save_path, 'gif_files')
+            os.makedirs(save_gif_path, exist_ok=True)
+            save_gif_name = os.path.join(save_gif_path, '{}.gif'.format(cur_time))
+            print('generating GIF ...')
+            imageio.mimsave(save_gif_name, image_list, fps=fps)
+            print('GIF done: {}'.format(save_gif_name))
+
+    # ------------------------- Image ----------------------------
+    elif mode == 'image':
+        for i, img_id in enumerate(os.listdir(args.path_to_img)):
+            image = cv2.imread((args.path_to_img + '/' + img_id), cv2.IMREAD_COLOR)
+            orig_h, orig_w, _ = image.shape
+
+            # prepare
+            x, _, ratio = transform(image)
+            x = x.unsqueeze(0).to(device)
+
+            # inference
+            t0 = time.time()
+            outputs = model(x)
+            scores = outputs['scores']
+            labels = outputs['labels']
+            bboxes = outputs['bboxes']
+            t1 = time.time()
+            print("Infer time: {:.1f} ms. ".format((t1 - t0) * 1000))
+
+            # rescale bboxes
+            bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
+
+            # vis detection
+            img_processed = visualize(image=image, 
+                                      bboxes=bboxes,
+                                      scores=scores, 
+                                      labels=labels,
+                                      class_colors=class_colors,
+                                      class_names=class_names
+                                      )
+            cv2.imwrite(os.path.join(save_path, str(i).zfill(6)+'.jpg'), img_processed)
+            if args.show:
+                cv2.imshow('detection', img_processed)
+                cv2.waitKey(0)
+
+
+def run():
+    args = parse_args()
+    # cuda
+    if args.cuda:
+        print('use cuda')
+        device = torch.device("cuda")
+    else:
+        device = torch.device("cpu")
+
+    # config
+    cfg = build_config(args)
+    cfg.num_classes = 80
+    cfg.class_labels = coco_class_labels
+    
+    # build model
+    model = build_model(args, cfg, False)
+
+    # load trained weight
+    model = load_weight(model, args.weight, args.fuse_conv_bn)
+    model.to(device).eval()
+
+    # transform
+    transform = build_transform(cfg, is_train=False)
+
+    print("================= DETECT =================")
+    # run
+    detect(args         = args,
+           mode         = args.mode,
+           model        = model, 
+           device       = device,
+           transform    = transform,
+           num_classes  = cfg.num_classes,
+           class_names  = cfg.class_labels,
+           )
+
+
+if __name__ == '__main__':
+    run()

+ 562 - 0
engine.py

@@ -0,0 +1,562 @@
+import torch
+import torch.distributed as dist
+
+import os
+import random
+
+# ----------------- Extra Components -----------------
+from utils import distributed_utils
+from utils.misc import MetricLogger, SmoothedValue
+from utils.vis_tools import vis_data
+
+# ----------------- Optimizer & LrScheduler Components -----------------
+from utils.solver.optimizer import build_yolo_optimizer, build_rtdetr_optimizer
+from utils.solver.lr_scheduler import LinearWarmUpLrScheduler, build_lr_scheduler
+
+
+class YoloTrainer(object):
+    def __init__(self,
+                 # Basic parameters
+                 args,
+                 cfg,
+                 device,
+                 # Model parameters
+                 model,
+                 model_ema,
+                 criterion,
+                 # Data parameters
+                 train_transform,
+                 val_transform,
+                 dataset,
+                 train_loader,
+                 evaluator,
+                 ):
+        # ------------------- basic parameters -------------------
+        self.args = args
+        self.cfg  = cfg
+        self.epoch = 0
+        self.best_map = -1.
+        self.device = device
+        self.criterion = criterion
+        self.heavy_eval = False
+        self.model_ema = model_ema
+        # weak augmentatino stage
+        self.second_stage = False
+        self.second_stage_epoch = cfg.no_aug_epoch
+        # path to save model
+        self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
+        os.makedirs(self.path_to_save, exist_ok=True)
+
+        # ---------------------------- Transform ----------------------------
+        self.train_transform = train_transform
+        self.val_transform   = val_transform
+
+        # ---------------------------- Dataset & Dataloader ----------------------------
+        self.dataset      = dataset
+        self.train_loader = train_loader
+
+        # ---------------------------- Evaluator ----------------------------
+        self.evaluator = evaluator
+
+        # ---------------------------- Build Grad. Scaler ----------------------------
+        self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
+
+        # ---------------------------- Build Optimizer ----------------------------
+        cfg.grad_accumulate = max(64 // args.batch_size, 1)
+        cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
+        cfg.min_lr  = cfg.base_lr * cfg.min_lr_ratio
+        self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)
+
+        # ---------------------------- Build LR Scheduler ----------------------------
+        self.lr_scheduler_warmup = LinearWarmUpLrScheduler(cfg.base_lr, wp_iter=cfg.warmup_epoch * len(self.train_loader))
+        self.lr_scheduler = build_lr_scheduler(cfg, self.optimizer, args.resume)
+
+        # ---------------------------- Build Model-EMA ----------------------------
+        if self.model_ema is not None:
+            update_init = self.start_epoch * len(self.train_loader) // cfg.grad_accumulate
+            print("Initialize ModelEMA's updates: {}".format(update_init))
+            self.model_ema.updates = update_init
+
+    def train(self, model):
+        for epoch in range(self.start_epoch, self.cfg.max_epoch):
+            if self.args.distributed:
+                self.train_loader.batch_sampler.sampler.set_epoch(epoch)
+
+            # check second stage
+            if epoch >= (self.cfg.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
+                self.check_second_stage()
+                # save model of the last mosaic epoch
+                weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
+                torch.save({'model': model.state_dict(),
+                            'mAP': round(self.evaluator.map*100, 1),
+                            'optimizer': self.optimizer.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)
+
+            # train one epoch
+            self.epoch = epoch
+            self.train_one_epoch(model)
+
+            # LR Schedule
+            if (epoch + 1) > self.cfg.warmup_epoch:
+                self.lr_scheduler.step()
+
+            # eval one epoch
+            if self.heavy_eval:
+                model_eval = model.module if self.args.distributed else model
+                self.eval(model_eval)
+            else:
+                model_eval = model.module if self.args.distributed else model
+                if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
+                    self.eval(model_eval)
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 epoch")
+                break
+
+    def eval(self, model):
+        # set eval mode
+        model.eval()
+        model_eval = model if self.model_ema is None else self.model_ema.ema
+
+        if distributed_utils.is_main_process():
+            # check evaluator
+            if self.evaluator is None:
+                print('No evaluator ... save model and go on training.')
+                print('Saving state, epoch: {}'.format(self.epoch))
+                weight_name = '{}_no_eval.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                torch.save({'model': model_eval.state_dict(),
+                            'mAP': -1.,
+                            'optimizer': self.optimizer.state_dict(),
+                            'lr_scheduler': self.lr_scheduler.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)               
+            else:
+                print('eval ...')
+                # evaluate
+                with torch.no_grad():
+                    self.evaluator.evaluate(model_eval)
+
+                # save model
+                cur_map = self.evaluator.map
+                if cur_map > self.best_map:
+                    # update best-map
+                    self.best_map = cur_map
+                    # save model
+                    print('Saving state, epoch:', self.epoch)
+                    weight_name = '{}_best.pth'.format(self.args.model)
+                    checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                    torch.save({'model': model_eval.state_dict(),
+                                'mAP': round(self.best_map*100, 1),
+                                'optimizer': self.optimizer.state_dict(),
+                                'lr_scheduler': self.lr_scheduler.state_dict(),
+                                'epoch': self.epoch,
+                                'args': self.args}, 
+                                checkpoint_path)                      
+
+        if self.args.distributed:
+            # wait for all processes to synchronize
+            dist.barrier()
+
+        # set train mode.
+        model.train()
+
+    def train_one_epoch(self, model):
+        metric_logger = MetricLogger(delimiter="  ")
+        metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
+        metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
+        header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
+        epoch_size = len(self.train_loader)
+        print_freq = 10
+
+        # basic parameters
+        epoch_size = len(self.train_loader)
+        img_size   = self.cfg.train_img_size
+        nw = epoch_size * self.cfg.warmup_epoch
+
+        # Train one epoch
+        for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
+            ni = iter_i + self.epoch * epoch_size
+            # Warmup
+            if nw > 0 and ni < nw:
+                self.lr_scheduler_warmup(ni, self.optimizer)
+            elif ni == nw:
+                print("Warmup stage is over.")
+                self.lr_scheduler_warmup.set_lr(self.optimizer, self.cfg.base_lr)
+                                
+            # To device
+            images = images.to(self.device, non_blocking=True).float()
+
+            # Multi scale
+            images, targets, img_size = self.rescale_image_targets(
+                images, targets, self.cfg.max_stride, self.cfg.multi_scale)
+                
+            # Visualize train targets
+            if self.args.vis_tgt:
+                vis_data(images,
+                         targets,
+                         self.cfg.num_classes,
+                         self.cfg.normalize_coords,
+                         self.train_transform.color_format,
+                         self.cfg.pixel_mean,
+                         self.cfg.pixel_std,
+                         self.cfg.box_format)
+
+            # Inference
+            with torch.cuda.amp.autocast(enabled=self.args.fp16):
+                outputs = model(images)
+                # Compute loss
+                loss_dict = self.criterion(outputs=outputs, targets=targets)
+                losses = loss_dict['losses']
+                losses /= self.cfg.grad_accumulate
+                loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
+
+            # Backward
+            self.scaler.scale(losses).backward()
+
+            # Gradient clip
+            if self.cfg.clip_max_norm > 0:
+                self.scaler.unscale_(self.optimizer)
+                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
+
+            # Optimize
+            if (iter_i + 1) % self.cfg.grad_accumulate == 0:
+                self.scaler.step(self.optimizer)
+                self.scaler.update()
+                self.optimizer.zero_grad()
+
+                # ModelEMA
+                if self.model_ema is not None:
+                    self.model_ema.update(model)
+
+            # Update log
+            metric_logger.update(**loss_dict_reduced)
+            metric_logger.update(lr=self.optimizer.param_groups[2]["lr"])
+            metric_logger.update(size=img_size)
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 iteration")
+                break
+
+        # Gather the stats from all processes
+        metric_logger.synchronize_between_processes()
+        print("Averaged stats:", metric_logger)
+
+    def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
+        """
+            Deployed for Multi scale trick.
+        """
+        # During training phase, the shape of input image is square.
+        old_img_size = images.shape[-1]
+        min_img_size = old_img_size * multi_scale_range[0]
+        max_img_size = old_img_size * multi_scale_range[1]
+
+        # Choose a new image size
+        new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
+        
+        # Resize
+        if new_img_size != old_img_size:
+            # interpolate
+            images = torch.nn.functional.interpolate(
+                                input=images, 
+                                size=new_img_size, 
+                                mode='bilinear', 
+                                align_corners=False)
+        # rescale targets
+        if not self.cfg.normalize_coords:
+            for tgt in targets:
+                boxes = tgt["boxes"].clone()
+                labels = tgt["labels"].clone()
+                boxes = torch.clamp(boxes, 0, old_img_size)
+                # rescale box
+                boxes[:, [0, 2]] = boxes[:, [0, 2]] / old_img_size * new_img_size
+                boxes[:, [1, 3]] = boxes[:, [1, 3]] / old_img_size * new_img_size
+                # refine tgt
+                tgt_boxes_wh = boxes[..., 2:] - boxes[..., :2]
+                min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
+                keep = (min_tgt_size >= 1)
+
+                tgt["boxes"] = boxes[keep]
+                tgt["labels"] = labels[keep]
+
+        return images, targets, new_img_size
+
+    def check_second_stage(self):
+        # set second stage
+        print('============== Second stage of Training ==============')
+        self.second_stage = True
+        self.heavy_eval = True
+
+        # close mosaic augmentation
+        if self.train_loader.dataset.mosaic_prob > 0.:
+            print(' - Close < Mosaic Augmentation > ...')
+            self.train_loader.dataset.mosaic_prob = 0.
+
+        # close mixup augmentation
+        if self.train_loader.dataset.mixup_prob > 0.:
+            print(' - Close < Mixup Augmentation > ...')
+            self.train_loader.dataset.mixup_prob = 0.
+
+        # close copy-paste augmentation
+        if self.train_loader.dataset.copy_paste > 0.:
+            print(' - Close < Copy-paste Augmentation > ...')
+            self.train_loader.dataset.copy_paste = 0.
+
+
+class RTDetrTrainer(object):
+    def __init__(self,
+                 # Basic parameters
+                 args,
+                 cfg,
+                 device,
+                 # Model parameters
+                 model,
+                 model_ema,
+                 criterion,
+                 # Data parameters
+                 train_transform,
+                 val_transform,
+                 dataset,
+                 train_loader,
+                 evaluator,
+                 ):
+        # ------------------- basic parameters -------------------
+        self.args = args
+        self.cfg  = cfg
+        self.epoch = 0
+        self.best_map = -1.
+        self.device = device
+        self.criterion = criterion
+        self.heavy_eval = False
+        self.model_ema = model_ema
+        # path to save model
+        self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
+        os.makedirs(self.path_to_save, exist_ok=True)
+
+        # ---------------------------- Transform ----------------------------
+        self.train_transform = train_transform
+        self.val_transform   = val_transform
+
+        # ---------------------------- Dataset & Dataloader ----------------------------
+        self.dataset      = dataset
+        self.train_loader = train_loader
+
+        # ---------------------------- Evaluator ----------------------------
+        self.evaluator = evaluator
+
+        # ---------------------------- Build Grad. Scaler ----------------------------
+        self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
+
+        # ---------------------------- Build Optimizer ----------------------------
+        cfg.grad_accumulate = max(16 // args.batch_size, 1)
+        cfg.base_lr = cfg.per_image_lr * args.batch_size * cfg.grad_accumulate
+        cfg.min_lr  = cfg.base_lr * cfg.min_lr_ratio
+        self.optimizer, self.start_epoch = build_rtdetr_optimizer(cfg, model, args.resume)
+
+        # ---------------------------- Build LR Scheduler ----------------------------
+        self.wp_lr_scheduler = LinearWarmUpLrScheduler(cfg.base_lr, wp_iter=cfg.warmup_iters)
+        self.lr_scheduler    = build_lr_scheduler(cfg, self.optimizer, args.resume)
+
+        # ---------------------------- Build Model-EMA ----------------------------
+        if self.model_ema is not None:
+            update_init = self.start_epoch * len(self.train_loader) // cfg.grad_accumulate
+            print("Initialize ModelEMA's updates: {}".format(update_init))
+            self.model_ema.updates = update_init            
+
+    def train(self, model):
+        for epoch in range(self.start_epoch, self.cfg.max_epoch):
+            if self.args.distributed:
+                self.train_loader.batch_sampler.sampler.set_epoch(epoch)
+
+            # train one epoch
+            self.epoch = epoch
+            self.train_one_epoch(model)
+
+            # LR Scheduler
+            self.lr_scheduler.step()
+
+            # eval one epoch
+            if self.heavy_eval:
+                model_eval = model.module if self.args.distributed else model
+                self.eval(model_eval)
+            else:
+                model_eval = model.module if self.args.distributed else model
+                if (epoch % self.cfg.eval_epoch) == 0 or (epoch == self.cfg.max_epoch - 1):
+                    self.eval(model_eval)
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 epoch")
+                break
+
+    def eval(self, model):
+        # set eval mode
+        model.eval()
+        model_eval = model if self.model_ema is None else self.model_ema.ema
+
+        if distributed_utils.is_main_process():
+            # check evaluator
+            if self.evaluator is None:
+                print('No evaluator ... save model and go on training.')
+                print('Saving state, epoch: {}'.format(self.epoch))
+                weight_name = '{}_no_eval.pth'.format(self.args.model)
+                checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                torch.save({'model': model_eval.state_dict(),
+                            'mAP': -1.,
+                            'optimizer': self.optimizer.state_dict(),
+                            'lr_scheduler': self.lr_scheduler.state_dict(),
+                            'epoch': self.epoch,
+                            'args': self.args}, 
+                            checkpoint_path)               
+            else:
+                print('eval ...')
+                # evaluate
+                with torch.no_grad():
+                    self.evaluator.evaluate(model_eval)
+
+                # save model
+                cur_map = self.evaluator.map
+                if cur_map > self.best_map:
+                    # update best-map
+                    self.best_map = cur_map
+                    # save model
+                    print('Saving state, epoch:', self.epoch)
+                    weight_name = '{}_best.pth'.format(self.args.model)
+                    checkpoint_path = os.path.join(self.path_to_save, weight_name)
+                    torch.save({'model': model_eval.state_dict(),
+                                'mAP': round(self.best_map*100, 1),
+                                'optimizer': self.optimizer.state_dict(),
+                                'lr_scheduler': self.lr_scheduler.state_dict(),
+                                'epoch': self.epoch,
+                                'args': self.args}, 
+                                checkpoint_path)                      
+
+        if self.args.distributed:
+            # wait for all processes to synchronize
+            dist.barrier()
+
+        # set train mode.
+        model.train()
+
+    def train_one_epoch(self, model):
+        metric_logger = MetricLogger(delimiter="  ")
+        metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
+        metric_logger.add_meter('size', SmoothedValue(window_size=1, fmt='{value:d}'))
+        metric_logger.add_meter('grad_norm', SmoothedValue(window_size=1, fmt='{value:.1f}'))
+        header = 'Epoch: [{} / {}]'.format(self.epoch, self.cfg.max_epoch)
+        epoch_size = len(self.train_loader)
+        print_freq = 10
+
+        # basic parameters
+        epoch_size = len(self.train_loader)
+        img_size   = self.cfg.train_img_size
+        nw         = self.cfg.warmup_iters
+        lr_warmup_stage = True
+
+        # Train one epoch
+        for iter_i, (images, targets) in enumerate(metric_logger.log_every(self.train_loader, print_freq, header)):
+            ni = iter_i + self.epoch * epoch_size
+            # WarmUp
+            if ni < nw and lr_warmup_stage:
+                self.wp_lr_scheduler(ni, self.optimizer)
+            elif ni == nw and lr_warmup_stage:
+                print('Warmup stage is over.')
+                lr_warmup_stage = False
+                self.wp_lr_scheduler.set_lr(self.optimizer, self.cfg.base_lr)
+                                
+            # To device
+            images = images.to(self.device, non_blocking=True).float()
+            for tgt in targets:
+                tgt['boxes'] = tgt['boxes'].to(self.device)
+                tgt['labels'] = tgt['labels'].to(self.device)
+
+            # Multi scale
+            images, targets, img_size = self.rescale_image_targets(
+                images, targets, self.cfg.max_stride, self.cfg.multi_scale)
+                
+            # Visualize train targets
+            if self.args.vis_tgt:
+                vis_data(images,
+                         targets,
+                         self.cfg.num_classes,
+                         self.cfg.normalize_coords,
+                         self.train_transform.color_format,
+                         self.cfg.pixel_mean,
+                         self.cfg.pixel_std,
+                         self.cfg.box_format)
+
+            # Inference
+            with torch.cuda.amp.autocast(enabled=self.args.fp16):
+                outputs = model(images, targets)    
+                loss_dict = self.criterion(outputs, targets)
+                losses = sum(loss_dict.values())
+                losses /= self.cfg.grad_accumulate
+                loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
+
+            # Backward
+            self.scaler.scale(losses).backward()
+
+            # Gradient clip
+            grad_norm = None
+            if self.cfg.clip_max_norm > 0:
+                self.scaler.unscale_(self.optimizer)
+                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=self.cfg.clip_max_norm)
+
+            # Optimize
+            if (iter_i + 1) % self.cfg.grad_accumulate == 0:
+                self.scaler.step(self.optimizer)
+                self.scaler.update()
+                self.optimizer.zero_grad()
+
+                # ModelEMA
+                if self.model_ema is not None:
+                    self.model_ema.update(model)
+
+            # Update log
+            metric_logger.update(loss=losses.item(), **loss_dict_reduced)
+            metric_logger.update(lr=self.optimizer.param_groups[0]["lr"])
+            metric_logger.update(grad_norm=grad_norm)
+            metric_logger.update(size=img_size)
+
+            if self.args.debug:
+                print("For debug mode, we only train 1 iteration")
+                break
+
+    def rescale_image_targets(self, images, targets, max_stride, multi_scale_range=[0.5, 1.5]):
+        """
+            Deployed for Multi scale trick.
+        """
+        # During training phase, the shape of input image is square.
+        old_img_size = images.shape[-1]
+        min_img_size = old_img_size * multi_scale_range[0]
+        max_img_size = old_img_size * multi_scale_range[1]
+
+        # Choose a new image size
+        new_img_size = random.randrange(min_img_size, max_img_size + max_stride, max_stride)
+        
+        # Resize
+        if new_img_size != old_img_size:
+            # interpolate
+            images = torch.nn.functional.interpolate(
+                                input=images, 
+                                size=new_img_size, 
+                                mode='bilinear', 
+                                align_corners=False)
+
+        return images, targets, new_img_size
+
+
+# Build Trainer
+def build_trainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator):
+    # ----------------------- Det trainers -----------------------
+    if   cfg.trainer == 'yolo':
+        return YoloTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
+    elif cfg.trainer == 'rtdetr':
+        return RTDetrTrainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
+    else:
+        raise NotImplementedError(cfg.trainer)
+    

+ 119 - 0
eval.py

@@ -0,0 +1,119 @@
+import argparse
+import torch
+
+from evaluator.voc_evaluator import VOCAPIEvaluator
+from evaluator.coco_evaluator import COCOAPIEvaluator
+from evaluator.customed_evaluator import CustomedEvaluator
+
+# load transform
+from dataset.build import build_dataset, build_transform
+
+# load some utils
+from utils.misc import load_weight
+
+from config import build_config
+from models import build_model
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
+    # Basic setting
+    parser.add_argument('-size', '--img_size', default=640, type=int,
+                        help='the max size of input image')
+    parser.add_argument('--cuda', action='store_true', default=False,
+                        help='Use cuda')
+
+    # Model setting
+    parser.add_argument('-m', '--model', default='yolov1', type=str,
+                        help='build yolo')
+    parser.add_argument('--weight', default=None,
+                        type=str, help='Trained state_dict file path to open')
+    parser.add_argument('-p', '--pretrained', default=None, type=str,
+                        help='load pretrained weight')
+    parser.add_argument('-r', '--resume', default=None, type=str,
+                        help='keep training')
+    parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
+                        help='fuse Conv & BN')
+    parser.add_argument('--fuse_rep_conv', action='store_true', default=False,
+                        help='fuse Conv & BN')
+
+    # Data setting
+    parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/',
+                        help='data root')
+    parser.add_argument('-d', '--dataset', default='coco',
+                        help='coco, voc.')
+
+    # TTA
+    parser.add_argument('-tta', '--test_aug', action='store_true', default=False,
+                        help='use test augmentation.')
+
+    return parser.parse_args()
+
+
+
+def voc_test(cfg, model, data_dir, device, transform):
+    evaluator = VOCAPIEvaluator(cfg=cfg,
+                                data_dir=data_dir,
+                                device=device,
+                                transform=transform,
+                                display=True)
+
+    # VOC evaluation
+    evaluator.evaluate(model)
+
+def coco_test(cfg, model, data_dir, device, transform):
+    # eval
+    evaluator = COCOAPIEvaluator(
+                    cfg=cfg,
+                    data_dir=data_dir,
+                    device=device,
+                    transform=transform)
+
+    # COCO evaluation
+    evaluator.evaluate(model)
+
+def customed_test(cfg, model, data_dir, device, transform):
+    evaluator = CustomedEvaluator(
+        cfg=cfg,
+        data_dir=data_dir,
+        device=device,
+        image_set='val',
+        transform=transform)
+
+    # WiderFace evaluation
+    evaluator.evaluate(model)
+
+
+if __name__ == '__main__':
+    args = parse_args()
+    # cuda
+    if args.cuda:
+        print('use cuda')
+        device = torch.device("cuda")
+    else:
+        device = torch.device("cpu")
+
+    # Dataset & Model Config
+    cfg = build_config(args)
+    
+    # Transform
+    transform = build_transform(cfg, is_train=False)
+
+    # Dataset
+    dataset = build_dataset(args, cfg, transform, is_train=False)
+
+    # build model
+    model, _ = build_model(args, cfg, is_val=True)
+
+    # load trained weight
+    model = load_weight(model, args.weight, args.fuse_conv_bn)
+    model.to(device).eval()
+
+    # evaluation
+    with torch.no_grad():
+        if args.dataset == 'voc':
+            voc_test(cfg, model, args.root, device, transform)
+        elif args.dataset == 'coco':
+            coco_test(cfg, model, args.root, device, transform)
+        elif args.dataset == 'customed':
+            customed_test(cfg, model, args.root, device, transform)

+ 33 - 0
evaluator/build.py

@@ -0,0 +1,33 @@
+import os
+
+from evaluator.coco_evaluator import COCOAPIEvaluator
+from evaluator.voc_evaluator import VOCAPIEvaluator
+from evaluator.customed_evaluator import CustomedEvaluator
+
+
+
+def build_evluator(args, cfg, transform, device):
+    # Evaluator
+    ## VOC Evaluator
+    if args.dataset == 'voc':
+        evaluator = VOCAPIEvaluator(cfg       = cfg,
+                                    data_dir  = args.root,
+                                    device    = device,
+                                    transform = transform
+                                    )
+    ## COCO Evaluator
+    elif args.dataset == 'coco':
+        evaluator = COCOAPIEvaluator(cfg       = cfg,
+                                     data_dir  = args.root,
+                                     device    = device,
+                                     transform = transform
+                                     )
+    ## Custom dataset Evaluator
+    elif args.dataset == 'ourdataset':
+        evaluator = CustomedEvaluator(cfg       = cfg,
+                                      data_dir  = args.root,
+                                      device    = device,
+                                      transform = transform
+                                      )
+
+    return evaluator

+ 98 - 0
evaluator/coco_evaluator.py

@@ -0,0 +1,98 @@
+import json
+import tempfile
+import torch
+from pycocotools.cocoeval import COCOeval
+
+from dataset.coco import COCODataset
+from utils.box_ops import rescale_bboxes
+
+
+class COCOAPIEvaluator():
+    def __init__(self, cfg, data_dir, device, transform=None):
+        # ----------------- Basic parameters -----------------
+        self.image_set = 'val2017'
+        self.transform = transform
+        self.device = device
+        # ----------------- Metrics -----------------
+        self.map = 0.
+        self.ap50_95 = 0.
+        self.ap50 = 0.
+        # ----------------- Dataset -----------------
+        self.dataset = COCODataset(cfg=cfg, data_dir=data_dir, image_set=self.image_set, transform=None, is_train=False)
+
+
+    @torch.no_grad()
+    def evaluate(self, model):
+        model.eval()
+        ids = []
+        data_dict = []
+        num_images = len(self.dataset)
+        print('total number of images: %d' % (num_images))
+
+        # start testing
+        for index in range(num_images): # all the data in val2017
+            if index % 500 == 0:
+                print('[Eval: %d / %d]'%(index, num_images))
+
+            # load an image
+            img, id_ = self.dataset.pull_image(index)
+            orig_h, orig_w, _ = img.shape
+            orig_size = [orig_w, orig_h]
+
+            # preprocess
+            x, _, ratio = self.transform(img)
+            x = x.unsqueeze(0).to(self.device)
+            
+            id_ = int(id_)
+            ids.append(id_)
+
+            # inference
+            outputs = model(x)
+            scores = outputs['scores']
+            labels = outputs['labels']
+            bboxes = outputs['bboxes']
+
+            # rescale bboxes
+            bboxes = rescale_bboxes(bboxes, orig_size, ratio)
+
+            # process outputs
+            for i, box in enumerate(bboxes):
+                x1 = float(box[0])
+                y1 = float(box[1])
+                x2 = float(box[2])
+                y2 = float(box[3])
+                label = self.dataset.class_ids[int(labels[i])]
+                
+                bbox = [x1, y1, x2 - x1, y2 - y1]
+                score = float(scores[i]) # object score * class score
+                A = {"image_id": id_, "category_id": label, "bbox": bbox,
+                     "score": score} # COCO json format
+                data_dict.append(A)
+
+        annType = ['segm', 'bbox', 'keypoints']
+
+        # Evaluate the Dt (detection) json comparing with the ground truth
+        if len(data_dict) > 0:
+            print('evaluating ......')
+            cocoGt = self.dataset.coco
+            # workaround: temporarily write data to json file because pycocotools can't process dict in py36.
+            _, tmp = tempfile.mkstemp()
+            json.dump(data_dict, open(tmp, 'w'))
+            cocoDt = cocoGt.loadRes(tmp)
+            cocoEval = COCOeval(self.dataset.coco, cocoDt, annType[1])
+            cocoEval.params.imgIds = ids
+            cocoEval.evaluate()
+            cocoEval.accumulate()
+            cocoEval.summarize()
+
+            ap50_95, ap50 = cocoEval.stats[0], cocoEval.stats[1]
+            print('ap50_95 : ', ap50_95)
+            print('ap50 : ', ap50)
+            self.map = ap50_95
+            self.ap50_95 = ap50_95
+            self.ap50 = ap50
+
+            return ap50, ap50_95
+        else:
+            return 0, 0
+

+ 108 - 0
evaluator/customed_evaluator.py

@@ -0,0 +1,108 @@
+import json
+import tempfile
+import torch
+from dataset.customed import CustomedDataset
+from utils.box_ops import rescale_bboxes
+
+try:
+    from pycocotools.cocoeval import COCOeval
+except:
+    print("It seems that the COCOAPI is not installed.")
+
+
+class CustomedEvaluator():
+    def __init__(self, cfg, data_dir, device, image_set='val', transform=None):
+        # ----------------- Basic parameters -----------------
+        self.image_set = image_set
+        self.transform = transform
+        self.device = device
+        # ----------------- Metrics -----------------
+        self.map = 0.
+        self.ap50_95 = 0.
+        self.ap50 = 0.
+        # ----------------- Dataset -----------------
+        self.dataset = CustomedDataset(cfg, data_dir=data_dir, image_set=image_set, transform=None, is_train=False)
+
+
+    @torch.no_grad()
+    def evaluate(self, model):
+        """
+        COCO average precision (AP) Evaluation. Iterate inference on the test dataset
+        and the results are evaluated by COCO API.
+        Args:
+            model : model object
+        Returns:
+            ap50_95 (float) : calculated COCO AP for IoU=50:95
+            ap50 (float) : calculated COCO AP for IoU=50
+        """
+        model.eval()
+        ids = []
+        data_dict = []
+        num_images = len(self.dataset)
+        print('total number of images: %d' % (num_images))
+
+        # start testing
+        for index in range(num_images): # all the data in val2017
+            if index % 500 == 0:
+                print('[Eval: %d / %d]'%(index, num_images))
+
+            # load an image
+            img, id_ = self.dataset.pull_image(index)
+            orig_h, orig_w, _ = img.shape
+
+            # preprocess
+            x, _, ratio = self.transform(img)
+            x = x.unsqueeze(0).to(self.device)
+            
+            id_ = int(id_)
+            ids.append(id_)
+
+            # inference
+            outputs = model(x)
+            scores = outputs['scores']
+            labels = outputs['labels']
+            bboxes = outputs['bboxes']
+
+            # rescale bboxes
+            bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
+
+            for i, box in enumerate(bboxes):
+                x1 = float(box[0])
+                y1 = float(box[1])
+                x2 = float(box[2])
+                y2 = float(box[3])
+                label = self.dataset.class_ids[int(labels[i])]
+                
+                bbox = [x1, y1, x2 - x1, y2 - y1]
+                score = float(scores[i]) # object score * class score
+                A = {"image_id": id_, "category_id": label, "bbox": bbox,
+                     "score": score} # COCO json format
+                data_dict.append(A)
+
+        annType = ['segm', 'bbox', 'keypoints']
+
+        # Evaluate the Dt (detection) json comparing with the ground truth
+        if len(data_dict) > 0:
+            print('evaluating ......')
+            cocoGt = self.dataset.coco
+            # workaround: temporarily write data to json file because pycocotools can't process dict in py36.
+            _, tmp = tempfile.mkstemp()
+            json.dump(data_dict, open(tmp, 'w'))
+            cocoDt = cocoGt.loadRes(tmp)
+            cocoEval = COCOeval(self.dataset.coco, cocoDt, annType[1])
+            cocoEval.params.imgIds = ids
+            cocoEval.evaluate()
+            cocoEval.accumulate()
+            cocoEval.summarize()
+
+            ap50_95, ap50 = cocoEval.stats[0], cocoEval.stats[1]
+            print('ap50_95 : ', ap50_95)
+            print('ap50 : ', ap50)
+            self.map = ap50_95
+            self.ap50_95 = ap50_95
+            self.ap50 = ap50
+
+            return ap50, ap50_95
+        else:
+            return 0, 0
+

+ 356 - 0
evaluator/voc_evaluator.py

@@ -0,0 +1,356 @@
+"""Adapted from:
+    @longcw faster_rcnn_pytorch: https://github.com/longcw/faster_rcnn_pytorch
+    @rbgirshick py-faster-rcnn https://github.com/rbgirshick/py-faster-rcnn
+    Licensed under The MIT License [see LICENSE for details]
+"""
+
+from dataset.voc import VOCDataset, VOC_CLASSES
+import os
+import time
+import numpy as np
+import pickle
+import xml.etree.ElementTree as ET
+
+from utils.box_ops import rescale_bboxes
+
+
+class VOCAPIEvaluator():
+    """ VOC AP Evaluation class """
+    def __init__(self,
+                 cfg,
+                 data_dir, 
+                 device,
+                 transform,
+                 set_type='test', 
+                 year='2007', 
+                 display=False):
+        # basic config
+        self.data_dir = data_dir
+        self.device = device
+        self.labelmap = VOC_CLASSES
+        self.set_type = set_type
+        self.year = year
+        self.display = display
+        self.map = 0.
+
+        # transform
+        self.transform = transform
+
+        # path
+        self.devkit_path = os.path.join(data_dir, 'VOC' + year)
+        self.annopath = os.path.join(data_dir, 'VOC2007', 'Annotations', '%s.xml')
+        self.imgpath = os.path.join(data_dir, 'VOC2007', 'JPEGImages', '%s.jpg')
+        self.imgsetpath = os.path.join(data_dir, 'VOC2007', 'ImageSets', 'Main', set_type+'.txt')
+        self.output_dir = self.get_output_dir('det_results/eval/voc_eval/', self.set_type)
+
+        # dataset
+        self.dataset = VOCDataset(
+            cfg=cfg,
+            data_dir=data_dir, 
+            image_set=[('2007', set_type)],
+            is_train=False)
+        
+
+    def evaluate(self, net):
+        net.eval()
+        num_images = len(self.dataset)
+        # all detections are collected into:
+        #    all_boxes[cls][image] = N x 5 array of detections in
+        #    (x1, y1, x2, y2, score)
+        self.all_boxes = [[[] for _ in range(num_images)]
+                        for _ in range(len(self.labelmap))]
+
+        # timers
+        det_file = os.path.join(self.output_dir, 'detections.pkl')
+
+        for i in range(num_images):
+            img, _ = self.dataset.pull_image(i)
+            orig_h, orig_w = img.shape[:2]
+
+            # preprocess
+            x, _, ratio = self.transform(img)
+            x = x.unsqueeze(0).to(self.device)
+
+            # forward
+            t0 = time.time()
+            outputs = net(x)
+            scores = outputs['scores']
+            labels = outputs['labels']
+            bboxes = outputs['bboxes']
+            detect_time = time.time() - t0
+
+            # rescale bboxes
+            bboxes = rescale_bboxes(bboxes, [orig_w, orig_h], ratio)
+
+            for j in range(len(self.labelmap)):
+                inds = np.where(labels == j)[0]
+                if len(inds) == 0:
+                    self.all_boxes[j][i] = np.empty([0, 5], dtype=np.float32)
+                    continue
+                c_bboxes = bboxes[inds]
+                c_scores = scores[inds]
+                c_dets = np.hstack((c_bboxes,
+                                    c_scores[:, np.newaxis])).astype(np.float32,
+                                                                    copy=False)
+                self.all_boxes[j][i] = c_dets
+
+            if i % 500 == 0:
+                print('im_detect: {:d}/{:d} {:.3f}s'.format(i + 1, num_images, detect_time))
+
+        with open(det_file, 'wb') as f:
+            pickle.dump(self.all_boxes, f, pickle.HIGHEST_PROTOCOL)
+
+        print('Evaluating detections')
+        self.evaluate_detections(self.all_boxes)
+
+        print('Mean AP: ', self.map)
+  
+
+    def parse_rec(self, filename):
+        """ Parse a PASCAL VOC xml file """
+        tree = ET.parse(filename)
+        objects = []
+        for obj in tree.findall('object'):
+            obj_struct = {}
+            obj_struct['name'] = obj.find('name').text
+            obj_struct['pose'] = obj.find('pose').text
+            obj_struct['truncated'] = int(obj.find('truncated').text)
+            obj_struct['difficult'] = int(obj.find('difficult').text)
+            bbox = obj.find('bndbox')
+            obj_struct['bbox'] = [int(bbox.find('xmin').text),
+                                int(bbox.find('ymin').text),
+                                int(bbox.find('xmax').text),
+                                int(bbox.find('ymax').text)]
+            objects.append(obj_struct)
+
+        return objects
+
+
+    def get_output_dir(self, name, phase):
+        """Return the directory where experimental artifacts are placed.
+        If the directory does not exist, it is created.
+        A canonical path is built using the name from an imdb and a network
+        (if not None).
+        """
+        filedir = os.path.join(name, phase)
+        if not os.path.exists(filedir):
+            os.makedirs(filedir, exist_ok=True)
+        return filedir
+
+
+    def get_voc_results_file_template(self, cls):
+        # VOCdevkit/VOC2007/results/det_test_aeroplane.txt
+        filename = 'det_' + self.set_type + '_%s.txt' % (cls)
+        filedir = os.path.join(self.devkit_path, 'results')
+        if not os.path.exists(filedir):
+            os.makedirs(filedir)
+        path = os.path.join(filedir, filename)
+        return path
+
+
+    def write_voc_results_file(self, all_boxes):
+        for cls_ind, cls in enumerate(self.labelmap):
+            if self.display:
+                print('Writing {:s} VOC results file'.format(cls))
+            filename = self.get_voc_results_file_template(cls)
+            with open(filename, 'wt') as f:
+                for im_ind, index in enumerate(self.dataset.ids):
+                    dets = all_boxes[cls_ind][im_ind]
+                    if len(dets) == 0:
+                        continue
+                    # the VOCdevkit expects 1-based indices
+                    for k in range(dets.shape[0]):
+                        f.write('{:s} {:.3f} {:.1f} {:.1f} {:.1f} {:.1f}\n'.
+                                format(index[1], dets[k, -1],
+                                    dets[k, 0] + 1, dets[k, 1] + 1,
+                                    dets[k, 2] + 1, dets[k, 3] + 1))
+
+
+    def do_python_eval(self, use_07=True):
+        cachedir = os.path.join(self.devkit_path, 'annotations_cache')
+        aps = []
+        # The PASCAL VOC metric changed in 2010
+        use_07_metric = use_07
+        print('VOC07 metric? ' + ('Yes' if use_07_metric else 'No'))
+        if not os.path.isdir(self.output_dir):
+            os.mkdir(self.output_dir)
+        for i, cls in enumerate(self.labelmap):
+            filename = self.get_voc_results_file_template(cls)
+            rec, prec, ap = self.voc_eval(detpath=filename, 
+                                          classname=cls, 
+                                          cachedir=cachedir, 
+                                          ovthresh=0.5, 
+                                          use_07_metric=use_07_metric
+                                        )
+            aps += [ap]
+            print('AP for {} = {:.4f}'.format(cls, ap))
+            with open(os.path.join(self.output_dir, cls + '_pr.pkl'), 'wb') as f:
+                pickle.dump({'rec': rec, 'prec': prec, 'ap': ap}, f)
+        if self.display:
+            self.map = np.mean(aps)
+            print('Mean AP = {:.4f}'.format(np.mean(aps)))
+            print('~~~~~~~~')
+            print('Results:')
+            for ap in aps:
+                print('{:.3f}'.format(ap))
+            print('{:.3f}'.format(np.mean(aps)))
+            print('~~~~~~~~')
+            print('')
+            print('--------------------------------------------------------------')
+            print('Results computed with the **unofficial** Python eval code.')
+            print('Results should be very close to the official MATLAB eval code.')
+            print('--------------------------------------------------------------')
+        else:
+            self.map = np.mean(aps)
+            print('Mean AP = {:.4f}'.format(np.mean(aps)))
+
+
+    def voc_ap(self, rec, prec, use_07_metric=True):
+        """ ap = voc_ap(rec, prec, [use_07_metric])
+        Compute VOC AP given precision and recall.
+        If use_07_metric is true, uses the
+        VOC 07 11 point method (default:True).
+        """
+        if use_07_metric:
+            # 11 point metric
+            ap = 0.
+            for t in np.arange(0., 1.1, 0.1):
+                if np.sum(rec >= t) == 0:
+                    p = 0
+                else:
+                    p = np.max(prec[rec >= t])
+                ap = ap + p / 11.
+        else:
+            # correct AP calculation
+            # first append sentinel values at the end
+            mrec = np.concatenate(([0.], rec, [1.]))
+            mpre = np.concatenate(([0.], prec, [0.]))
+
+            # compute the precision envelope
+            for i in range(mpre.size - 1, 0, -1):
+                mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])
+
+            # to calculate area under PR curve, look for points
+            # where X axis (recall) changes value
+            i = np.where(mrec[1:] != mrec[:-1])[0]
+
+            # and sum (\Delta recall) * prec
+            ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
+        return ap
+
+
+    def voc_eval(self, detpath, classname, cachedir, ovthresh=0.5, use_07_metric=True):
+        if not os.path.isdir(cachedir):
+            os.mkdir(cachedir)
+        cachefile = os.path.join(cachedir, 'annots.pkl')
+        # read list of images
+        with open(self.imgsetpath, 'r') as f:
+            lines = f.readlines()
+        imagenames = [x.strip() for x in lines]
+        if not os.path.isfile(cachefile):
+            # load annots
+            recs = {}
+            for i, imagename in enumerate(imagenames):
+                recs[imagename] = self.parse_rec(self.annopath % (imagename))
+                if i % 100 == 0 and self.display:
+                    print('Reading annotation for {:d}/{:d}'.format(
+                    i + 1, len(imagenames)))
+            # save
+            if self.display:
+                print('Saving cached annotations to {:s}'.format(cachefile))
+            with open(cachefile, 'wb') as f:
+                pickle.dump(recs, f)
+        else:
+            # load
+            with open(cachefile, 'rb') as f:
+                recs = pickle.load(f)
+
+        # extract gt objects for this class
+        class_recs = {}
+        npos = 0
+        for imagename in imagenames:
+            R = [obj for obj in recs[imagename] if obj['name'] == classname]
+            bbox = np.array([x['bbox'] for x in R])
+            difficult = np.array([x['difficult'] for x in R]).astype(np.bool)
+            det = [False] * len(R)
+            npos = npos + sum(~difficult)
+            class_recs[imagename] = {'bbox': bbox,
+                                    'difficult': difficult,
+                                    'det': det}
+
+        # read dets
+        detfile = detpath.format(classname)
+        with open(detfile, 'r') as f:
+            lines = f.readlines()
+        if any(lines) == 1:
+
+            splitlines = [x.strip().split(' ') for x in lines]
+            image_ids = [x[0] for x in splitlines]
+            confidence = np.array([float(x[1]) for x in splitlines])
+            BB = np.array([[float(z) for z in x[2:]] for x in splitlines])
+
+            # sort by confidence
+            sorted_ind = np.argsort(-confidence)
+            sorted_scores = np.sort(-confidence)
+            BB = BB[sorted_ind, :]
+            image_ids = [image_ids[x] for x in sorted_ind]
+
+            # go down dets and mark TPs and FPs
+            nd = len(image_ids)
+            tp = np.zeros(nd)
+            fp = np.zeros(nd)
+            for d in range(nd):
+                R = class_recs[image_ids[d]]
+                bb = BB[d, :].astype(float)
+                ovmax = -np.inf
+                BBGT = R['bbox'].astype(float)
+                if BBGT.size > 0:
+                    # compute overlaps
+                    # intersection
+                    ixmin = np.maximum(BBGT[:, 0], bb[0])
+                    iymin = np.maximum(BBGT[:, 1], bb[1])
+                    ixmax = np.minimum(BBGT[:, 2], bb[2])
+                    iymax = np.minimum(BBGT[:, 3], bb[3])
+                    iw = np.maximum(ixmax - ixmin, 0.)
+                    ih = np.maximum(iymax - iymin, 0.)
+                    inters = iw * ih
+                    uni = ((bb[2] - bb[0]) * (bb[3] - bb[1]) +
+                        (BBGT[:, 2] - BBGT[:, 0]) *
+                        (BBGT[:, 3] - BBGT[:, 1]) - inters)
+                    overlaps = inters / uni
+                    ovmax = np.max(overlaps)
+                    jmax = np.argmax(overlaps)
+
+                if ovmax > ovthresh:
+                    if not R['difficult'][jmax]:
+                        if not R['det'][jmax]:
+                            tp[d] = 1.
+                            R['det'][jmax] = 1
+                        else:
+                            fp[d] = 1.
+                else:
+                    fp[d] = 1.
+
+            # compute precision recall
+            fp = np.cumsum(fp)
+            tp = np.cumsum(tp)
+            rec = tp / float(npos)
+            # avoid divide by zero in case the first detection matches a difficult
+            # ground truth
+            prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
+            ap = self.voc_ap(rec, prec, use_07_metric)
+        else:
+            rec = -1.
+            prec = -1.
+            ap = -1.
+
+        return rec, prec, ap
+
+
+    def evaluate_detections(self, box_list):
+        self.write_voc_results_file(box_list)
+        self.do_python_eval()
+
+
+if __name__ == '__main__':
+    pass

+ 59 - 0
models/__init__.py

@@ -0,0 +1,59 @@
+#!/usr/bin/env python3
+# -*- coding:utf-8 -*-
+
+import torch
+from .yolov1.build import build_yolov1
+from .yolov8.build import build_yolov8
+from .rtdetr.build import build_rtdetr
+
+# build object detector
+def build_model(args, cfg, is_val=False):
+    # ------------ build object detector ------------
+    ## YOLOv8
+    if 'yolov1' in args.model:
+        model, criterion = build_yolov1(cfg, is_val)
+    elif 'yolov8' in args.model:
+        model, criterion = build_yolov8(cfg, is_val)
+    ## RT-DETR
+    elif 'rtdetr' in args.model:
+        model, criterion = build_rtdetr(cfg, is_val)
+
+    if is_val:
+        # ------------ Load pretrained weight ------------
+        if args.pretrained is not None:
+            print('Loading COCO pretrained weight ...')
+            checkpoint = torch.load(args.pretrained, map_location='cpu')
+            # checkpoint state dict
+            checkpoint_state_dict = checkpoint.pop("model")
+            # model state dict
+            model_state_dict = model.state_dict()
+            # check
+            for k in list(checkpoint_state_dict.keys()):
+                if k in model_state_dict:
+                    shape_model = tuple(model_state_dict[k].shape)
+                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                    if shape_model != shape_checkpoint:
+                        checkpoint_state_dict.pop(k)
+                        print(k)
+                else:
+                    checkpoint_state_dict.pop(k)
+                    print(k)
+
+            model.load_state_dict(checkpoint_state_dict, strict=False)
+
+        # ------------ Keep training from the given checkpoint ------------
+        if args.resume and args.resume != "None":
+            checkpoint = torch.load(args.resume, map_location='cpu')
+            # checkpoint state dict
+            try:
+                checkpoint_state_dict = checkpoint.pop("model")
+                print('Load model from the checkpoint: ', args.resume)
+                model.load_state_dict(checkpoint_state_dict)
+                del checkpoint, checkpoint_state_dict
+            except:
+                print("No model in the given checkpoint.")
+
+        return model, criterion
+
+    else:      
+        return model

+ 50 - 0
models/rtdetr/README.md

@@ -0,0 +1,50 @@
+# Real-time Transformer-based Object Detector:
+
+## Results on the COCO-val
+|     Model    | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | ckpt | Logs |
+|--------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|------|
+| RT-DETR-R18  | 4xb4  |  640  |           45.5         |        63.5       |        66.8       |        21.0        | [ckpt](https://github.com/yjh0410/ODLab-World/releases/download/coco_weight/rtdetr_r18_coco.pth) | [log](https://github.com/yjh0410/ODLab-World/releases/download/coco_weight/RT-DETR-R18-COCO.txt)|
+| RT-DETR-R50  | 4xb4  |  640  |           50.6         |        69.4       |       112.1       |        36.7        | [ckpt](https://github.com/yjh0410/ODLab-World/releases/download/coco_weight/rtdetr_r50_coco.pth) | [log](https://github.com/yjh0410/ODLab-World/releases/download/coco_weight/RT-DETR-R50-COCO.txt)|
+| RT-DETR-R101 | 4xb4  |  640  |                        |                   |                   |                    |  | |
+
+
+## Train RT-DETR
+### Single GPU
+Taking training RT-DETR-S on COCO as the example,
+```Shell
+python train.py --cuda -d coco --root path/to/coco -m rtdetr_r18 -bs 16  --fp16
+```
+
+### Multi GPU
+Taking training RT-DETR on COCO as the example,
+```Shell
+python -m torch.distributed.run --nproc_per_node=8 train.py --cuda --distributed -d coco --root /data/datasets/ -m rtdetr_r18 -bs 16 --fp16 --sybn 
+```
+
+## Test RT-DETR
+Taking testing RT-DETR on COCO-val as the example,
+```Shell
+python test.py --cuda -d coco --root path/to/coco -m rtdetr_r18 --weight path/to/rtdetr_r18.pth --show 
+```
+
+## Evaluate RT-DETR
+Taking evaluating RT-DETR on COCO-val as the example,
+```Shell
+python train.py --cuda -d coco --root path/to/coco -m rtdetr_r18 -bs 16 --fp16 --resume path/to/rtdetr_r18.pth --eval_first
+```
+
+## Demo
+### Detect with Image
+```Shell
+python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m rtdetr_r18 --weight path/to/weight --show
+```
+
+### Detect with Video
+```Shell
+python demo.py --mode video --path_to_vid path/to/video --cuda -m rtdetr_r18 --weight path/to/weight --show
+```
+
+### Detect with Camera
+```Shell
+python demo.py --mode camera --cuda -m rtdetr_r18 --weight path/to/weight --show
+```

+ 103 - 0
models/rtdetr/basic_modules/backbone.py

@@ -0,0 +1,103 @@
+import torch.nn as nn
+import torchvision
+from torchvision.models._utils import IntermediateLayerGetter
+from torchvision.models.resnet import (ResNet18_Weights,
+                                       ResNet34_Weights,
+                                       ResNet50_Weights,
+                                       ResNet101_Weights)
+from .norm import FrozenBatchNorm2d
+
+
+# IN1K pretrained weights
+pretrained_urls = {
+    # ResNet series
+    'resnet18':  ResNet18_Weights,
+    'resnet34':  ResNet34_Weights,
+    'resnet50':  ResNet50_Weights,
+    'resnet101': ResNet101_Weights,
+
+}
+
+
+# ----------------- Model functions -----------------
+## Build backbone network
+def build_backbone(cfg, pretrained):
+    print('==============================')
+    print('Backbone: {}'.format(cfg.backbone))
+    # ResNet
+    if 'resnet' in cfg.backbone:
+        pretrained_weight = cfg.pretrained_weight if pretrained else None
+        model = build_resnet(cfg, pretrained_weight)
+    else:
+        raise NotImplementedError("Unknown backbone: <>.".format(cfg.backbone))
+    
+    return model
+
+
+# ----------------- ResNet Backbone -----------------
+class ResNet(nn.Module):
+    """ResNet backbone with frozen BatchNorm."""
+    def __init__(self,
+                 name: str,
+                 norm_type: str,
+                 pretrained_weights: str = "imagenet1k_v1",
+                 freeze_at: int = -1,
+                 freeze_stem_only: bool = False):
+        super().__init__()
+        # Pretrained
+        assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
+        if pretrained_weights is not None:
+            if name in ('resnet18', 'resnet34'):
+                pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
+            else:
+                if pretrained_weights == "imagenet1k_v1":
+                    pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
+                else:
+                    pretrained_weights = pretrained_urls[name].IMAGENET1K_V2
+        else:
+            pretrained_weights = None
+        print('- Backbone pretrained weight: ', pretrained_weights)
+
+        # Norm layer
+        print("- Norm layer of backbone: {}".format(norm_type))
+        if norm_type == 'BN':
+            norm_layer = nn.BatchNorm2d
+        elif norm_type == 'FrozeBN':
+            norm_layer = FrozenBatchNorm2d
+
+        # Backbone
+        backbone = getattr(torchvision.models, name)(norm_layer=norm_layer, weights=pretrained_weights)
+        return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
+        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+        self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
+
+        # Freeze
+        print("- Freeze at: {}".format(freeze_at))
+        if freeze_at >= 0:
+            for name, parameter in backbone.named_parameters():
+                if freeze_stem_only:
+                    print("- Freeze stem layer only")
+                    if 'layer1' not in name and 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+                        parameter.requires_grad_(False)
+                else:
+                    print("- Freeze stem layer only + layer1")
+                    if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+                        parameter.requires_grad_(False)
+
+    def forward(self, x):
+        xs = self.body(x)
+        fmp_list = []
+        for name, fmp in xs.items():
+            fmp_list.append(fmp)
+
+        return fmp_list
+
+def build_resnet(cfg, pretrained_weight=None):
+    # ResNet series
+    backbone = ResNet(cfg.backbone,
+                      cfg.backbone_norm,
+                      pretrained_weight,
+                      cfg.freeze_at,
+                      cfg.freeze_stem_only)
+
+    return backbone

+ 144 - 0
models/rtdetr/basic_modules/conv.py

@@ -0,0 +1,144 @@
+import torch
+import torch.nn as nn
+
+
+# ----------------- Basic CNN Ops -----------------
+def get_conv2d(c1, c2, k, p, s, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type == 'gelu':
+        return nn.GELU()
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+    """3x3 convolution with padding"""
+    return nn.Conv2d(
+        in_planes,
+        out_planes,
+        kernel_size=3,
+        stride=stride,
+        padding=dilation,
+        groups=groups,
+        bias=False,
+        dilation=dilation,
+    )
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+
+# ----------------- CNN Modules -----------------
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        add_bias = False if norm_type else True
+        self.depthwise = depthwise
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, g=1, bias=add_bias)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.norm2(self.conv2(x))
+            return x
+
+class Bottleneck(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expand_ratio = 0.5,
+                 kernel_sizes = [3, 3],
+                 shortcut     = True,
+                 act_type     = 'silu',
+                 norm_type    = 'BN',
+                 depthwise    = False,):
+        super(Bottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)
+        paddings = [k // 2 for k in kernel_sizes]
+        self.cv1 = BasicConv(in_dim, inter_dim,
+                             kernel_size=kernel_sizes[0], padding=paddings[0],
+                             act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.cv2 = BasicConv(inter_dim, out_dim,
+                             kernel_size=kernel_sizes[1], padding=paddings[1],
+                             act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.cv2(self.cv1(x))
+
+        return x + h if self.shortcut else h
+
+class ELANLayer(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 num_blocks   :int   = 1,
+                 expand_ratio :float = 0.5,
+                 shortcut     :bool  = False,
+                 act_type     :str   = 'silu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False,):
+        super(ELANLayer, self).__init__()
+        self.inter_dim = round(out_dim * expand_ratio)
+        self.conv1 = BasicConv(in_dim, self.inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.conv2 = BasicConv(in_dim, self.inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.cmodules = nn.ModuleList([Bottleneck(self.inter_dim, self.inter_dim,
+                                                   1.0, [3, 3], shortcut,
+                                                   act_type, norm_type, depthwise)
+                                                   for _ in range(num_blocks)])
+        self.conv3 = BasicConv(self.inter_dim * (2 + num_blocks), out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+
+    def forward(self, x):
+        x1, x2 = self.conv1(x), self.conv2(x)
+        out = [x1, x2]
+        for m in self.cmodules:
+            x2 = m(x2)
+            out.append(x2)
+
+        return self.conv3(torch.cat(out, dim=1))
+    

+ 109 - 0
models/rtdetr/basic_modules/dn_compoments.py

@@ -0,0 +1,109 @@
+import torch
+
+
+def inverse_sigmoid(x, eps=1e-5):
+    x = x.clamp(min=0., max=1.)
+    return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
+
+def box_cxcywh_to_xyxy(x):
+    x_c, y_c, w, h = x.unbind(-1)
+    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
+         (x_c + 0.5 * w), (y_c + 0.5 * h)]
+    return torch.stack(b, dim=-1)
+
+def box_xyxy_to_cxcywh(x):
+    x0, y0, x1, y1 = x.unbind(-1)
+    b = [(x0 + x1) / 2, (y0 + y1) / 2,
+         (x1 - x0), (y1 - y0)]
+    return torch.stack(b, dim=-1)
+
+def get_contrastive_denoising_training_group(targets,
+                                             num_classes,
+                                             num_queries,
+                                             class_embed,
+                                             num_denoising=100,
+                                             label_noise_ratio=0.5,
+                                             box_noise_scale=1.0,):
+    if num_denoising <= 0:
+        return None, None, None, None
+
+    num_gts = [len(t['labels']) for t in targets]
+    device = targets[0]['labels'].device
+    
+    max_gt_num = max(num_gts)
+    if max_gt_num == 0:
+        return None, None, None, None
+
+    num_group = num_denoising // max_gt_num
+    num_group = 1 if num_group == 0 else num_group
+    # pad gt to max_num of a batch
+    bs = len(num_gts)
+
+    input_query_class = torch.full([bs, max_gt_num], num_classes, dtype=torch.int32, device=device)
+    input_query_bbox = torch.zeros([bs, max_gt_num, 4], device=device)
+    pad_gt_mask = torch.zeros([bs, max_gt_num], dtype=torch.bool, device=device)
+
+    for i in range(bs):
+        num_gt = num_gts[i]
+        if num_gt > 0:
+            input_query_class[i, :num_gt] = targets[i]['labels']
+            input_query_bbox[i, :num_gt] = targets[i]['boxes']
+            pad_gt_mask[i, :num_gt] = 1
+    # each group has positive and negative queries.
+    input_query_class = input_query_class.tile([1, 2 * num_group])
+    input_query_bbox = input_query_bbox.tile([1, 2 * num_group, 1])
+    pad_gt_mask = pad_gt_mask.tile([1, 2 * num_group])
+    # positive and negative mask
+    negative_gt_mask = torch.zeros([bs, max_gt_num * 2, 1], device=device)
+    negative_gt_mask[:, max_gt_num:] = 1
+    negative_gt_mask = negative_gt_mask.tile([1, num_group, 1])
+    positive_gt_mask = 1 - negative_gt_mask
+    # contrastive denoising training positive index
+    positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
+    dn_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
+    dn_positive_idx = torch.split(dn_positive_idx, [n * num_group for n in num_gts])
+    # total denoising queries
+    num_denoising = int(max_gt_num * 2 * num_group)
+
+    if label_noise_ratio > 0:
+        mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
+        # randomly put a new one here
+        new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
+        input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
+
+    if box_noise_scale > 0:
+        known_bbox = box_cxcywh_to_xyxy(input_query_bbox)
+        diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
+        rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
+        rand_part = torch.rand_like(input_query_bbox)
+        rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
+        rand_part *= rand_sign
+        known_bbox += rand_part * diff
+        known_bbox.clip_(min=0.0, max=1.0)
+        input_query_bbox = box_xyxy_to_cxcywh(known_bbox)
+        input_query_bbox = inverse_sigmoid(input_query_bbox)
+    input_query_class = class_embed(input_query_class)
+
+    tgt_size = num_denoising + num_queries
+    # attn_mask = torch.ones([tgt_size, tgt_size], device=device) < 0
+    attn_mask = torch.full([tgt_size, tgt_size], False, dtype=torch.bool, device=device)
+    # match query cannot see the reconstruction
+    attn_mask[num_denoising:, :num_denoising] = True
+    
+    # reconstruct cannot see each other
+    for i in range(num_group):
+        if i == 0:
+            attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
+        if i == num_group - 1:
+            attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * i * 2] = True
+        else:
+            attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), max_gt_num * 2 * (i + 1): num_denoising] = True
+            attn_mask[max_gt_num * 2 * i: max_gt_num * 2 * (i + 1), :max_gt_num * 2 * i] = True
+        
+    dn_meta = {
+        "dn_positive_idx": dn_positive_idx,
+        "dn_num_group": num_group,
+        "dn_num_split": [num_denoising, num_queries]
+    }
+
+    return input_query_class, input_query_bbox, attn_mask, dn_meta

+ 85 - 0
models/rtdetr/basic_modules/ext_op/README.md

@@ -0,0 +1,85 @@
+# Multi-scale deformable attention自定义OP编译
+该自定义OP是参考[自定义外部算子](https://www.paddlepaddle.org.cn/documentation/docs/zh/guides/custom_op/new_cpp_op_cn.html) 。
+
+## 1. 环境依赖
+- Paddle >= 2.3.2
+- gcc 8.2
+
+## 2. 安装
+请在当前路径下进行编译安装
+```
+cd rtdetr_paddle/ppdet/modeling/transformers/ext_op/
+python setup_ms_deformable_attn_op.py install
+```
+
+编译完成后即可使用,以下为`ms_deformable_attn`的使用示例
+```
+# 引入自定义op
+from deformable_detr_ops import ms_deformable_attn
+
+# 构造fake input tensor
+bs, n_heads, c = 2, 8, 8
+query_length, n_levels, n_points = 2, 2, 2
+spatial_shapes = paddle.to_tensor([(6, 4), (3, 2)], dtype=paddle.int64)
+level_start_index = paddle.concat((paddle.to_tensor(
+    [0], dtype=paddle.int64), spatial_shapes.prod(1).cumsum(0)[:-1]))
+value_length = sum([(H * W).item() for H, W in spatial_shapes])
+
+def get_test_tensors(channels):
+    value = paddle.rand(
+        [bs, value_length, n_heads, channels], dtype=paddle.float32) * 0.01
+    sampling_locations = paddle.rand(
+        [bs, query_length, n_heads, n_levels, n_points, 2],
+        dtype=paddle.float32)
+    attention_weights = paddle.rand(
+        [bs, query_length, n_heads, n_levels, n_points],
+        dtype=paddle.float32) + 1e-5
+    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(
+        -2, keepdim=True)
+    return [value, sampling_locations, attention_weights]
+
+value, sampling_locations, attention_weights = get_test_tensors(c)
+
+output = ms_deformable_attn(value,
+                            spatial_shapes,
+                            level_start_index,
+                            sampling_locations,
+                            attention_weights)
+```
+
+## 3. 单元测试
+可以通过执行单元测试来确认自定义算子功能的正确性,执行单元测试的示例如下所示:
+```
+python test_ms_deformable_attn_op.py
+```
+运行成功后,打印如下:
+```
+*True check_forward_equal_with_paddle_float: max_abs_err 6.98e-10 max_rel_err 2.03e-07
+*tensor1 True check_gradient_numerical(D=30)
+*tensor2 True check_gradient_numerical(D=30)
+*tensor3 True check_gradient_numerical(D=30)
+*tensor1 True check_gradient_numerical(D=32)
+*tensor2 True check_gradient_numerical(D=32)
+*tensor3 True check_gradient_numerical(D=32)
+*tensor1 True check_gradient_numerical(D=64)
+*tensor2 True check_gradient_numerical(D=64)
+*tensor3 True check_gradient_numerical(D=64)
+*tensor1 True check_gradient_numerical(D=71)
+*tensor2 True check_gradient_numerical(D=71)
+*tensor3 True check_gradient_numerical(D=71)
+*tensor1 True check_gradient_numerical(D=128)
+*tensor2 True check_gradient_numerical(D=128)
+*tensor3 True check_gradient_numerical(D=128)
+*tensor1 True check_gradient_numerical(D=1024)
+*tensor2 True check_gradient_numerical(D=1024)
+*tensor3 True check_gradient_numerical(D=1024)
+*tensor1 True check_gradient_numerical(D=1025)
+*tensor2 True check_gradient_numerical(D=1025)
+*tensor3 True check_gradient_numerical(D=1025)
+*tensor1 True check_gradient_numerical(D=2048)
+*tensor2 True check_gradient_numerical(D=2048)
+*tensor3 True check_gradient_numerical(D=2048)
+*tensor1 True check_gradient_numerical(D=3096)
+*tensor2 True check_gradient_numerical(D=3096)
+*tensor3 True check_gradient_numerical(D=3096)
+```

+ 65 - 0
models/rtdetr/basic_modules/ext_op/ms_deformable_attn_op.cc

@@ -0,0 +1,65 @@
+/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/extension.h"
+
+#include <vector>
+
+// declare GPU implementation
+std::vector<paddle::Tensor>
+MSDeformableAttnCUDAForward(const paddle::Tensor &value,
+                            const paddle::Tensor &value_spatial_shapes,
+                            const paddle::Tensor &value_level_start_index,
+                            const paddle::Tensor &sampling_locations,
+                            const paddle::Tensor &attention_weights);
+
+std::vector<paddle::Tensor> MSDeformableAttnCUDABackward(
+    const paddle::Tensor &value, const paddle::Tensor &value_spatial_shapes,
+    const paddle::Tensor &value_level_start_index,
+    const paddle::Tensor &sampling_locations,
+    const paddle::Tensor &attention_weights, const paddle::Tensor &grad_out);
+
+//// CPU not implemented
+
+std::vector<std::vector<int64_t>>
+MSDeformableAttnInferShape(std::vector<int64_t> value_shape,
+                           std::vector<int64_t> value_spatial_shapes_shape,
+                           std::vector<int64_t> value_level_start_index_shape,
+                           std::vector<int64_t> sampling_locations_shape,
+                           std::vector<int64_t> attention_weights_shape) {
+  return {{value_shape[0], sampling_locations_shape[1],
+           value_shape[2] * value_shape[3]}};
+}
+
+std::vector<paddle::DataType>
+MSDeformableAttnInferDtype(paddle::DataType value_dtype,
+                           paddle::DataType value_spatial_shapes_dtype,
+                           paddle::DataType value_level_start_index_dtype,
+                           paddle::DataType sampling_locations_dtype,
+                           paddle::DataType attention_weights_dtype) {
+  return {value_dtype};
+}
+
+PD_BUILD_OP(ms_deformable_attn)
+    .Inputs({"Value", "SpatialShapes", "LevelIndex", "SamplingLocations",
+             "AttentionWeights"})
+    .Outputs({"Out"})
+    .SetKernelFn(PD_KERNEL(MSDeformableAttnCUDAForward))
+    .SetInferShapeFn(PD_INFER_SHAPE(MSDeformableAttnInferShape))
+    .SetInferDtypeFn(PD_INFER_DTYPE(MSDeformableAttnInferDtype));
+
+PD_BUILD_GRAD_OP(ms_deformable_attn)
+    .Inputs({"Value", "SpatialShapes", "LevelIndex", "SamplingLocations",
+             "AttentionWeights", paddle::Grad("Out")})
+    .Outputs({paddle::Grad("Value"), paddle::Grad("SpatialShapes"),
+              paddle::Grad("LevelIndex"), paddle::Grad("SamplingLocations"),
+              paddle::Grad("AttentionWeights")})
+    .SetKernelFn(PD_KERNEL(MSDeformableAttnCUDABackward));

+ 1073 - 0
models/rtdetr/basic_modules/ext_op/ms_deformable_attn_op.cu

@@ -0,0 +1,1073 @@
+/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+    http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License. */
+
+#include "paddle/extension.h"
+
+#define CUDA_KERNEL_LOOP(i, n)                                                 \
+  for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n);                 \
+       i += blockDim.x * gridDim.x)
+
+const int CUDA_NUM_THREADS = 1024;
+inline int GET_BLOCKS(const int N, const int num_threads) {
+  return (N + num_threads - 1) / num_threads;
+}
+
+// forward bilinear
+template <typename data_t>
+__device__ data_t deformable_attn_bilinear_forward(
+    const data_t *&bottom_data, const int &height, const int &width,
+    const int &nheads, const int &channels, const data_t &h, const data_t &w,
+    const int &m, const int &c) {
+  const int h_low = floor(h);
+  const int w_low = floor(w);
+  const int h_high = h_low + 1;
+  const int w_high = w_low + 1;
+
+  const data_t lh = h - h_low;
+  const data_t lw = w - w_low;
+  const data_t hh = 1 - lh, hw = 1 - lw;
+
+  const int w_stride = nheads * channels;
+  const int h_stride = width * w_stride;
+  const int h_low_ptr_offset = h_low * h_stride;
+  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+  const int w_low_ptr_offset = w_low * w_stride;
+  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+  const int base_ptr = m * channels + c;
+
+  data_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0) {
+    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+    v1 = bottom_data[ptr1];
+  }
+  data_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1) {
+    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+    v2 = bottom_data[ptr2];
+  }
+  data_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0) {
+    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+    v3 = bottom_data[ptr3];
+  }
+  data_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1) {
+    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+    v4 = bottom_data[ptr4];
+  }
+
+  const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+
+  const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  return val;
+}
+
+// forward kernel
+template <typename data_t>
+__global__ void deformable_attn_cuda_kernel_forward(
+    const int n, const data_t *data_value, const int64_t *data_spatial_shapes,
+    const int64_t *data_level_start_index, const data_t *data_sampling_loc,
+    const data_t *data_attn_weight, const int batch_size,
+    const int value_length, const int num_heads, const int channels,
+    const int num_levels, const int query_length, const int num_points,
+    data_t *output_data_ptr) {
+  CUDA_KERNEL_LOOP(index, n) {
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    data_t *data_ptr = output_data_ptr + index;
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+    data_t col = 0;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const data_t *data_value_ptr = data_value + (data_value_ptr_init_offset +
+                                                   level_start_id * qid_stride);
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          col += deformable_attn_bilinear_forward(
+                     data_value_ptr, spatial_h, spatial_w, num_heads, channels,
+                     h_im, w_im, m_col, c_col) *
+                 weight;
+        }
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+      }
+    }
+    *data_ptr = col;
+  }
+}
+
+#define CHECK_INPUT_GPU(x) PD_CHECK(x.is_gpu(), #x " must be a GPU Tensor.")
+// forward
+std::vector<paddle::Tensor>
+MSDeformableAttnCUDAForward(const paddle::Tensor &value,
+                            const paddle::Tensor &value_spatial_shapes,
+                            const paddle::Tensor &value_level_start_index,
+                            const paddle::Tensor &sampling_locations,
+                            const paddle::Tensor &attention_weights) {
+
+  CHECK_INPUT_GPU(value);
+  CHECK_INPUT_GPU(value_spatial_shapes);
+  CHECK_INPUT_GPU(value_level_start_index);
+  CHECK_INPUT_GPU(sampling_locations);
+  CHECK_INPUT_GPU(attention_weights);
+
+  const int batch_size = value.shape()[0];
+  const int value_length = value.shape()[1];
+  const int num_heads = value.shape()[2];
+  const int channels = value.shape()[3];
+
+  const int num_levels = value_spatial_shapes.shape()[0];
+  const int query_length = sampling_locations.shape()[1];
+  const int num_points = sampling_locations.shape()[4];
+
+  auto output = paddle::full({batch_size, query_length, num_heads * channels},
+                             0, value.dtype(), paddle::GPUPlace());
+
+  const int num_kernels = batch_size * query_length * num_heads * channels;
+  deformable_attn_cuda_kernel_forward<float>
+      <<<GET_BLOCKS(num_kernels, CUDA_NUM_THREADS), CUDA_NUM_THREADS, 0,
+         value.stream()>>>(num_kernels, value.data<float>(),
+                           value_spatial_shapes.data<int64_t>(),
+                           value_level_start_index.data<int64_t>(),
+                           sampling_locations.data<float>(),
+                           attention_weights.data<float>(), batch_size,
+                           value_length, num_heads, channels, num_levels,
+                           query_length, num_points, output.data<float>());
+  return {output};
+}
+
+// backward bilinear
+template <typename data_t>
+__device__ void deformable_attn_bilinear_backward(
+    const data_t *&bottom_data, const int &height, const int &width,
+    const int &nheads, const int &channels, const data_t &h, const data_t &w,
+    const int &m, const int &c, const data_t &top_grad,
+    const data_t &attn_weight, data_t *&grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  const int h_low = floor(h);
+  const int w_low = floor(w);
+  const int h_high = h_low + 1;
+  const int w_high = w_low + 1;
+
+  const data_t lh = h - h_low;
+  const data_t lw = w - w_low;
+  const data_t hh = 1 - lh, hw = 1 - lw;
+
+  const int w_stride = nheads * channels;
+  const int h_stride = width * w_stride;
+  const int h_low_ptr_offset = h_low * h_stride;
+  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+  const int w_low_ptr_offset = w_low * w_stride;
+  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+  const int base_ptr = m * channels + c;
+
+  const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+  const data_t top_grad_value = top_grad * attn_weight;
+  data_t grad_h_weight = 0, grad_w_weight = 0;
+
+  data_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0) {
+    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+    v1 = bottom_data[ptr1];
+    grad_h_weight -= hw * v1;
+    grad_w_weight -= hh * v1;
+    atomicAdd(grad_value + ptr1, w1 * top_grad_value);
+  }
+  data_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1) {
+    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+    v2 = bottom_data[ptr2];
+    grad_h_weight -= lw * v2;
+    grad_w_weight += hh * v2;
+    atomicAdd(grad_value + ptr2, w2 * top_grad_value);
+  }
+  data_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0) {
+    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+    v3 = bottom_data[ptr3];
+    grad_h_weight += hw * v3;
+    grad_w_weight -= lh * v3;
+    atomicAdd(grad_value + ptr3, w3 * top_grad_value);
+  }
+  data_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1) {
+    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+    v4 = bottom_data[ptr4];
+    grad_h_weight += lw * v4;
+    grad_w_weight += lh * v4;
+    atomicAdd(grad_value + ptr4, w4 * top_grad_value);
+  }
+
+  const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  *grad_attn_weight = top_grad * val;
+  *grad_sampling_loc = width * grad_w_weight * top_grad_value;
+  *(grad_sampling_loc + 1) = height * grad_h_weight * top_grad_value;
+}
+
+template <typename data_t>
+__device__ void deformable_attn_bilinear_backward_gm(
+    const data_t *&bottom_data, const int &height, const int &width,
+    const int &nheads, const int &channels, const data_t &h, const data_t &w,
+    const int &m, const int &c, const data_t &top_grad,
+    const data_t &attn_weight, data_t *&grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  const int h_low = floor(h);
+  const int w_low = floor(w);
+  const int h_high = h_low + 1;
+  const int w_high = w_low + 1;
+
+  const data_t lh = h - h_low;
+  const data_t lw = w - w_low;
+  const data_t hh = 1 - lh, hw = 1 - lw;
+
+  const int w_stride = nheads * channels;
+  const int h_stride = width * w_stride;
+  const int h_low_ptr_offset = h_low * h_stride;
+  const int h_high_ptr_offset = h_low_ptr_offset + h_stride;
+  const int w_low_ptr_offset = w_low * w_stride;
+  const int w_high_ptr_offset = w_low_ptr_offset + w_stride;
+  const int base_ptr = m * channels + c;
+
+  const data_t w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw;
+  const data_t top_grad_value = top_grad * attn_weight;
+  data_t grad_h_weight = 0, grad_w_weight = 0;
+
+  data_t v1 = 0;
+  if (h_low >= 0 && w_low >= 0) {
+    const int ptr1 = h_low_ptr_offset + w_low_ptr_offset + base_ptr;
+    v1 = bottom_data[ptr1];
+    grad_h_weight -= hw * v1;
+    grad_w_weight -= hh * v1;
+    atomicAdd(grad_value + ptr1, w1 * top_grad_value);
+  }
+  data_t v2 = 0;
+  if (h_low >= 0 && w_high <= width - 1) {
+    const int ptr2 = h_low_ptr_offset + w_high_ptr_offset + base_ptr;
+    v2 = bottom_data[ptr2];
+    grad_h_weight -= lw * v2;
+    grad_w_weight += hh * v2;
+    atomicAdd(grad_value + ptr2, w2 * top_grad_value);
+  }
+  data_t v3 = 0;
+  if (h_high <= height - 1 && w_low >= 0) {
+    const int ptr3 = h_high_ptr_offset + w_low_ptr_offset + base_ptr;
+    v3 = bottom_data[ptr3];
+    grad_h_weight += hw * v3;
+    grad_w_weight -= lh * v3;
+    atomicAdd(grad_value + ptr3, w3 * top_grad_value);
+  }
+  data_t v4 = 0;
+  if (h_high <= height - 1 && w_high <= width - 1) {
+    const int ptr4 = h_high_ptr_offset + w_high_ptr_offset + base_ptr;
+    v4 = bottom_data[ptr4];
+    grad_h_weight += lw * v4;
+    grad_w_weight += lh * v4;
+    atomicAdd(grad_value + ptr4, w4 * top_grad_value);
+  }
+
+  const data_t val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4);
+  atomicAdd(grad_attn_weight, top_grad * val);
+  atomicAdd(grad_sampling_loc, width * grad_w_weight * top_grad_value);
+  atomicAdd(grad_sampling_loc + 1, height * grad_h_weight * top_grad_value);
+}
+
+// backward kernels
+// channels > 1024
+template <typename data_t>
+__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v2_multi_blocks(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    extern __shared__ int _s[];
+    data_t *cache_grad_sampling_loc = (data_t *)_s;
+    data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight + threadIdx.x) = 0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              cache_grad_sampling_loc + (threadIdx.x << 1),
+              cache_grad_attn_weight + threadIdx.x);
+        }
+
+        __syncthreads();
+
+        for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
+             s >>= 1, spre >>= 1) {
+          if (tid < s) {
+            const unsigned int xid1 = tid << 1;
+            const unsigned int xid2 = (tid + s) << 1;
+            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+            cache_grad_sampling_loc[xid1 + 1] +=
+                cache_grad_sampling_loc[xid2 + 1];
+            if (tid + (s << 1) < spre) {
+              cache_grad_attn_weight[tid] +=
+                  cache_grad_attn_weight[tid + (s << 1)];
+              cache_grad_sampling_loc[xid1] +=
+                  cache_grad_sampling_loc[xid2 + (s << 1)];
+              cache_grad_sampling_loc[xid1 + 1] +=
+                  cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+            }
+          }
+          __syncthreads();
+        }
+
+        if (tid == 0) {
+          atomicAdd(grad_sampling_loc, cache_grad_sampling_loc[0]);
+          atomicAdd(grad_sampling_loc + 1, cache_grad_sampling_loc[1]);
+          atomicAdd(grad_attn_weight, cache_grad_attn_weight[0]);
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+template <typename data_t>
+__global__ void deformable_attn_cuda_kernel_backward_gm(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward_gm(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              grad_sampling_loc, grad_attn_weight);
+        }
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+// channels <= 1024
+template <typename data_t, unsigned int blockSize>
+__global__ void
+deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    __shared__ data_t cache_grad_sampling_loc[blockSize * 2];
+    __shared__ data_t cache_grad_attn_weight[blockSize];
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight + threadIdx.x) = 0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              cache_grad_sampling_loc + (threadIdx.x << 1),
+              cache_grad_attn_weight + threadIdx.x);
+        }
+
+        __syncthreads();
+        if (tid == 0) {
+          data_t _grad_w = cache_grad_sampling_loc[0],
+                 _grad_h = cache_grad_sampling_loc[1],
+                 _grad_a = cache_grad_attn_weight[0];
+          int sid = 2;
+          for (unsigned int tid = 1; tid < blockSize; ++tid) {
+            _grad_w += cache_grad_sampling_loc[sid];
+            _grad_h += cache_grad_sampling_loc[sid + 1];
+            _grad_a += cache_grad_attn_weight[tid];
+            sid += 2;
+          }
+
+          *grad_sampling_loc = _grad_w;
+          *(grad_sampling_loc + 1) = _grad_h;
+          *grad_attn_weight = _grad_a;
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+template <typename data_t, unsigned int blockSize>
+__global__ void
+deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    __shared__ data_t cache_grad_sampling_loc[blockSize * 2];
+    __shared__ data_t cache_grad_attn_weight[blockSize];
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight + threadIdx.x) = 0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              cache_grad_sampling_loc + (threadIdx.x << 1),
+              cache_grad_attn_weight + threadIdx.x);
+        }
+
+        __syncthreads();
+
+        for (unsigned int s = blockSize / 2; s > 0; s >>= 1) {
+          if (tid < s) {
+            const unsigned int xid1 = tid << 1;
+            const unsigned int xid2 = (tid + s) << 1;
+            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+            cache_grad_sampling_loc[xid1 + 1] +=
+                cache_grad_sampling_loc[xid2 + 1];
+          }
+          __syncthreads();
+        }
+
+        if (tid == 0) {
+          *grad_sampling_loc = cache_grad_sampling_loc[0];
+          *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+          *grad_attn_weight = cache_grad_attn_weight[0];
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+template <typename data_t>
+__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v1(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    extern __shared__ int _s[];
+    data_t *cache_grad_sampling_loc = (data_t *)_s;
+    data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight + threadIdx.x) = 0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              cache_grad_sampling_loc + (threadIdx.x << 1),
+              cache_grad_attn_weight + threadIdx.x);
+        }
+
+        __syncthreads();
+        if (tid == 0) {
+          data_t _grad_w = cache_grad_sampling_loc[0],
+                 _grad_h = cache_grad_sampling_loc[1],
+                 _grad_a = cache_grad_attn_weight[0];
+          int sid = 2;
+          for (unsigned int tid = 1; tid < blockDim.x; ++tid) {
+            _grad_w += cache_grad_sampling_loc[sid];
+            _grad_h += cache_grad_sampling_loc[sid + 1];
+            _grad_a += cache_grad_attn_weight[tid];
+            sid += 2;
+          }
+
+          *grad_sampling_loc = _grad_w;
+          *(grad_sampling_loc + 1) = _grad_h;
+          *grad_attn_weight = _grad_a;
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+template <typename data_t>
+__global__ void deformable_attn_cuda_kernel_backward_shm_reduce_v2(
+    const int n, const data_t *grad_col, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  CUDA_KERNEL_LOOP(index, n) {
+    extern __shared__ int _s[];
+    data_t *cache_grad_sampling_loc = (data_t *)_s;
+    data_t *cache_grad_attn_weight = cache_grad_sampling_loc + 2 * blockDim.x;
+    unsigned int tid = threadIdx.x;
+    int _temp = index;
+    const int c_col = _temp % channels;
+    _temp /= channels;
+    const int sampling_index = _temp;
+    const int m_col = _temp % num_heads;
+    _temp /= num_heads;
+    const int q_col = _temp % query_length;
+    _temp /= query_length;
+    const int b_col = _temp;
+
+    const data_t top_grad = grad_col[index];
+
+    int data_weight_ptr = sampling_index * num_levels * num_points;
+    int data_loc_w_ptr = data_weight_ptr << 1;
+    const int grad_sampling_ptr = data_weight_ptr;
+    grad_sampling_loc += grad_sampling_ptr << 1;
+    grad_attn_weight += grad_sampling_ptr;
+    const int grad_weight_stride = 1;
+    const int grad_loc_stride = 2;
+    const int qid_stride = num_heads * channels;
+    const int data_value_ptr_init_offset = b_col * value_length * qid_stride;
+
+    for (int l_col = 0; l_col < num_levels; ++l_col) {
+      const int level_start_id = data_level_start_index[l_col];
+      const int spatial_h_ptr = l_col << 1;
+      const int spatial_h = data_spatial_shapes[spatial_h_ptr];
+      const int spatial_w = data_spatial_shapes[spatial_h_ptr + 1];
+      const int value_ptr_offset =
+          data_value_ptr_init_offset + level_start_id * qid_stride;
+      const data_t *data_value_ptr = data_value + value_ptr_offset;
+      data_t *grad_value_ptr = grad_value + value_ptr_offset;
+
+      for (int p_col = 0; p_col < num_points; ++p_col) {
+        const data_t loc_w = data_sampling_loc[data_loc_w_ptr];
+        const data_t loc_h = data_sampling_loc[data_loc_w_ptr + 1];
+        const data_t weight = data_attn_weight[data_weight_ptr];
+
+        const data_t h_im = loc_h * spatial_h - 0.5;
+        const data_t w_im = loc_w * spatial_w - 0.5;
+        *(cache_grad_sampling_loc + (threadIdx.x << 1)) = 0;
+        *(cache_grad_sampling_loc + ((threadIdx.x << 1) + 1)) = 0;
+        *(cache_grad_attn_weight + threadIdx.x) = 0;
+        if (h_im > -1 && w_im > -1 && h_im < spatial_h && w_im < spatial_w) {
+          deformable_attn_bilinear_backward(
+              data_value_ptr, spatial_h, spatial_w, num_heads, channels, h_im,
+              w_im, m_col, c_col, top_grad, weight, grad_value_ptr,
+              cache_grad_sampling_loc + (threadIdx.x << 1),
+              cache_grad_attn_weight + threadIdx.x);
+        }
+
+        __syncthreads();
+
+        for (unsigned int s = blockDim.x / 2, spre = blockDim.x; s > 0;
+             s >>= 1, spre >>= 1) {
+          if (tid < s) {
+            const unsigned int xid1 = tid << 1;
+            const unsigned int xid2 = (tid + s) << 1;
+            cache_grad_attn_weight[tid] += cache_grad_attn_weight[tid + s];
+            cache_grad_sampling_loc[xid1] += cache_grad_sampling_loc[xid2];
+            cache_grad_sampling_loc[xid1 + 1] +=
+                cache_grad_sampling_loc[xid2 + 1];
+            if (tid + (s << 1) < spre) {
+              cache_grad_attn_weight[tid] +=
+                  cache_grad_attn_weight[tid + (s << 1)];
+              cache_grad_sampling_loc[xid1] +=
+                  cache_grad_sampling_loc[xid2 + (s << 1)];
+              cache_grad_sampling_loc[xid1 + 1] +=
+                  cache_grad_sampling_loc[xid2 + 1 + (s << 1)];
+            }
+          }
+          __syncthreads();
+        }
+
+        if (tid == 0) {
+          *grad_sampling_loc = cache_grad_sampling_loc[0];
+          *(grad_sampling_loc + 1) = cache_grad_sampling_loc[1];
+          *grad_attn_weight = cache_grad_attn_weight[0];
+        }
+        __syncthreads();
+
+        data_weight_ptr += 1;
+        data_loc_w_ptr += 2;
+        grad_attn_weight += grad_weight_stride;
+        grad_sampling_loc += grad_loc_stride;
+      }
+    }
+  }
+}
+
+// backward branch
+template <typename data_t>
+void deformable_attn_cuda_backward(
+    cudaStream_t stream, const data_t *grad_out, const data_t *data_value,
+    const int64_t *data_spatial_shapes, const int64_t *data_level_start_index,
+    const data_t *data_sampling_loc, const data_t *data_attn_weight,
+    const int batch_size, const int value_length, const int num_heads,
+    const int channels, const int num_levels, const int query_length,
+    const int num_points, data_t *grad_value, data_t *grad_sampling_loc,
+    data_t *grad_attn_weight) {
+  const int num_threads =
+      (channels > CUDA_NUM_THREADS) ? CUDA_NUM_THREADS : channels;
+  const int num_kernels = batch_size * query_length * num_heads * channels;
+  const int num_actual_kernels =
+      batch_size * query_length * num_heads * channels;
+  if (channels > 1024) {
+    if ((channels & 1023) == 0) {
+      deformable_attn_cuda_kernel_backward_shm_reduce_v2_multi_blocks<data_t>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+             num_threads * 3 * sizeof(data_t), stream>>>(
+              num_kernels, grad_out, data_value, data_spatial_shapes,
+              data_level_start_index, data_sampling_loc, data_attn_weight,
+              batch_size, value_length, num_heads, channels, num_levels,
+              query_length, num_points, grad_value, grad_sampling_loc,
+              grad_attn_weight);
+    } else {
+      deformable_attn_cuda_kernel_backward_gm<data_t>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+    }
+  } else {
+    switch (channels) {
+    case 1:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         1>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 2:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         2>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 4:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         4>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 8:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         8>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 16:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         16>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 32:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v1<data_t,
+                                                                         32>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 64:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
+                                                                         64>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 128:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
+                                                                         128>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 256:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
+                                                                         256>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 512:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
+                                                                         512>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    case 1024:
+      deformable_attn_cuda_kernel_backward_shm_blocksize_aware_reduce_v2<data_t,
+                                                                         1024>
+          <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads, 0,
+             stream>>>(num_kernels, grad_out, data_value, data_spatial_shapes,
+                       data_level_start_index, data_sampling_loc,
+                       data_attn_weight, batch_size, value_length, num_heads,
+                       channels, num_levels, query_length, num_points,
+                       grad_value, grad_sampling_loc, grad_attn_weight);
+      break;
+    default:
+      if (channels < 64) {
+        deformable_attn_cuda_kernel_backward_shm_reduce_v1<data_t>
+            <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+               num_threads * 3 * sizeof(data_t), stream>>>(
+                num_kernels, grad_out, data_value, data_spatial_shapes,
+                data_level_start_index, data_sampling_loc, data_attn_weight,
+                batch_size, value_length, num_heads, channels, num_levels,
+                query_length, num_points, grad_value, grad_sampling_loc,
+                grad_attn_weight);
+      } else {
+        deformable_attn_cuda_kernel_backward_shm_reduce_v2<data_t>
+            <<<GET_BLOCKS(num_actual_kernels, num_threads), num_threads,
+               num_threads * 3 * sizeof(data_t), stream>>>(
+                num_kernels, grad_out, data_value, data_spatial_shapes,
+                data_level_start_index, data_sampling_loc, data_attn_weight,
+                batch_size, value_length, num_heads, channels, num_levels,
+                query_length, num_points, grad_value, grad_sampling_loc,
+                grad_attn_weight);
+      }
+    }
+  }
+}
+
+// backward
+std::vector<paddle::Tensor> MSDeformableAttnCUDABackward(
+    const paddle::Tensor &value, const paddle::Tensor &value_spatial_shapes,
+    const paddle::Tensor &value_level_start_index,
+    const paddle::Tensor &sampling_locations,
+    const paddle::Tensor &attention_weights, const paddle::Tensor &grad_out) {
+
+  CHECK_INPUT_GPU(value);
+  CHECK_INPUT_GPU(value_spatial_shapes);
+  CHECK_INPUT_GPU(value_level_start_index);
+  CHECK_INPUT_GPU(sampling_locations);
+  CHECK_INPUT_GPU(attention_weights);
+  CHECK_INPUT_GPU(grad_out);
+
+  const int batch_size = value.shape()[0];
+  const int value_length = value.shape()[1];
+  const int num_heads = value.shape()[2];
+  const int channels = value.shape()[3];
+
+  const int num_levels = value_spatial_shapes.shape()[0];
+  const int query_length = sampling_locations.shape()[1];
+  const int num_points = sampling_locations.shape()[4];
+
+  auto grad_value =
+      paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace());
+  auto grad_spatial_shapes =
+      paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace());
+  auto grad_level_start_index =
+      paddle::full(value.shape(), 0, value.dtype(), paddle::GPUPlace());
+  auto grad_sampling_locations =
+      paddle::full(sampling_locations.shape(), 0, sampling_locations.dtype(),
+                   paddle::GPUPlace());
+  auto grad_attention_weights =
+      paddle::full(attention_weights.shape(), 0, attention_weights.dtype(),
+                   paddle::GPUPlace());
+
+  deformable_attn_cuda_backward<float>(
+      value.stream(), grad_out.data<float>(), value.data<float>(),
+      value_spatial_shapes.data<int64_t>(),
+      value_level_start_index.data<int64_t>(), sampling_locations.data<float>(),
+      attention_weights.data<float>(), batch_size, value_length, num_heads,
+      channels, num_levels, query_length, num_points, grad_value.data<float>(),
+      grad_sampling_locations.data<float>(),
+      grad_attention_weights.data<float>());
+
+  return {grad_value, grad_spatial_shapes, grad_level_start_index,
+          grad_sampling_locations, grad_attention_weights};
+}

+ 7 - 0
models/rtdetr/basic_modules/ext_op/setup_ms_deformable_attn_op.py

@@ -0,0 +1,7 @@
+from paddle.utils.cpp_extension import CUDAExtension, setup
+
+if __name__ == "__main__":
+    setup(
+        name='deformable_detr_ops',
+        ext_modules=CUDAExtension(
+            sources=['ms_deformable_attn_op.cc', 'ms_deformable_attn_op.cu']))

+ 140 - 0
models/rtdetr/basic_modules/ext_op/test_ms_deformable_attn_op.py

@@ -0,0 +1,140 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import print_function
+from __future__ import division
+
+import os
+import sys
+import random
+import numpy as np
+import paddle
+# add python path of PaddleDetection to sys.path
+parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 5)))
+if parent_path not in sys.path:
+    sys.path.append(parent_path)
+
+from ppdet.modeling.transformers.utils import deformable_attention_core_func
+ms_deform_attn_core_paddle = deformable_attention_core_func
+
+try:
+    gpu_index = int(sys.argv[1])
+except:
+    gpu_index = 0
+print(f'Use gpu {gpu_index} to test...')
+paddle.set_device(f'gpu:{gpu_index}')
+
+try:
+    from deformable_detr_ops import ms_deformable_attn
+except Exception as e:
+    print('import deformable_detr_ops error', e)
+    sys.exit(-1)
+
+paddle.seed(1)
+random.seed(1)
+np.random.seed(1)
+
+bs, n_heads, c = 2, 8, 8
+query_length, n_levels, n_points = 2, 2, 2
+spatial_shapes = paddle.to_tensor([(6, 4), (3, 2)], dtype=paddle.int64)
+level_start_index = paddle.concat((paddle.to_tensor(
+    [0], dtype=paddle.int64), spatial_shapes.prod(1).cumsum(0)[:-1]))
+value_length = sum([(H * W).item() for H, W in spatial_shapes])
+
+
+def get_test_tensors(channels):
+    value = paddle.rand(
+        [bs, value_length, n_heads, channels], dtype=paddle.float32) * 0.01
+    sampling_locations = paddle.rand(
+        [bs, query_length, n_heads, n_levels, n_points, 2],
+        dtype=paddle.float32)
+    attention_weights = paddle.rand(
+        [bs, query_length, n_heads, n_levels, n_points],
+        dtype=paddle.float32) + 1e-5
+    attention_weights /= attention_weights.sum(-1, keepdim=True).sum(
+        -2, keepdim=True)
+
+    return [value, sampling_locations, attention_weights]
+
+
+@paddle.no_grad()
+def check_forward_equal_with_paddle_float():
+    value, sampling_locations, attention_weights = get_test_tensors(c)
+
+    output_paddle = ms_deform_attn_core_paddle(
+        value, spatial_shapes, level_start_index, sampling_locations,
+        attention_weights).detach().cpu()
+    output_cuda = ms_deformable_attn(value, spatial_shapes, level_start_index,
+                                     sampling_locations,
+                                     attention_weights).detach().cpu()
+    fwdok = paddle.allclose(
+        output_cuda, output_paddle, rtol=1e-2, atol=1e-3).item()
+    max_abs_err = (output_cuda - output_paddle).abs().max().item()
+    max_rel_err = (
+        (output_cuda - output_paddle).abs() / output_paddle.abs()).max().item()
+
+    print(
+        f'*{fwdok} check_forward_equal_with_paddle_float: max_abs_err {max_abs_err:.2e} max_rel_err {max_rel_err:.2e}'
+    )
+
+
+def check_gradient_numerical(channels=4):
+    value_paddle, sampling_locations_paddle, attention_weights_paddle = get_test_tensors(
+        channels)
+    value_paddle.stop_gradient = False
+    sampling_locations_paddle.stop_gradient = False
+    attention_weights_paddle.stop_gradient = False
+
+    value_cuda = value_paddle.detach().clone()
+    sampling_locations_cuda = sampling_locations_paddle.detach().clone()
+    attention_weights_cuda = attention_weights_paddle.detach().clone()
+    value_cuda.stop_gradient = False
+    sampling_locations_cuda.stop_gradient = False
+    attention_weights_cuda.stop_gradient = False
+
+    output_paddle = ms_deform_attn_core_paddle(
+        value_paddle, spatial_shapes, level_start_index,
+        sampling_locations_paddle, attention_weights_paddle)
+    output_paddle.sum().backward()
+
+    output_cuda = ms_deformable_attn(value_cuda, spatial_shapes,
+                                     level_start_index, sampling_locations_cuda,
+                                     attention_weights_cuda)
+    output_cuda.sum().backward()
+
+    res = paddle.allclose(
+        value_paddle.grad, value_cuda.grad, rtol=1e-2, atol=1e-3).item()
+    print(f'*tensor1 {res} check_gradient_numerical(D={channels})')
+
+    res = paddle.allclose(
+        sampling_locations_paddle.grad,
+        sampling_locations_cuda.grad,
+        rtol=1e-2,
+        atol=1e-3).item()
+    print(f'*tensor2 {res} check_gradient_numerical(D={channels})')
+
+    res = paddle.allclose(
+        attention_weights_paddle.grad,
+        attention_weights_cuda.grad,
+        rtol=1e-2,
+        atol=1e-3).item()
+    print(f'*tensor3 {res} check_gradient_numerical(D={channels})')
+
+
+if __name__ == '__main__':
+    check_forward_equal_with_paddle_float()
+
+    for channels in [30, 32, 64, 71, 128, 1024, 1025, 2048, 3096]:
+        check_gradient_numerical(channels)

+ 164 - 0
models/rtdetr/basic_modules/fpn.py

@@ -0,0 +1,164 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List
+
+from .conv import BasicConv, ELANLayer
+from .transformer import TransformerEncoder
+
+
+# Build PaFPN
+def build_fpn(cfg, in_dims):
+    if cfg.fpn == 'hybrid_encoder':
+        return HybridEncoder(in_dims     = in_dims,
+                             out_dim     = cfg.hidden_dim,
+                             num_blocks  = cfg.fpn_num_blocks,
+                             expand_ratio= cfg.fpn_expand_ratio,
+                             act_type    = cfg.fpn_act,
+                             norm_type   = cfg.fpn_norm,
+                             depthwise   = cfg.fpn_depthwise,
+                             num_heads   = cfg.en_num_heads,
+                             num_layers  = cfg.en_num_layers,
+                             ffn_dim     = cfg.en_ffn_dim,
+                             dropout     = cfg.en_dropout,
+                             en_act_type    = cfg.en_act,
+                             )
+    else:
+        raise NotImplementedError("Unknown PaFPN: <{}>".format(cfg.fpn))
+
+
+# ----------------- Feature Pyramid Network -----------------## Hybrid Encoder (Transformer encoder + Convolutional PaFPN)
+class HybridEncoder(nn.Module):
+    def __init__(self, 
+                 in_dims        :List  = [256, 512, 1024],
+                 out_dim        :int   = 256,
+                 num_blocks     :int   = 3,
+                 expand_ratio   :float = 0.5,
+                 act_type       :str   = 'silu',
+                 norm_type      :str   = 'BN',
+                 depthwise      :bool  = False,
+                 # Transformer's parameters
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 ffn_dim        :int   = 1024,
+                 dropout        :float = 0.1,
+                 pe_temperature :float = 10000.,
+                 en_act_type    :str   = 'gelu'
+                 ) -> None:
+        super(HybridEncoder, self).__init__()
+        print('==============================')
+        print('FPN: {}'.format("RTC-PaFPN"))
+        # ---------------- Basic parameters ----------------
+        self.in_dims = in_dims
+        self.out_dim = out_dim
+        self.out_dims = [self.out_dim] * len(in_dims)
+        self.num_heads = num_heads
+        self.num_layers = num_layers
+        self.ffn_dim = ffn_dim
+        c3, c4, c5 = in_dims
+
+        # ---------------- Input projs ----------------
+        self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+
+        # ---------------- Downsample ----------------
+        self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim,
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim,
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+
+        # ---------------- Transformer Encoder ----------------
+        self.transformer_encoder = TransformerEncoder(d_model        = self.out_dim,
+                                                      num_heads      = num_heads,
+                                                      num_layers     = num_layers,
+                                                      ffn_dim        = ffn_dim,
+                                                      pe_temperature = pe_temperature,
+                                                      dropout        = dropout,
+                                                      act_type       = en_act_type
+                                                      )
+
+        # ---------------- Top dwon FPN ----------------
+        ## P5 -> P4
+        self.top_down_layer_1 = ELANLayer(in_dim       = self.out_dim * 2,
+                                          out_dim      = self.out_dim,
+                                          num_blocks   = num_blocks,
+                                          expand_ratio = expand_ratio,
+                                          shortcut     = False,
+                                          act_type     = act_type,
+                                          norm_type    = norm_type,
+                                          depthwise    = depthwise,
+                                          )
+        ## P4 -> P3
+        self.top_down_layer_2 = ELANLayer(in_dim       = self.out_dim * 2,
+                                          out_dim      = self.out_dim,
+                                          num_blocks   = num_blocks,
+                                          expand_ratio = expand_ratio,
+                                          shortcut     = False,
+                                          act_type     = act_type,
+                                          norm_type    = norm_type,
+                                          depthwise    = depthwise,
+                                          )
+        
+        # ---------------- Bottom up PAN----------------
+        ## P3 -> P4
+        self.bottom_up_layer_1 = ELANLayer(in_dim       = self.out_dim * 2,
+                                           out_dim      = self.out_dim,
+                                           num_blocks   = num_blocks,
+                                           expand_ratio = expand_ratio,
+                                           shortcut     = False,
+                                           act_type     = act_type,
+                                           norm_type    = norm_type,
+                                           depthwise    = depthwise,
+                                          )
+        ## P4 -> P5
+        self.bottom_up_layer_2 = ELANLayer(in_dim       = self.out_dim * 2,
+                                           out_dim      = self.out_dim,
+                                           num_blocks   = num_blocks,
+                                           expand_ratio = expand_ratio,
+                                           shortcut     = False,
+                                           act_type     = act_type,
+                                           norm_type    = norm_type,
+                                           depthwise    = depthwise,
+                                           )
+
+        self.init_weights()
+  
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        # -------- Input projs --------
+        p5 = self.reduce_layer_1(c5)
+        p4 = self.reduce_layer_2(c4)
+        p3 = self.reduce_layer_3(c3)
+
+        # -------- Transformer encoder --------
+        p5 = self.transformer_encoder(p5)
+
+        # -------- Top down FPN --------
+        p5_up = F.interpolate(p5, scale_factor=2.0)
+        p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
+
+        p4_up = F.interpolate(p4, scale_factor=2.0)
+        p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
+
+        # -------- Bottom up PAN --------
+        p3_ds = self.dowmsample_layer_1(p3)
+        p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
+
+        p4_ds = self.dowmsample_layer_2(p4)
+        p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
+
+        out_feats = [p3, p4, p5]
+        
+        return out_feats

+ 51 - 0
models/rtdetr/basic_modules/mlp.py

@@ -0,0 +1,51 @@
+import torch.nn as nn
+
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type == 'gelu':
+        return nn.GELU()
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+
+# ----------------- MLP modules -----------------
+class MLP(nn.Module):
+    def __init__(self, in_dim, hidden_dim, out_dim, num_layers):
+        super().__init__()
+        self.num_layers = num_layers
+        h = [hidden_dim] * (num_layers - 1)
+        self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([in_dim] + h, h + [out_dim]))
+
+    def forward(self, x):
+        for i, layer in enumerate(self.layers):
+            x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
+        return x
+
+class FFN(nn.Module):
+    def __init__(self, d_model=256, ffn_dim=1024, dropout=0., act_type='relu'):
+        super().__init__()
+        self.ffn_dim = ffn_dim
+        self.linear1 = nn.Linear(d_model, self.ffn_dim)
+        self.activation = get_activation(act_type)
+        self.dropout2 = nn.Dropout(dropout)
+        self.linear2 = nn.Linear(self.ffn_dim, d_model)
+        self.dropout3 = nn.Dropout(dropout)
+        self.norm = nn.LayerNorm(d_model)
+
+    def forward(self, src):
+        src2 = self.linear2(self.dropout2(self.activation(self.linear1(src))))
+        src = src + self.dropout3(src2)
+        src = self.norm(src)
+        
+        return src
+    

+ 71 - 0
models/rtdetr/basic_modules/nms_ops.py

@@ -0,0 +1,71 @@
+import numpy as np
+
+
+# ---------------------------- NMS ----------------------------
+## basic NMS
+def nms(bboxes, scores, nms_thresh):
+    """"Pure Python NMS."""
+    x1 = bboxes[:, 0]  #xmin
+    y1 = bboxes[:, 1]  #ymin
+    x2 = bboxes[:, 2]  #xmax
+    y2 = bboxes[:, 3]  #ymax
+
+    areas = (x2 - x1) * (y2 - y1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        # compute iou
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(1e-10, xx2 - xx1)
+        h = np.maximum(1e-10, yy2 - yy1)
+        inter = w * h
+
+        iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
+        #reserve all the boundingbox whose ovr less than thresh
+        inds = np.where(iou <= nms_thresh)[0]
+        order = order[inds + 1]
+
+    return keep
+
+## class-agnostic NMS 
+def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
+    # nms
+    keep = nms(bboxes, scores, nms_thresh)
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+## class-aware NMS 
+def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
+    # nms
+    keep = np.zeros(len(bboxes), dtype=np.int32)
+    for i in range(num_classes):
+        inds = np.where(labels == i)[0]
+        if len(inds) == 0:
+            continue
+        c_bboxes = bboxes[inds]
+        c_scores = scores[inds]
+        c_keep = nms(c_bboxes, c_scores, nms_thresh)
+        keep[inds[c_keep]] = 1
+    keep = np.where(keep > 0)
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+## multi-class NMS 
+def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
+    if class_agnostic:
+        return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
+    else:
+        return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)

+ 33 - 0
models/rtdetr/basic_modules/norm.py

@@ -0,0 +1,33 @@
+import torch
+
+
+class FrozenBatchNorm2d(torch.nn.Module):
+    def __init__(self, n):
+        super(FrozenBatchNorm2d, self).__init__()
+        self.register_buffer("weight", torch.ones(n))
+        self.register_buffer("bias", torch.zeros(n))
+        self.register_buffer("running_mean", torch.zeros(n))
+        self.register_buffer("running_var", torch.ones(n))
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        num_batches_tracked_key = prefix + 'num_batches_tracked'
+        if num_batches_tracked_key in state_dict:
+            del state_dict[num_batches_tracked_key]
+
+        super(FrozenBatchNorm2d, self)._load_from_state_dict(
+            state_dict, prefix, local_metadata, strict,
+            missing_keys, unexpected_keys, error_msgs)
+
+    def forward(self, x):
+        # move reshapes to the beginning
+        # to make it fuser-friendly
+        w = self.weight.reshape(1, -1, 1, 1)
+        b = self.bias.reshape(1, -1, 1, 1)
+        rv = self.running_var.reshape(1, -1, 1, 1)
+        rm = self.running_mean.reshape(1, -1, 1, 1)
+        eps = 1e-5
+        scale = w * (rv + eps).rsqrt()
+        bias = b - rm * scale
+        return x * scale + bias
+    

+ 459 - 0
models/rtdetr/basic_modules/transformer.py

@@ -0,0 +1,459 @@
+import math
+import copy
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .mlp import FFN
+
+
+def get_clones(module, N):
+    if N <= 0:
+        return None
+    else:
+        return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])
+
+def inverse_sigmoid(x, eps=1e-5):
+    x = x.clamp(min=0., max=1.)
+    return torch.log(x.clamp(min=eps) / (1 - x).clamp(min=eps))
+
+
+# ----------------- Basic Transformer Ops -----------------
+def multi_scale_deformable_attn_pytorch(
+    value: torch.Tensor,
+    value_spatial_shapes: torch.Tensor,
+    sampling_locations: torch.Tensor,
+    attention_weights: torch.Tensor,
+) -> torch.Tensor:
+
+    bs, _, num_heads, embed_dims = value.shape
+    _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
+    
+    value_list = value.split([H_ * W_ for H_, W_ in value_spatial_shapes], dim=1)
+    sampling_grids = 2 * sampling_locations - 1
+    sampling_value_list = []
+    for level, (H_, W_) in enumerate(value_spatial_shapes):
+        # bs, H_*W_, num_heads, embed_dims ->
+        # bs, H_*W_, num_heads*embed_dims ->
+        # bs, num_heads*embed_dims, H_*W_ ->
+        # bs*num_heads, embed_dims, H_, W_
+        value_l_ = (
+            value_list[level].flatten(2).transpose(1, 2).reshape(bs * num_heads, embed_dims, H_, W_)
+        )
+        # bs, num_queries, num_heads, num_points, 2 ->
+        # bs, num_heads, num_queries, num_points, 2 ->
+        # bs*num_heads, num_queries, num_points, 2
+        sampling_grid_l_ = sampling_grids[:, :, :, level].transpose(1, 2).flatten(0, 1)
+        # bs*num_heads, embed_dims, num_queries, num_points
+        sampling_value_l_ = F.grid_sample(
+            value_l_, sampling_grid_l_, mode="bilinear", padding_mode="zeros", align_corners=False
+        )
+        sampling_value_list.append(sampling_value_l_)
+    # (bs, num_queries, num_heads, num_levels, num_points) ->
+    # (bs, num_heads, num_queries, num_levels, num_points) ->
+    # (bs, num_heads, 1, num_queries, num_levels*num_points)
+    attention_weights = attention_weights.transpose(1, 2).reshape(
+        bs * num_heads, 1, num_queries, num_levels * num_points
+    )
+    output = (
+        (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
+        .sum(-1)
+        .view(bs, num_heads * embed_dims, num_queries)
+    )
+    return output.transpose(1, 2).contiguous()
+
+class MSDeformableAttention(nn.Module):
+    def __init__(self,
+                 embed_dim=256,
+                 num_heads=8,
+                 num_levels=4,
+                 num_points=4):
+        """
+        Multi-Scale Deformable Attention Module
+        """
+        super(MSDeformableAttention, self).__init__()
+        self.embed_dim = embed_dim
+        self.num_heads = num_heads
+        self.num_levels = num_levels
+        self.num_points = num_points
+        self.total_points = num_heads * num_levels * num_points
+
+        self.head_dim = embed_dim // num_heads
+        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
+
+        self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2)
+        self.attention_weights = nn.Linear(embed_dim, self.total_points)
+        self.value_proj = nn.Linear(embed_dim, embed_dim)
+        self.output_proj = nn.Linear(embed_dim, embed_dim)
+        
+        try:
+            # use cuda op
+            from deformable_detr_ops import ms_deformable_attn
+            self.ms_deformable_attn_core = ms_deformable_attn
+        except:
+            # use torch func
+            self.ms_deformable_attn_core = multi_scale_deformable_attn_pytorch
+
+        self._reset_parameters()
+
+    def _reset_parameters(self):
+        """
+        Default initialization for Parameters of Module.
+        """
+        nn.init.constant_(self.sampling_offsets.weight.data, 0.0)
+        thetas = torch.arange(self.num_heads, dtype=torch.float32) * (
+            2.0 * math.pi / self.num_heads
+        )
+        grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
+        grid_init = (
+            (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
+            .view(self.num_heads, 1, 1, 2)
+            .repeat(1, self.num_levels, self.num_points, 1)
+        )
+        for i in range(self.num_points):
+            grid_init[:, :, i, :] *= i + 1
+        with torch.no_grad():
+            self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1))
+
+        # attention weight
+        nn.init.constant_(self.attention_weights.weight, 0.0)
+        nn.init.constant_(self.attention_weights.bias, 0.0)
+
+        # proj
+        nn.init.xavier_uniform_(self.value_proj.weight)
+        nn.init.constant_(self.value_proj.bias, 0.0)
+        nn.init.xavier_uniform_(self.output_proj.weight)
+        nn.init.constant_(self.output_proj.bias, 0.0)
+
+    def forward(self,
+                query,
+                reference_points,
+                value,
+                value_spatial_shapes,
+                value_mask=None):
+        """
+        Args:
+            query (Tensor): [bs, query_length, C]
+            reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0),
+                bottom-right (1, 1), including padding area
+            value (Tensor): [bs, value_length, C]
+            value_spatial_shapes (Tensor): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})]
+            value_mask (Tensor): [bs, value_length], True for non-padding elements, False for padding elements
+
+        Returns:
+            output (Tensor): [bs, Length_{query}, C]
+        """
+        bs, num_query = query.shape[:2]
+        num_value = value.shape[1]
+        assert sum([s[0] * s[1] for s in value_spatial_shapes]) == num_value
+
+        # Value projection
+        value = self.value_proj(value)
+        # fill "0" for the padding part
+        if value_mask is not None:
+            value_mask = value_mask.astype(value.dtype).unsqueeze(-1)
+            value *= value_mask
+        # [bs, all_hw, 256] -> [bs, all_hw, num_head, head_dim]
+        value = value.reshape([bs, num_value, self.num_heads, -1])
+
+        # [bs, all_hw, num_head, nun_level, num_sample_point, num_offset]
+        sampling_offsets = self.sampling_offsets(query).reshape(
+            [bs, num_query, self.num_heads, self.num_levels, self.num_points, 2])
+        # [bs, all_hw, num_head, nun_level*num_sample_point]
+        attention_weights = self.attention_weights(query).reshape(
+            [bs, num_query, self.num_heads, self.num_levels * self.num_points])
+        # [bs, all_hw, num_head, nun_level, num_sample_point]
+        attention_weights = attention_weights.softmax(-1).reshape(
+            [bs, num_query, self.num_heads, self.num_levels, self.num_points])
+
+        # [bs, num_query, num_heads, num_levels, num_points, 2]
+        if reference_points.shape[-1] == 2:
+            # reference_points   [bs, all_hw, num_sample_point, 2] -> [bs, all_hw, 1, num_sample_point, 1, 2]
+            # sampling_offsets   [bs, all_hw, nun_head, num_level, num_sample_point, 2]
+            # offset_normalizer  [4, 2] -> [1, 1, 1, num_sample_point, 1, 2]
+            # references_points + sampling_offsets
+            offset_normalizer = value_spatial_shapes.flip([1]).reshape(
+                [1, 1, 1, self.num_levels, 1, 2])
+            sampling_locations = (
+                reference_points[:, :, None, :, None, :]
+                + sampling_offsets / offset_normalizer
+            )
+        elif reference_points.shape[-1] == 4:
+            sampling_locations = (
+                reference_points[:, :, None, :, None, :2]
+                + sampling_offsets
+                / self.num_points
+                * reference_points[:, :, None, :, None, 2:]
+                * 0.5)
+        else:
+            raise ValueError(
+                "Last dim of reference_points must be 2 or 4, but get {} instead.".
+                format(reference_points.shape[-1]))
+
+        # Multi-scale Deformable attention
+        output = self.ms_deformable_attn_core(
+            value, value_spatial_shapes, sampling_locations, attention_weights)
+        
+        # Output project
+        output = self.output_proj(output)
+
+        return output
+
+
+# ----------------- Transformer modules -----------------
+## Transformer Encoder layer
+class TransformerEncoderLayer(nn.Module):
+    def __init__(self,
+                 d_model         :int   = 256,
+                 num_heads       :int   = 8,
+                 ffn_dim         :int   = 1024,
+                 dropout         :float = 0.1,
+                 act_type        :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.ffn_dim = ffn_dim
+        self.dropout = dropout
+        self.act_type = act_type
+        # ----------- Basic parameters -----------
+        # Multi-head Self-Attn
+        self.self_attn = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
+        self.dropout = nn.Dropout(dropout)
+        self.norm = nn.LayerNorm(d_model)
+
+        # Feedforwaed Network
+        self.ffn = FFN(d_model, ffn_dim, dropout, act_type)
+
+    def with_pos_embed(self, tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+    def forward(self, src, pos_embed):
+        """
+        Input:
+            src:       [torch.Tensor] -> [B, N, C]
+            pos_embed: [torch.Tensor] -> [B, N, C]
+        Output:
+            src:       [torch.Tensor] -> [B, N, C]
+        """
+        q = k = self.with_pos_embed(src, pos_embed)
+
+        # -------------- MHSA --------------
+        src2 = self.self_attn(q, k, value=src)[0]
+        src = src + self.dropout(src2)
+        src = self.norm(src)
+
+        # -------------- FFN --------------
+        src = self.ffn(src)
+        
+        return src
+
+## Transformer Encoder
+class TransformerEncoder(nn.Module):
+    def __init__(self,
+                 d_model        :int   = 256,
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 ffn_dim        :int   = 1024,
+                 pe_temperature : float = 10000.,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_layers = num_layers
+        self.ffn_dim = ffn_dim
+        self.dropout = dropout
+        self.act_type = act_type
+        self.pe_temperature = pe_temperature
+        self.pos_embed = None
+        # ----------- Basic parameters -----------
+        self.encoder_layers = get_clones(
+            TransformerEncoderLayer(d_model, num_heads, ffn_dim, dropout, act_type), num_layers)
+
+    def build_2d_sincos_position_embedding(self, device, w, h, embed_dim=256, temperature=10000.):
+        assert embed_dim % 4 == 0, \
+            'Embed dimension must be divisible by 4 for 2D sin-cos position embedding'
+        
+        # ----------- Check cahed pos_embed -----------
+        if self.pos_embed is not None and \
+            self.pos_embed.shape[2:] == [h, w]:
+            return self.pos_embed
+        
+        # ----------- Generate grid coords -----------
+        grid_w = torch.arange(int(w), dtype=torch.float32)
+        grid_h = torch.arange(int(h), dtype=torch.float32)
+        grid_w, grid_h = torch.meshgrid([grid_w, grid_h])  # shape: [H, W]
+
+        pos_dim = embed_dim // 4
+        omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim
+        omega = 1. / (temperature**omega)
+
+        out_w = grid_w.flatten()[..., None] @ omega[None] # shape: [N, C]
+        out_h = grid_h.flatten()[..., None] @ omega[None] # shape: [N, C]
+
+        # shape: [1, N, C]
+        pos_embed = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h),torch.cos(out_h)], dim=1)[None, :, :]
+        pos_embed = pos_embed.to(device)
+        self.pos_embed = pos_embed
+
+        return pos_embed
+
+    def forward(self, src):
+        """
+        Input:
+            src:  [torch.Tensor] -> [B, C, H, W]
+        Output:
+            src:  [torch.Tensor] -> [B, C, H, W]
+        """
+        # -------- Transformer encoder --------
+        channels, fmp_h, fmp_w = src.shape[1:]
+        # [B, C, H, W] -> [B, N, C], N=HxW
+        src_flatten = src.flatten(2).permute(0, 2, 1).contiguous()
+        memory = src_flatten
+
+        # PosEmbed: [1, N, C]
+        pos_embed = self.build_2d_sincos_position_embedding(
+            src.device, fmp_w, fmp_h, channels, self.pe_temperature)
+        
+        # Transformer Encoder layer
+        for encoder in self.encoder_layers:
+            memory = encoder(memory, pos_embed=pos_embed)
+
+        # Output: [B, N, C] -> [B, C, N] -> [B, C, H, W]
+        src = memory.permute(0, 2, 1).contiguous()
+        src = src.view([-1, channels, fmp_h, fmp_w])
+
+        return src
+
+## Transformer Decoder layer
+class DeformableTransformerDecoderLayer(nn.Module):
+    def __init__(self,
+                 d_model     :int   = 256,
+                 num_heads   :int   = 8,
+                 num_levels  :int   = 3,
+                 num_points  :int   = 4,
+                 ffn_dim     :int   = 1024,
+                 dropout     :float = 0.1,
+                 act_type    :str   = "relu",
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_levels = num_levels
+        self.num_points = num_points
+        self.ffn_dim = ffn_dim
+        self.dropout = dropout
+        self.act_type = act_type
+        # ---------------- Network parameters ----------------
+        ## Multi-head Self-Attn
+        self.self_attn  = nn.MultiheadAttention(d_model, num_heads, dropout=dropout, batch_first=True)
+        self.dropout1 = nn.Dropout(dropout)
+        self.norm1 = nn.LayerNorm(d_model)
+        ## CrossAttention
+        self.cross_attn = MSDeformableAttention(d_model, num_heads, num_levels, num_points)
+        self.dropout2 = nn.Dropout(dropout)
+        self.norm2 = nn.LayerNorm(d_model)
+        ## FFN
+        self.ffn = FFN(d_model, ffn_dim, dropout, act_type)
+
+    def with_pos_embed(self, tensor, pos):
+        return tensor if pos is None else tensor + pos
+
+    def forward(self,
+                tgt,
+                reference_points,
+                memory,
+                memory_spatial_shapes,
+                attn_mask=None,
+                memory_mask=None,
+                query_pos_embed=None):
+        # ---------------- MSHA for Object Query -----------------
+        q = k = self.with_pos_embed(tgt, query_pos_embed)
+        tgt2 = self.self_attn(q, k, value=tgt, attn_mask=attn_mask)[0]
+        tgt = tgt + self.dropout1(tgt2)
+        tgt = self.norm1(tgt)
+
+        # ---------------- CMHA for Object Query and Image-feature -----------------
+        tgt2 = self.cross_attn(self.with_pos_embed(tgt, query_pos_embed),
+                               reference_points,
+                               memory,
+                               memory_spatial_shapes,
+                               memory_mask)
+        tgt = tgt + self.dropout2(tgt2)
+        tgt = self.norm2(tgt)
+
+        # ---------------- FeedForward Network -----------------
+        tgt = self.ffn(tgt)
+
+        return tgt
+
+## Transformer Decoder
+class DeformableTransformerDecoder(nn.Module):
+    def __init__(self,
+                 d_model        :int   = 256,
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 num_levels     :int   = 3,
+                 num_points     :int   = 4,
+                 ffn_dim        :int   = 1024,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 return_intermediate :bool = False,
+                 ):
+        super().__init__()
+        # ----------- Basic parameters -----------
+        self.d_model = d_model
+        self.num_heads = num_heads
+        self.num_layers = num_layers
+        self.ffn_dim = ffn_dim
+        self.dropout = dropout
+        self.act_type = act_type
+        self.pos_embed = None
+        # ----------- Network parameters -----------
+        self.decoder_layers = get_clones(
+            DeformableTransformerDecoderLayer(d_model, num_heads, num_levels, num_points, ffn_dim, dropout, act_type), num_layers)
+        self.num_layers = num_layers
+        self.return_intermediate = return_intermediate
+
+    def forward(self,
+                tgt,
+                ref_points_unact,
+                memory,
+                memory_spatial_shapes,
+                bbox_head,
+                score_head,
+                query_pos_head,
+                attn_mask=None,
+                memory_mask=None):
+        output = tgt
+        dec_out_bboxes = []
+        dec_out_logits = []
+        ref_points_detach = F.sigmoid(ref_points_unact)
+        for i, layer in enumerate(self.decoder_layers):
+            ref_points_input = ref_points_detach.unsqueeze(2)
+            query_pos_embed = query_pos_head(ref_points_detach)
+
+            output = layer(output, ref_points_input, memory,
+                           memory_spatial_shapes, attn_mask,
+                           memory_mask, query_pos_embed)
+
+            inter_ref_bbox = F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points_detach))
+
+            dec_out_logits.append(score_head[i](output))
+            if i == 0:
+                dec_out_bboxes.append(inter_ref_bbox)
+            else:
+                dec_out_bboxes.append(
+                    F.sigmoid(bbox_head[i](output) + inverse_sigmoid(ref_points)))
+
+            ref_points = inter_ref_bbox
+            ref_points_detach = inter_ref_bbox.detach() if self.training else inter_ref_bbox
+
+        return torch.stack(dec_out_bboxes), torch.stack(dec_out_logits)
+

+ 16 - 0
models/rtdetr/build.py

@@ -0,0 +1,16 @@
+from .loss import SetCriterion
+from .rtdetr import RTDETR
+
+
+# build object detector
+def build_rtdetr(cfg, is_val=False):    
+    # -------------- Build RT-DETR --------------
+    model = RTDETR(cfg, is_val, use_nms=True, onnx_deploy=False)
+            
+    # -------------- Build criterion --------------
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+        
+    return model, criterion

+ 170 - 0
models/rtdetr/loss.py

@@ -0,0 +1,170 @@
+"""
+reference: 
+https://github.com/facebookresearch/detr/blob/main/models/detr.py
+
+by lyuwenyu
+"""
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .loss_utils import box_cxcywh_to_xyxy, box_iou, generalized_box_iou
+from .loss_utils import is_dist_avail_and_initialized, get_world_size
+from .matcher import HungarianMatcher
+
+
+# --------------- Criterion for RT-DETR ---------------
+class SetCriterion(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.num_classes = cfg.num_classes
+        self.losses = ['labels', 'boxes']
+
+        self.alpha = 0.75  # For VFL
+        self.gamma = 2.0
+
+        self.matcher = HungarianMatcher(cfg.cost_class, cfg.cost_bbox, cfg.cost_giou, alpha=0.25, gamma=2.0)
+        self.weight_dict = {'loss_cls':  cfg.loss_cls,
+                            'loss_box':  cfg.loss_box,
+                            'loss_giou': cfg.loss_giou}
+
+    def loss_labels(self, outputs, targets, indices, num_boxes):
+        "Compute variable focal loss"
+        assert 'pred_boxes' in outputs
+        idx = self._get_src_permutation_idx(indices)
+        # Compute IoU between pred and target
+        src_boxes = outputs['pred_boxes'][idx]
+        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
+        ious, _ = box_iou(box_cxcywh_to_xyxy(src_boxes), box_cxcywh_to_xyxy(target_boxes))
+        ious = torch.diag(ious).detach()
+
+        # One-hot class label
+        src_logits = outputs['pred_logits']
+        target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
+        target_classes = torch.full(src_logits.shape[:2], self.num_classes,
+                                    dtype=torch.int64, device=src_logits.device)
+        target_classes[idx] = target_classes_o
+        target = F.one_hot(target_classes, num_classes=self.num_classes + 1)[..., :-1]
+
+        # Iou-aware class label
+        target_score_o = torch.zeros_like(target_classes, dtype=src_logits.dtype)
+        target_score_o[idx] = ious.to(target_score_o.dtype)
+        target_score = target_score_o.unsqueeze(-1) * target
+
+        # Compute VFL
+        pred_score = F.sigmoid(src_logits).detach()
+        weight = self.alpha * pred_score.pow(self.gamma) * (1 - target) + target_score
+        
+        loss = F.binary_cross_entropy_with_logits(src_logits, target_score, weight=weight, reduction='none')
+        loss = loss.mean(1).sum() * src_logits.shape[1] / num_boxes
+
+        return {'loss_cls': loss}
+
+    def loss_boxes(self, outputs, targets, indices, num_boxes):
+        """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss
+           targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4]
+           The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size.
+        """
+        assert 'pred_boxes' in outputs
+        idx = self._get_src_permutation_idx(indices)
+        src_boxes = outputs['pred_boxes'][idx]
+        target_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0)
+
+        losses = {}
+
+        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction='none')
+        losses['loss_box'] = loss_bbox.sum() / num_boxes
+
+        loss_giou = 1 - torch.diag(generalized_box_iou(
+                box_cxcywh_to_xyxy(src_boxes),
+                box_cxcywh_to_xyxy(target_boxes)))
+        losses['loss_giou'] = loss_giou.sum() / num_boxes
+        return losses
+
+    def _get_src_permutation_idx(self, indices):
+        # permute predictions following indices
+        batch_idx = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
+        src_idx = torch.cat([src for (src, _) in indices])
+        return batch_idx, src_idx
+
+    def _get_tgt_permutation_idx(self, indices):
+        # permute targets following indices
+        batch_idx = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
+        tgt_idx = torch.cat([tgt for (_, tgt) in indices])
+        return batch_idx, tgt_idx
+
+    def get_loss(self, loss, outputs, targets, indices, num_boxes, **kwargs):
+        loss_map = {
+            'boxes': self.loss_boxes,
+            'labels': self.loss_labels,
+        }
+        assert loss in loss_map, f'do you really want to compute {loss} loss?'
+        return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
+
+    def forward(self, outputs, targets):
+        outputs_without_aux = {k: v for k, v in outputs.items() if 'aux' not in k}
+
+        # Retrieve the matching between the outputs of the last layer and the targets
+        indices = self.matcher(outputs_without_aux, targets)
+
+        # Compute the average number of target boxes accross all nodes, for normalization purposes
+        num_boxes = sum(len(t["labels"]) for t in targets)
+        num_boxes = torch.as_tensor([num_boxes], dtype=torch.float, device=next(iter(outputs.values())).device)
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_boxes)
+        num_boxes = torch.clamp(num_boxes / get_world_size(), min=1).item()
+
+        # Compute all the requested losses
+        losses = {}
+        for loss in self.losses:
+            l_dict = self.get_loss(loss, outputs, targets, indices, num_boxes)
+            l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
+            losses.update(l_dict)
+
+        # In case of auxiliary losses, we repeat this process with the output of each intermediate layer.
+        if 'aux_outputs' in outputs:
+            for i, aux_outputs in enumerate(outputs['aux_outputs']):
+                indices = self.matcher(aux_outputs, targets)
+                for loss in self.losses:
+                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
+                    l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
+                    l_dict = {k + f'_aux_{i}': v for k, v in l_dict.items()}
+                    losses.update(l_dict)
+
+        # In case of cdn auxiliary losses. For rtdetr
+        if 'dn_aux_outputs' in outputs:
+            assert 'dn_meta' in outputs, ''
+            indices = self.get_cdn_matched_indices(outputs['dn_meta'], targets)
+            num_boxes = num_boxes * outputs['dn_meta']['dn_num_group']
+
+            for i, aux_outputs in enumerate(outputs['dn_aux_outputs']):
+                # indices = self.matcher(aux_outputs, targets)
+                for loss in self.losses:
+                    l_dict = self.get_loss(loss, aux_outputs, targets, indices, num_boxes)
+                    l_dict = {k: l_dict[k] * self.weight_dict[k] for k in l_dict if k in self.weight_dict}
+                    l_dict = {k + f'_dn_{i}': v for k, v in l_dict.items()}
+                    losses.update(l_dict)
+
+        return losses
+
+    @staticmethod
+    def get_cdn_matched_indices(dn_meta, targets):
+        '''get_cdn_matched_indices
+        '''
+        dn_positive_idx, dn_num_group = dn_meta["dn_positive_idx"], dn_meta["dn_num_group"]
+        num_gts = [len(t['labels']) for t in targets]
+        device = targets[0]['labels'].device
+        
+        dn_match_indices = []
+        for i, num_gt in enumerate(num_gts):
+            if num_gt > 0:
+                gt_idx = torch.arange(num_gt, dtype=torch.int64, device=device)
+                gt_idx = gt_idx.tile(dn_num_group)
+                assert len(dn_positive_idx[i]) == len(gt_idx)
+                dn_match_indices.append((dn_positive_idx[i], gt_idx))
+            else:
+                dn_match_indices.append((torch.zeros(0, dtype=torch.int64, device=device), \
+                    torch.zeros(0, dtype=torch.int64,  device=device)))
+        
+        return dn_match_indices

+ 240 - 0
models/rtdetr/loss_utils.py

@@ -0,0 +1,240 @@
+import math
+import torch
+import torch.nn.functional as F
+import torch.distributed as dist
+from torchvision.ops.boxes import box_area
+
+
+# ------------------------- For loss -------------------------
+## FocalLoss
+def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
+    """
+    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+        alpha: (optional) Weighting factor in range (0,1) to balance
+                positive vs negative examples. Default = -1 (no weighting).
+        gamma: Exponent of the modulating factor (1 - p_t) to
+               balance easy vs hard examples.
+    Returns:
+        Loss tensor
+    """
+    prob = inputs.sigmoid()
+    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+    p_t = prob * targets + (1 - prob) * (1 - targets)
+    loss = ce_loss * ((1 - p_t) ** gamma)
+
+    if alpha >= 0:
+        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+        loss = alpha_t * loss
+
+    return loss.mean(1).sum() / num_boxes
+
+## Variable FocalLoss
+def varifocal_loss_with_logits(pred_logits,
+                               gt_score,
+                               label,
+                               normalizer=1.0,
+                               alpha=0.75,
+                               gamma=2.0):
+    pred_score = F.sigmoid(pred_logits)
+    weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
+    loss = F.binary_cross_entropy_with_logits(pred_logits, gt_score, reduction='none')
+    loss = loss * weight
+
+    return loss.mean(1).sum() / normalizer
+
+## InverseSigmoid
+def inverse_sigmoid(x, eps=1e-5):
+    x = x.clamp(min=0, max=1)
+    x1 = x.clamp(min=eps)
+    x2 = (1 - x).clamp(min=eps)
+    return torch.log(x1/x2)
+
+## GIoU loss
+class GIoULoss(object):
+    """ Modified GIoULoss from Paddle-Paddle"""
+    def __init__(self, eps=1e-10, reduction='none'):
+        self.eps = eps
+        self.reduction = reduction
+        assert reduction in ('none', 'mean', 'sum')
+
+    def bbox_overlap(self, box1, box2, eps=1e-10):
+        """calculate the iou of box1 and box2
+        Args:
+            box1 (Tensor): box1 with the shape (..., 4)
+            box2 (Tensor): box1 with the shape (..., 4)
+            eps (float): epsilon to avoid divide by zero
+        Return:
+            iou (Tensor): iou of box1 and box2
+            overlap (Tensor): overlap of box1 and box2
+            union (Tensor): union of box1 and box2
+        """
+        x1, y1, x2, y2 = box1
+        x1g, y1g, x2g, y2g = box2
+
+        xkis1 = torch.max(x1, x1g)
+        ykis1 = torch.max(y1, y1g)
+        xkis2 = torch.min(x2, x2g)
+        ykis2 = torch.min(y2, y2g)
+        w_inter = (xkis2 - xkis1).clip(0)
+        h_inter = (ykis2 - ykis1).clip(0)
+        overlap = w_inter * h_inter
+
+        area1 = (x2 - x1) * (y2 - y1)
+        area2 = (x2g - x1g) * (y2g - y1g)
+        union = area1 + area2 - overlap + eps
+        iou = overlap / union
+
+        return iou, overlap, union
+
+    def __call__(self, pbox, gbox):
+        # x1, y1, x2, y2 = torch.split(pbox, 4, dim=-1)
+        # x1g, y1g, x2g, y2g = torch.split(gbox, 4, dim=-1)
+        x1, y1, x2, y2 = torch.chunk(pbox, 4, dim=-1)
+        x1g, y1g, x2g, y2g = torch.chunk(gbox, 4, dim=-1)
+        box1 = [x1, y1, x2, y2]
+        box2 = [x1g, y1g, x2g, y2g]
+        iou, _, union = self.bbox_overlap(box1, box2, self.eps)
+        xc1 = torch.min(x1, x1g)
+        yc1 = torch.min(y1, y1g)
+        xc2 = torch.max(x2, x2g)
+        yc2 = torch.max(y2, y2g)
+
+        area_c = (xc2 - xc1) * (yc2 - yc1) + self.eps
+        miou = iou - ((area_c - union) / area_c)
+        giou = 1 - miou
+
+        if self.reduction == 'none':
+            loss = giou
+        elif self.reduction == 'sum':
+            loss = giou.sum()
+        elif self.reduction == 'mean':
+            loss = giou.mean()
+
+        return loss
+
+
+# ------------------------- For box -------------------------
+def box_cxcywh_to_xyxy(x):
+    x_c, y_c, w, h = x.unbind(-1)
+    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
+         (x_c + 0.5 * w), (y_c + 0.5 * h)]
+    return torch.stack(b, dim=-1)
+
+def box_xyxy_to_cxcywh(x):
+    x0, y0, x1, y1 = x.unbind(-1)
+    b = [(x0 + x1) / 2, (y0 + y1) / 2,
+         (x1 - x0), (y1 - y0)]
+    return torch.stack(b, dim=-1)
+
+def box_iou(boxes1, boxes2):
+    area1 = box_area(boxes1)
+    area2 = box_area(boxes2)
+
+    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
+    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
+
+    wh = (rb - lt).clamp(min=0)  # [N,M,2]
+    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
+
+    union = area1[:, None] + area2 - inter
+
+    iou = inter / union
+    return iou, union
+
+def generalized_box_iou(boxes1, boxes2):
+    """
+    Generalized IoU from https://giou.stanford.edu/
+
+    The boxes should be in [x0, y0, x1, y1] format
+
+    Returns a [N, M] pairwise matrix, where N = len(boxes1)
+    and M = len(boxes2)
+    """
+    # degenerate boxes gives inf / nan results
+    # so do an early check
+    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+    iou, union = box_iou(boxes1, boxes2)
+
+    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+    wh = (rb - lt).clamp(min=0)  # [N,M,2]
+    area = wh[:, :, 0] * wh[:, :, 1]
+
+    return iou - (area - union) / area
+
+def bbox_iou(box1, box2, giou=False, diou=False, ciou=False, eps=1e-9):
+    """Modified from Paddle-paddle
+    Args:
+        box1 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
+        box2 (list): [x, y, w, h], all have the shape [b, na, h, w, 1]
+        giou (bool): whether use giou or not, default False
+        diou (bool): whether use diou or not, default False
+        ciou (bool): whether use ciou or not, default False
+        eps (float): epsilon to avoid divide by zero
+    Return:
+        iou (Tensor): iou of box1 and box1, with the shape [b, na, h, w, 1]
+    """
+    px1, py1, px2, py2 = torch.chunk(box1, 4, -1)
+    gx1, gy1, gx2, gy2 = torch.chunk(box2, 4, -1)
+    x1 = torch.max(px1, gx1)
+    y1 = torch.max(py1, gy1)
+    x2 = torch.min(px2, gx2)
+    y2 = torch.min(py2, gy2)
+
+    overlap = ((x2 - x1).clamp(0)) * ((y2 - y1).clamp(0))
+
+    area1 = (px2 - px1) * (py2 - py1)
+    area1 = area1.clamp(0)
+
+    area2 = (gx2 - gx1) * (gy2 - gy1)
+    area2 = area2.clamp(0)
+
+    union = area1 + area2 - overlap + eps
+    iou = overlap / union
+
+    if giou or ciou or diou:
+        # convex w, h
+        cw = torch.max(px2, gx2) - torch.min(px1, gx1)
+        ch = torch.max(py2, gy2) - torch.min(py1, gy1)
+        if giou:
+            c_area = cw * ch + eps
+            return iou - (c_area - union) / c_area
+        else:
+            # convex diagonal squared
+            c2 = cw**2 + ch**2 + eps
+            # center distance
+            rho2 = ((px1 + px2 - gx1 - gx2)**2 + (py1 + py2 - gy1 - gy2)**2) / 4
+            if diou:
+                return iou - rho2 / c2
+            else:
+                w1, h1 = px2 - px1, py2 - py1 + eps
+                w2, h2 = gx2 - gx1, gy2 - gy1 + eps
+                delta = torch.atan(w1 / h1) - torch.atan(w2 / h2)
+                v = (4 / math.pi**2) * torch.pow(delta, 2)
+                alpha = v / (1 + eps - iou + v)
+                alpha.requires_grad_ = False
+                return iou - (rho2 / c2 + v * alpha)
+    else:
+        return iou
+
+
+# ------------------------- For distributed -------------------------
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()

+ 52 - 0
models/rtdetr/matcher.py

@@ -0,0 +1,52 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from scipy.optimize import linear_sum_assignment
+
+from .loss_utils import box_cxcywh_to_xyxy, generalized_box_iou
+
+
+class HungarianMatcher(nn.Module):
+    def __init__(self, cost_class=2.0, cost_bbox=5.0, cost_giou=2.0, alpha=0.25, gamma=2.0):
+        super().__init__()
+        self.cost_class = cost_class
+        self.cost_bbox  = cost_bbox
+        self.cost_giou  = cost_giou
+
+        self.alpha = alpha
+        self.gamma = gamma
+
+        assert self.cost_class != 0 or self.cost_bbox != 0 or self.cost_giou != 0, "all costs cant be 0"
+
+    @torch.no_grad()
+    def forward(self, outputs, targets):
+        bs, num_queries = outputs["pred_logits"].shape[:2]
+
+        # We flatten to compute the cost matrices in a batch
+        out_prob = F.sigmoid(outputs["pred_logits"].flatten(0, 1))
+        out_bbox = outputs["pred_boxes"].flatten(0, 1)  # [batch_size * num_queries, 4]
+
+        # Also concat the target labels and boxes
+        tgt_ids = torch.cat([v["labels"] for v in targets])
+        tgt_bbox = torch.cat([v["boxes"] for v in targets])
+
+        # Compute the classification cost
+        out_prob = out_prob[:, tgt_ids]
+        neg_cost_class = (1 - self.alpha) * (out_prob**self.gamma) * (-(1 - out_prob + 1e-8).log())
+        pos_cost_class = self.alpha * ((1 - out_prob)**self.gamma) * (-(out_prob + 1e-8).log())
+        cost_class = pos_cost_class - neg_cost_class        
+
+        # Compute the L1 cost between boxes
+        cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
+
+        # Compute the giou cost betwen boxes
+        cost_giou = -generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox))
+        
+        # Final cost matrix
+        C = self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou
+        C = C.view(bs, num_queries, -1).cpu()
+
+        sizes = [len(v["boxes"]) for v in targets]
+        indices = [linear_sum_assignment(c[i]) for i, c in enumerate(C.split(sizes, -1))]
+
+        return [(torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices]

+ 143 - 0
models/rtdetr/rtdetr.py

@@ -0,0 +1,143 @@
+import torch
+import torch.nn as nn
+
+from .rtdetr_encoder import ImageEncoder
+from .rtdetr_decoder import RTDetrTransformer
+
+from .basic_modules.nms_ops import multiclass_nms
+
+
+# Real-time DETR
+class RTDETR(nn.Module):
+    def __init__(self,
+                 cfg,
+                 is_val = False,
+                 use_nms = False,
+                 onnx_deploy = False,
+                 ) -> None:
+        super(RTDETR, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.use_nms = use_nms
+        self.onnx_deploy = onnx_deploy
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh     = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.nms_thresh      = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
+        self.no_multi_labels = False if is_val else True
+
+
+        # ----------- Network setting -----------
+        ## Image encoder
+        self.image_encoder = ImageEncoder(cfg)
+        ## Detect decoder
+        self.detect_decoder = RTDetrTransformer(in_dims             = self.image_encoder.fpn_dims,
+                                                hidden_dim          = cfg.hidden_dim,
+                                                strides             = cfg.out_stride,
+                                                num_classes         = cfg.num_classes,
+                                                num_queries         = cfg.num_queries,
+                                                num_heads           = cfg.de_num_heads,
+                                                num_layers          = cfg.de_num_layers,
+                                                num_levels          = len(cfg.out_stride),
+                                                num_points          = cfg.de_num_points,
+                                                ffn_dim             = cfg.de_ffn_dim,
+                                                dropout             = cfg.de_dropout,
+                                                act_type            = cfg.de_act,
+                                                return_intermediate = True,
+                                                num_denoising       = cfg.dn_num_denoising,
+                                                label_noise_ratio   = cfg.dn_label_noise_ratio,
+                                                box_noise_scale     = cfg.dn_box_noise_scale,
+                                                learnt_init_query   = cfg.learnt_init_query,
+                                                )
+
+    def post_process(self, box_pred, cls_pred):
+        # xywh -> xyxy
+        box_preds_x1y1 = box_pred[..., :2] - 0.5 * box_pred[..., 2:]
+        box_preds_x2y2 = box_pred[..., :2] + 0.5 * box_pred[..., 2:]
+        box_pred = torch.cat([box_preds_x1y1, box_preds_x2y2], dim=-1)
+        
+        cls_pred = cls_pred[0]
+        box_pred = box_pred[0]
+        if self.no_multi_labels:
+            # [M,]
+            scores, labels = torch.max(cls_pred.sigmoid(), dim=1)
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk_candidates, box_pred.size(0))
+
+            # Topk candidates
+            predicted_prob, topk_idxs = scores.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # Filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            topk_idxs = topk_idxs[keep_idxs]
+
+            # Top-k results
+            topk_scores = topk_scores[keep_idxs]
+            topk_labels = labels[topk_idxs]
+            topk_bboxes = box_pred[topk_idxs]
+
+        else:
+            # Top-k select
+            cls_pred = cls_pred.flatten().sigmoid_()
+            box_pred = box_pred
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk_candidates, box_pred.size(0))
+
+            # Topk candidates
+            predicted_prob, topk_idxs = cls_pred.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # Filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            topk_scores = topk_scores[keep_idxs]
+            topk_idxs = topk_idxs[keep_idxs]
+            topk_box_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+
+            ## Top-k results
+            topk_labels = topk_idxs % self.num_classes
+            topk_bboxes = box_pred[topk_box_idxs]
+
+        if not self.onnx_deploy:
+            topk_scores = topk_scores.cpu().numpy()
+            topk_labels = topk_labels.cpu().numpy()
+            topk_bboxes = topk_bboxes.cpu().numpy()
+
+            # nms
+            if self.use_nms:
+                topk_scores, topk_labels, topk_bboxes = multiclass_nms(
+                    topk_scores, topk_labels, topk_bboxes, self.nms_thresh, self.num_classes)
+
+        return topk_bboxes, topk_scores, topk_labels
+    
+    def forward(self, x, targets=None):
+        # ----------- Image Encoder -----------
+        pyramid_feats = self.image_encoder(x)
+
+        # ----------- Transformer -----------
+        outputs = self.detect_decoder(pyramid_feats, targets)
+
+        if not self.training:
+            img_h, img_w = x.shape[2:]
+            box_pred = outputs["pred_boxes"]
+            cls_pred = outputs["pred_logits"]
+
+            # rescale bbox
+            box_pred[..., [0, 2]] *= img_h
+            box_pred[..., [1, 3]] *= img_w
+            
+            # post-process
+            bboxes, scores, labels = self.post_process(box_pred, cls_pred)
+
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes,
+            }
+
+        return outputs

+ 304 - 0
models/rtdetr/rtdetr_decoder.py

@@ -0,0 +1,304 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List
+
+from .basic_modules.conv import BasicConv
+from .basic_modules.mlp  import MLP
+from .basic_modules.transformer import DeformableTransformerDecoder
+from .basic_modules.dn_compoments import get_contrastive_denoising_training_group
+
+
+# ----------------- Dencoder for Detection task -----------------
+## RTDETR's Transformer for Detection task
+class RTDetrTransformer(nn.Module):
+    def __init__(self,
+                 # basic parameters
+                 in_dims        :List = [256, 512, 1024],
+                 hidden_dim     :int  = 256,
+                 strides        :List = [8, 16, 32],
+                 num_classes    :int  = 80,
+                 num_queries    :int  = 300,
+                 # transformer parameters
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 num_levels     :int   = 3,
+                 num_points     :int   = 4,
+                 ffn_dim        :int   = 1024,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 return_intermediate :bool = False,
+                 # Denoising parameters
+                 num_denoising       :int  = 100,
+                 label_noise_ratio   :float = 0.5,
+                 box_noise_scale     :float = 1.0,
+                 learnt_init_query   :bool  = False,
+                 aux_loss            :bool  = True
+                 ):
+        super().__init__()
+        # --------------- Basic setting ---------------
+        ## Basic parameters
+        self.in_dims = in_dims
+        self.strides = strides
+        self.num_queries = num_queries
+        self.num_classes = num_classes
+        self.eps = 1e-2
+        self.aux_loss = aux_loss
+        ## Transformer parameters
+        self.num_heads  = num_heads
+        self.num_layers = num_layers
+        self.num_levels = num_levels
+        self.num_points = num_points
+        self.ffn_dim  = ffn_dim
+        self.dropout    = dropout
+        self.act_type   = act_type
+        self.return_intermediate = return_intermediate
+        ## Denoising parameters
+        self.num_denoising = num_denoising
+        self.label_noise_ratio = label_noise_ratio
+        self.box_noise_scale = box_noise_scale
+        self.learnt_init_query = learnt_init_query
+
+        # --------------- Network setting ---------------
+        ## Input proj layers
+        self.input_proj_layers = nn.ModuleList(
+            BasicConv(in_dims[i], hidden_dim, kernel_size=1, act_type=None, norm_type="BN")
+            for i in range(num_levels)
+        )
+
+        ## Deformable transformer decoder
+        self.decoder = DeformableTransformerDecoder(
+                                    d_model    = hidden_dim,
+                                    num_heads  = num_heads,
+                                    num_layers = num_layers,
+                                    num_levels = num_levels,
+                                    num_points = num_points,
+                                    ffn_dim  = ffn_dim,
+                                    dropout    = dropout,
+                                    act_type   = act_type,
+                                    return_intermediate = return_intermediate
+                                    )
+        
+        ## Detection head for Encoder
+        self.enc_output = nn.Sequential(
+            nn.Linear(hidden_dim, hidden_dim),
+            nn.LayerNorm(hidden_dim)
+            )
+        self.enc_class_head = nn.Linear(hidden_dim, num_classes)
+        self.enc_bbox_head = MLP(hidden_dim, hidden_dim, 4, num_layers=3)
+
+        ## Detection head for Decoder
+        self.dec_class_head = nn.ModuleList([
+            nn.Linear(hidden_dim, num_classes)
+            for _ in range(num_layers)
+        ])
+        self.dec_bbox_head = nn.ModuleList([
+            MLP(hidden_dim, hidden_dim, 4, num_layers=3)
+            for _ in range(num_layers)
+        ])
+
+        ## Object query
+        if learnt_init_query:
+            self.tgt_embed = nn.Embedding(num_queries, hidden_dim)
+        self.query_pos_head = MLP(4, 2 * hidden_dim, hidden_dim, num_layers=2)
+
+        ## Denoising part
+        if num_denoising > 0: 
+            self.denoising_class_embed = nn.Embedding(num_classes+1, hidden_dim, padding_idx=num_classes)
+
+        self._reset_parameters()
+
+    def _reset_parameters(self):
+        # class and bbox head init
+        prior_prob = 0.01
+        cls_bias_init = float(-math.log((1 - prior_prob) / prior_prob))
+
+        nn.init.constant_(self.enc_class_head.bias, cls_bias_init)
+        nn.init.constant_(self.enc_bbox_head.layers[-1].weight, 0.)
+        nn.init.constant_(self.enc_bbox_head.layers[-1].bias, 0.)
+        for cls_, reg_ in zip(self.dec_class_head, self.dec_bbox_head):
+            nn.init.constant_(cls_.bias, cls_bias_init)
+            nn.init.constant_(reg_.layers[-1].weight, 0.)
+            nn.init.constant_(reg_.layers[-1].bias, 0.)
+
+        nn.init.xavier_uniform_(self.enc_output[0].weight)
+        if self.learnt_init_query:
+            nn.init.xavier_uniform_(self.tgt_embed.weight)
+        nn.init.xavier_uniform_(self.query_pos_head.layers[0].weight)
+        nn.init.xavier_uniform_(self.query_pos_head.layers[1].weight)
+
+    @torch.jit.unused
+    def _set_aux_loss(self, outputs_class, outputs_coord):
+        # this is a workaround to make torchscript happy, as torchscript
+        # doesn't support dictionary with non-homogeneous values, such
+        # as a dict having both a Tensor and a list.
+        return [{'pred_logits': a, 'pred_boxes': b}
+                for a, b in zip(outputs_class, outputs_coord)]
+
+    def generate_anchors(self, spatial_shapes, grid_size=0.05):
+        anchors = []
+        for lvl, (h, w) in enumerate(spatial_shapes):
+            grid_y, grid_x = torch.meshgrid(torch.arange(h), torch.arange(w))
+            # [H, W, 2]
+            grid_xy = torch.stack([grid_x, grid_y], dim=-1).float()
+
+            valid_WH = torch.as_tensor([w, h]).float()
+            grid_xy = (grid_xy.unsqueeze(0) + 0.5) / valid_WH
+            wh = torch.ones_like(grid_xy) * grid_size * (2.0**lvl)
+            # [H, W, 4] -> [1, N, 4], N=HxW
+            anchors.append(torch.cat([grid_xy, wh], dim=-1).reshape(-1, h * w, 4))
+        # List[L, 1, N_i, 4] -> [1, N, 4], N=N_0 + N_1 + N_2 + ...
+        anchors = torch.cat(anchors, dim=1)
+        valid_mask = ((anchors > self.eps) * (anchors < 1 - self.eps)).all(-1, keepdim=True)
+        anchors = torch.log(anchors / (1 - anchors))
+        # Equal to operation: anchors = torch.masked_fill(anchors, ~valid_mask, torch.as_tensor(float("inf")))
+        anchors = torch.where(valid_mask, anchors, torch.inf)
+        
+        return anchors, valid_mask
+    
+    def get_encoder_input(self, feats):
+        # get projection features
+        proj_feats = [self.input_proj_layers[i](feat) for i, feat in enumerate(feats)]
+
+        # get encoder inputs
+        feat_flatten = []
+        spatial_shapes = []
+        level_start_index = [0, ]
+        for i, feat in enumerate(proj_feats):
+            _, _, h, w = feat.shape
+            spatial_shapes.append([h, w])
+            # [l], start index of each level
+            level_start_index.append(h * w + level_start_index[-1])
+            # [B, C, H, W] -> [B, N, C], N=HxW
+            feat_flatten.append(feat.flatten(2).permute(0, 2, 1).contiguous())
+
+        # [B, N, C], N = N_0 + N_1 + ...
+        feat_flatten = torch.cat(feat_flatten, dim=1)
+        level_start_index.pop()
+
+        return (feat_flatten, spatial_shapes, level_start_index)
+
+    def get_decoder_input(self,
+                          memory,
+                          spatial_shapes,
+                          denoising_class=None,
+                          denoising_bbox_unact=None):
+        bs, _, _ = memory.shape
+        # Prepare input for decoder
+        anchors, valid_mask = self.generate_anchors(spatial_shapes)
+        anchors = anchors.to(memory.device)
+        valid_mask = valid_mask.to(memory.device)
+        
+        # Process encoder's output
+        memory = torch.where(valid_mask, memory, torch.as_tensor(0., device=memory.device))
+        output_memory = self.enc_output(memory)
+
+        # Head for encoder's output : [bs, num_quries, c]
+        enc_outputs_class = self.enc_class_head(output_memory)
+        enc_outputs_coord_unact = self.enc_bbox_head(output_memory) + anchors
+
+        # Topk proposals from encoder's output
+        topk = self.num_queries
+        topk_ind = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1]  # [bs, num_queries]
+        enc_topk_logits = torch.gather(
+            enc_outputs_class, 1, topk_ind.unsqueeze(-1).repeat(1, 1, self.num_classes))  # [bs, num_queries, nc]
+        reference_points_unact = torch.gather(
+            enc_outputs_coord_unact, 1, topk_ind.unsqueeze(-1).repeat(1, 1, 4))    # [bs, num_queries, 4]
+        enc_topk_bboxes = F.sigmoid(reference_points_unact)
+
+        if denoising_bbox_unact is not None:
+            reference_points_unact = torch.cat(
+                [denoising_bbox_unact, reference_points_unact], dim=1)
+
+        # Extract region features
+        if self.learnt_init_query:
+            # [num_queries, c] -> [b, num_queries, c]
+            target = self.tgt_embed.weight.unsqueeze(0).repeat(bs, 1, 1)
+        else:
+            # [num_queries, c] -> [b, num_queries, c]
+            target = torch.gather(output_memory, 1, topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
+            target = target.detach()
+        
+        if denoising_class is not None:
+            target = torch.cat([denoising_class, target], dim=1)
+
+        return target, reference_points_unact.detach(), enc_topk_bboxes, enc_topk_logits
+    
+    def forward(self, feats, targets=None):
+        # input projection and embedding
+        memory, spatial_shapes, _ = self.get_encoder_input(feats)
+
+        # prepare denoising training
+        if self.training and self.num_denoising > 0:
+            denoising_class, denoising_bbox_unact, attn_mask, dn_meta = \
+                get_contrastive_denoising_training_group(targets, \
+                                                         self.num_classes, 
+                                                         self.num_queries, 
+                                                         self.denoising_class_embed, 
+                                                         num_denoising=self.num_denoising, 
+                                                         label_noise_ratio=self.label_noise_ratio, 
+                                                         box_noise_scale=self.box_noise_scale, )
+        else:
+            denoising_class, denoising_bbox_unact, attn_mask, dn_meta = None, None, None, None
+
+        target, init_ref_points_unact, enc_topk_bboxes, enc_topk_logits = \
+            self.get_decoder_input(
+            memory, spatial_shapes, denoising_class, denoising_bbox_unact)
+
+        # decoder
+        out_bboxes, out_logits = self.decoder(target,
+                                              init_ref_points_unact,
+                                              memory,
+                                              spatial_shapes,
+                                              self.dec_bbox_head,
+                                              self.dec_class_head,
+                                              self.query_pos_head,
+                                              attn_mask)
+
+        if self.training and dn_meta is not None:
+            dn_out_bboxes, out_bboxes = torch.split(out_bboxes, dn_meta['dn_num_split'], dim=2)
+            dn_out_logits, out_logits = torch.split(out_logits, dn_meta['dn_num_split'], dim=2)
+
+        out = {'pred_logits': out_logits[-1], 'pred_boxes': out_bboxes[-1]}
+
+        if self.training and self.aux_loss:
+            out['aux_outputs'] = self._set_aux_loss(out_logits[:-1], out_bboxes[:-1])
+            out['aux_outputs'].extend(self._set_aux_loss([enc_topk_logits], [enc_topk_bboxes]))
+            
+            if self.training and dn_meta is not None:
+                out['dn_aux_outputs'] = self._set_aux_loss(dn_out_logits, dn_out_bboxes)
+                out['dn_meta'] = dn_meta
+
+        return out
+
+
+## RTDETR's Transformer for Instance Segmentation task (not complete yet)
+class MaskRTDetrTransformer(RTDetrTransformer):
+    def __init__(self,
+                 # basic parameters
+                 in_dims        :List = [256, 512, 1024],
+                 hidden_dim     :int  = 256,
+                 strides        :List = [8, 16, 32],
+                 num_classes    :int  = 80,
+                 num_queries    :int  = 300,
+                 # transformer parameters
+                 num_heads      :int   = 8,
+                 num_layers     :int   = 1,
+                 num_levels     :int   = 3,
+                 num_points     :int   = 4,
+                 ffn_dim        :int   = 1024,
+                 dropout        :float = 0.1,
+                 act_type       :str   = "relu",
+                 return_intermediate :bool = False,
+                 # Denoising parameters
+                 num_denoising       :int  = 100,
+                 label_noise_ratio   :float = 0.5,
+                 box_noise_scale     :float = 1.0,
+                 learnt_init_query   :bool  = False,
+                 aux_loss            :bool  = True
+                 ):
+        super().__init__()
+
+    def forward(self, feats, targets=None):
+        return

+ 34 - 0
models/rtdetr/rtdetr_encoder.py

@@ -0,0 +1,34 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .basic_modules.backbone import build_backbone
+from .basic_modules.fpn      import build_fpn
+
+
+# ----------------- Image Encoder -----------------
+class ImageEncoder(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        # ---------------- Basic settings ----------------
+        ## Basic parameters
+        self.cfg = cfg
+        ## Network parameters
+        self.strides    = cfg.out_stride
+        self.hidden_dim = cfg.hidden_dim
+        self.num_levels = len(self.strides)
+        
+        # ---------------- Network settings ----------------
+        ## Backbone Network
+        self.backbone = build_backbone(cfg, pretrained=cfg.pretrained)
+        self.fpn_feat_dims = self.backbone.feat_dims[-3:]
+
+        ## Feature Pyramid Network
+        self.fpn = build_fpn(cfg, self.fpn_feat_dims)
+        self.fpn_dims = self.fpn.out_dims
+        
+    def forward(self, x):
+        pyramid_feats = self.backbone(x)
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        return pyramid_feats

+ 52 - 0
models/yolov1/README.md

@@ -0,0 +1,52 @@
+# Redesigned YOLOv1:
+
+| Model  |  Backbone  | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
+|--------|------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
+| YOLOv1 | ResNet-18  | 1xb16 |  640  |        27.9            |       47.5        |   37.8            |   21.3             | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov1_coco.pth) |
+
+- For training, we train redesigned YOLOv1 with 150 epochs on COCO.
+- For data augmentation, we only use the large scale jitter (LSJ), no Mosaic or Mixup augmentation.
+- For optimizer, we use SGD with momentum 0.937, weight decay 0.0005 and base lr 0.01.
+- For learning rate scheduler, we use linear decay scheduler.
+
+
+## Train YOLOv1
+### Single GPU
+Taking training YOLOv1 on COCO as the example,
+```Shell
+python train.py --cuda -d coco --root path/to/coco -m yolov1 -bs 16 -size 640 --wp_epoch 3 --max_epoch 150 --eval_epoch 10 --no_aug_epoch 10 --ema --fp16 --multi_scale 
+```
+
+### Multi GPU
+Taking training YOLOv1 on COCO as the example,
+```Shell
+python -m torch.distributed.run --nproc_per_node=8 train.py --cuda -dist -d coco --root /data/datasets/ -m yolov1 -bs 128 -size 640 --wp_epoch 3 --max_epoch 150  --eval_epoch 10 --no_aug_epoch 20 --ema --fp16 --sybn --multi_scale --save_folder weights/ 
+```
+
+## Test YOLOv1
+Taking testing YOLOv1 on COCO-val as the example,
+```Shell
+python test.py --cuda -d coco --root path/to/coco -m yolov1 --weight path/to/yolov1.pth -size 640 -vt 0.3 --show 
+```
+
+## Evaluate YOLOv1
+Taking evaluating YOLOv1 on COCO-val as the example,
+```Shell
+python eval.py --cuda -d coco-val --root path/to/coco -m yolov1 --weight path/to/yolov1.pth 
+```
+
+## Demo
+### Detect with Image
+```Shell
+python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m yolov1 --weight path/to/weight -size 640 -vt 0.3 --show
+```
+
+### Detect with Video
+```Shell
+python demo.py --mode video --path_to_vid path/to/video --cuda -m yolov1 --weight path/to/weight -size 640 -vt 0.3 --show --gif
+```
+
+### Detect with Camera
+```Shell
+python demo.py --mode camera --cuda -m yolov1 --weight path/to/weight -size 640 -vt 0.3 --show --gif
+```

+ 16 - 0
models/yolov1/build.py

@@ -0,0 +1,16 @@
+from .loss import SetCriterion
+from .yolov1 import Yolov1
+
+
+# build object detector
+def build_yolov1(cfg, is_val=False):
+    # -------------- Build YOLO --------------
+    model = Yolov1(cfg, is_val)
+  
+    # -------------- Build criterion --------------
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+        
+    return model, criterion

+ 98 - 0
models/yolov1/loss.py

@@ -0,0 +1,98 @@
+import torch
+import torch.nn.functional as F
+from .matcher import YoloMatcher
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+
+class SetCriterion(object):
+    def __init__(self, cfg):
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        self.loss_obj_weight = cfg.loss_obj
+        self.loss_cls_weight = cfg.loss_cls
+        self.loss_box_weight = cfg.loss_box
+
+        # matcher
+        self.matcher = YoloMatcher(cfg.num_classes)
+
+    def loss_objectness(self, pred_obj, gt_obj):
+        loss_obj = F.binary_cross_entropy_with_logits(pred_obj, gt_obj, reduction='none')
+
+        return loss_obj
+    
+    def loss_classes(self, pred_cls, gt_label):
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_label, reduction='none')
+
+        return loss_cls
+
+    def loss_bboxes(self, pred_box, gt_box):
+        # regression loss
+        ious = get_ious(pred_box,
+                        gt_box,
+                        box_mode="xyxy",
+                        iou_type='giou')
+        loss_box = 1.0 - ious
+
+        return loss_box
+
+    def __call__(self, outputs, targets):
+        device = outputs['pred_cls'][0].device
+        stride = outputs['stride']
+        fmp_size = outputs['fmp_size']
+        (
+            gt_objectness, 
+            gt_classes, 
+            gt_bboxes,
+            ) = self.matcher(fmp_size=fmp_size, 
+                             stride=stride, 
+                             targets=targets)
+        # List[B, M, C] -> [B, M, C] -> [BM, C]
+        pred_obj = outputs['pred_obj'].view(-1)        # [B, M, 1] -> [BM,]
+        pred_cls = outputs['pred_cls'].flatten(0, 1)   # [B, M, C] -> [BM, C]
+        pred_box = outputs['pred_box'].flatten(0, 1)   # [B, M, 4] -> [BM, 4]
+       
+        gt_objectness = gt_objectness.view(-1).to(device).float()               # [BM,]
+        gt_classes = gt_classes.view(-1, self.num_classes).to(device).float()   # [BM, C]
+        gt_bboxes = gt_bboxes.view(-1, 4).to(device).float()                    # [BM, 4]
+
+        pos_masks = (gt_objectness > 0)
+        num_fgs = pos_masks.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+
+        # obj loss
+        loss_obj = self.loss_objectness(pred_obj, gt_objectness)
+        loss_obj = loss_obj.sum() / num_fgs
+
+        # cls loss
+        pred_cls_pos = pred_cls[pos_masks]
+        gt_classes_pos = gt_classes[pos_masks]
+        loss_cls = self.loss_classes(pred_cls_pos, gt_classes_pos)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # box loss
+        pred_box_pos = pred_box[pos_masks]
+        gt_bboxes_pos = gt_bboxes[pos_masks]
+        loss_box = self.loss_bboxes(pred_box_pos, gt_bboxes_pos)
+        loss_box = loss_box.sum() / num_fgs
+        
+        # total loss
+        losses = self.loss_obj_weight * loss_obj + \
+                 self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box
+
+        loss_dict = dict(
+                loss_obj = loss_obj,
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                losses = losses
+        )
+
+        return loss_dict
+    
+    
+if __name__ == "__main__":
+    pass

+ 69 - 0
models/yolov1/matcher.py

@@ -0,0 +1,69 @@
+import torch
+import numpy as np
+
+
+class YoloMatcher(object):
+    def __init__(self, num_classes):
+        self.num_classes = num_classes
+
+
+    @torch.no_grad()
+    def __call__(self, fmp_size, stride, targets):
+        """
+            img_size: (Int) input image size
+            stride: (Int) -> stride of YOLOv1 output.
+            targets: (Dict) dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}
+        """
+        # prepare
+        bs = len(targets)
+        fmp_h, fmp_w = fmp_size
+        gt_objectness = np.zeros([bs, fmp_h, fmp_w, 1]) 
+        gt_classes = np.zeros([bs, fmp_h, fmp_w, self.num_classes]) 
+        gt_bboxes = np.zeros([bs, fmp_h, fmp_w, 4])
+
+        for batch_index in range(bs):
+            targets_per_image = targets[batch_index]
+            # [N,]
+            tgt_cls = targets_per_image["labels"].numpy()
+            # [N, 4]
+            tgt_box = targets_per_image['boxes'].numpy()
+
+            for gt_box, gt_label in zip(tgt_box, tgt_cls):
+                x1, y1, x2, y2 = gt_box
+                # xyxy -> cxcywh
+                xc, yc = (x2 + x1) * 0.5, (y2 + y1) * 0.5
+                bw, bh = x2 - x1, y2 - y1
+
+                # check
+                if bw < 1. or bh < 1.:
+                    continue    
+
+                # grid
+                xs_c = xc / stride
+                ys_c = yc / stride
+                grid_x = int(xs_c)
+                grid_y = int(ys_c)
+
+                if grid_x < fmp_w and grid_y < fmp_h:
+                    # obj
+                    gt_objectness[batch_index, grid_y, grid_x] = 1.0
+                    # cls
+                    cls_ont_hot = np.zeros(self.num_classes)
+                    cls_ont_hot[int(gt_label)] = 1.0
+                    gt_classes[batch_index, grid_y, grid_x] = cls_ont_hot
+                    # box
+                    gt_bboxes[batch_index, grid_y, grid_x] = np.array([x1, y1, x2, y2])
+
+        # [B, M, C]
+        gt_objectness = gt_objectness.reshape(bs, -1, 1)
+        gt_classes = gt_classes.reshape(bs, -1, self.num_classes)
+        gt_bboxes = gt_bboxes.reshape(bs, -1, 4)
+
+        # to tensor
+        gt_objectness = torch.from_numpy(gt_objectness).float()
+        gt_classes = torch.from_numpy(gt_classes).float()
+        gt_bboxes = torch.from_numpy(gt_bboxes).float()
+
+        return gt_objectness, gt_classes, gt_bboxes

+ 146 - 0
models/yolov1/yolov1.py

@@ -0,0 +1,146 @@
+# --------------- Torch components ---------------
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from .yolov1_backbone import Yolov1Backbone
+from .yolov1_neck     import SPPF
+from .yolov1_head     import Yolov1DetHead
+from .yolov1_pred     import Yolov1DetPredLayer
+
+# --------------- External components ---------------
+from utils.misc import multiclass_nms
+
+
+# YOLOv1
+class Yolov1(nn.Module):
+    def __init__(self,
+                 cfg,
+                 is_val = False,
+                 ) -> None:
+        super(Yolov1, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
+        self.no_multi_labels  = False if is_val else True
+        
+        # ---------------------- Network Parameters ----------------------
+        self.backbone = Yolov1Backbone(cfg)
+        self.neck     = SPPF(cfg, self.backbone.feat_dim, cfg.head_dim)
+        self.head     = Yolov1DetHead(cfg, self.neck.out_dim)
+        self.pred     = Yolov1DetPredLayer(cfg, self.num_classes)
+
+    def post_process(self, obj_preds, cls_preds, box_preds):
+        """
+        We process predictions at each scale hierarchically
+        Input:
+            obj_preds: List[torch.Tensor] -> [[B, M, 1], ...], B=1
+            cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
+            box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
+        Output:
+            bboxes: np.array -> [N, 4]
+            scores: np.array -> [N,]
+            labels: np.array -> [N,]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for obj_pred_i, cls_pred_i, box_pred_i in zip(obj_preds, cls_preds, box_preds):
+            obj_pred_i = obj_pred_i[0]
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
+            if self.no_multi_labels:
+                # [M,]
+                scores, labels = torch.max(
+                    torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()), dim=1)
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # topk candidates
+                predicted_prob, topk_idxs = scores.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                labels = labels[topk_idxs]
+                bboxes = box_pred_i[topk_idxs]
+            else:
+                # [M, C] -> [MC,]
+                scores_i = torch.sqrt(obj_pred_i.sigmoid() * cls_pred_i.sigmoid()).flatten()
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # torch.sort is actually faster than .topk (at least on GPUs)
+                predicted_prob, topk_idxs = scores_i.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+                labels = topk_idxs % self.num_classes
+
+                bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores, dim=0)
+        labels = torch.cat(all_labels, dim=0)
+        bboxes = torch.cat(all_bboxes, dim=0)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes)
+        
+        return bboxes, scores, labels
+    
+    def forward(self, x):
+        # ---------------- Backbone ----------------
+        x = self.backbone(x)
+
+        # ---------------- Neck ----------------
+        x = self.neck(x)
+
+        # ---------------- Heads ----------------
+        cls_feats, reg_feats = self.head(x)
+
+        # ---------------- Preds ----------------
+        outputs = self.pred(cls_feats, reg_feats)
+        outputs['image_size'] = [x.shape[2], x.shape[3]]
+
+        if not self.training:
+            all_obj_preds = [outputs['pred_obj'],]
+            all_cls_preds = [outputs['pred_cls'],]
+            all_box_preds = [outputs['pred_box'],]
+
+            # post process
+            bboxes, scores, labels = self.post_process(
+                all_obj_preds, all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
+        
+        return outputs 

+ 209 - 0
models/yolov1/yolov1_backbone.py

@@ -0,0 +1,209 @@
+import torch
+import torch.nn as nn
+import torch.utils.model_zoo as model_zoo
+
+try:
+    from .yolov1_basic import conv1x1, BasicBlock, Bottleneck
+except:
+    from  yolov1_basic import conv1x1, BasicBlock, Bottleneck
+
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152']
+
+
+model_urls = {
+    'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
+    'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
+    'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
+    'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
+    'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
+}
+
+
+# --------------------- Yolov1's Backbone -----------------------
+class Yolov1Backbone(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.backbone, self.feat_dim = build_resnet(cfg.backbone, cfg.use_pretrained)
+
+    def forward(self, x):
+        c5 = self.backbone(x)
+
+        return c5
+
+
+# --------------------- ResNet -----------------------
+class ResNet(nn.Module):
+
+    def __init__(self, block, layers, zero_init_residual=False):
+        super(ResNet, self).__init__()
+        self.inplanes = 64
+        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = nn.BatchNorm2d(64)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, nn.BatchNorm2d):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)
+                elif isinstance(m, BasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1):
+        downsample = None
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                nn.BatchNorm2d(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        """
+        Input:
+            x: (Tensor) -> [B, C, H, W]
+        Output:
+            c5: (Tensor) -> [B, C, H/32, W/32]
+        """
+        c1 = self.conv1(x)     # [B, C, H/2, W/2]
+        c1 = self.bn1(c1)      # [B, C, H/2, W/2]
+        c1 = self.relu(c1)     # [B, C, H/2, W/2]
+        c2 = self.maxpool(c1)  # [B, C, H/4, W/4]
+
+        c2 = self.layer1(c2)   # [B, C, H/4, W/4]
+        c3 = self.layer2(c2)   # [B, C, H/8, W/8]
+        c4 = self.layer3(c3)   # [B, C, H/16, W/16]
+        c5 = self.layer4(c4)   # [B, C, H/32, W/32]
+
+        return c5
+
+
+# --------------------- Functions -----------------------
+def build_resnet(model_name="resnet18", pretrained=False):
+    if model_name == 'resnet18':
+        model = resnet18(pretrained)
+        feat_dim = 512
+    elif model_name == 'resnet34':
+        model = resnet34(pretrained)
+        feat_dim = 512
+    elif model_name == 'resnet50':
+        model = resnet50(pretrained)
+        feat_dim = 2048
+    elif model_name == 'resnet101':
+        model = resnet34(pretrained)
+        feat_dim = 2048
+    else:
+        raise NotImplementedError("Unknown resnet: {}".format(model_name))
+    
+    return model, feat_dim
+
+def resnet18(pretrained=False, **kwargs):
+    """Constructs a ResNet-18 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
+    if pretrained:
+        # strict = False as we don't need fc layer params.
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet18']), strict=False)
+    return model
+
+def resnet34(pretrained=False, **kwargs):
+    """Constructs a ResNet-34 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet34']), strict=False)
+    return model
+
+def resnet50(pretrained=False, **kwargs):
+    """Constructs a ResNet-50 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)
+    return model
+
+def resnet101(pretrained=False, **kwargs):
+    """Constructs a ResNet-101 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet101']), strict=False)
+    return model
+
+def resnet152(pretrained=False, **kwargs):
+    """Constructs a ResNet-152 model.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+    """
+    model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
+    if pretrained:
+        model.load_state_dict(model_zoo.load_url(model_urls['resnet152']), strict=False)
+    return model
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # YOLOv8-Base config
+    class Yolov1BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.out_stride = 32
+            self.max_stride = 32
+            ## Backbone
+            self.backbone       = 'resnet18'
+            self.use_pretrained = True
+
+    cfg = Yolov1BaseConfig()
+    # Build backbone
+    model = Yolov1Backbone(cfg)
+
+    # Inference
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    output = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print(output.shape)
+
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))    

+ 147 - 0
models/yolov1/yolov1_basic.py

@@ -0,0 +1,147 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.norm2(self.conv2(x))
+            return x
+
+
+# --------------------- ResNet modules ---------------------
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = conv1x1(inplanes, planes)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = conv1x1(planes, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out

+ 121 - 0
models/yolov1/yolov1_head.py

@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov1_basic import BasicConv
+except:
+    from  yolov1_basic import BasicConv
+
+
+class Yolov1DetHead(nn.Module):
+    def __init__(self, cfg, in_dim: int = 256):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dim = in_dim
+        self.cls_head_dim = cfg.head_dim
+        self.reg_head_dim = cfg.head_dim
+        self.num_cls_head = cfg.num_cls_head
+        self.num_reg_head = cfg.num_reg_head
+        self.act_type     = cfg.head_act
+        self.norm_type    = cfg.head_norm
+        self.depthwise    = cfg.head_depthwise
+        
+        # --------- Network Parameters ----------
+        ## cls head
+        cls_feats = []
+        for i in range(self.num_cls_head):
+            if i == 0:
+                cls_feats.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type  = self.act_type,
+                              norm_type = self.norm_type,
+                              depthwise = self.depthwise)
+                              )
+            else:
+                cls_feats.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type  = self.act_type,
+                              norm_type = self.norm_type,
+                              depthwise = self.depthwise)
+                              )
+        ## reg head
+        reg_feats = []
+        for i in range(self.num_reg_head):
+            if i == 0:
+                reg_feats.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type  = self.act_type,
+                              norm_type = self.norm_type,
+                              depthwise = self.depthwise)
+                              )
+            else:
+                reg_feats.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type  = self.act_type,
+                              norm_type = self.norm_type,
+                              depthwise = self.depthwise)
+                              )
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv8-Base config
+    class Yolov1BaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.out_stride = 32
+            self.max_stride = 32
+            ## Head
+            self.head_act  = 'lrelu'
+            self.head_norm = 'BN'
+            self.head_depthwise = False
+            self.head_dim  = 256
+            self.num_cls_head   = 2
+            self.num_reg_head   = 2
+
+    cfg = Yolov1BaseConfig()
+    # Build a head
+    head = Yolov1DetHead(cfg, 512)
+
+
+    # Inference
+    x = torch.randn(1, 512, 20, 20)
+    t0 = time.time()
+    cls_feat, reg_feat = head(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print(cls_feat.shape, reg_feat.shape)
+
+    print('==============================')
+    flops, params = profile(head, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))    

+ 33 - 0
models/yolov1/yolov1_neck.py

@@ -0,0 +1,33 @@
+import torch
+import torch.nn as nn
+
+from .yolov1_basic import BasicConv
+
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    """
+        This code referenced to https://github.com/ultralytics/yolov5
+    """
+    def __init__(self, cfg, in_dim, out_dim):
+        super().__init__()
+        ## ----------- Basic Parameters -----------
+        inter_dim = round(in_dim * cfg.neck_expand_ratio)
+        self.out_dim = out_dim
+        ## ----------- Network Parameters -----------
+        self.cv1 = BasicConv(in_dim, inter_dim,
+                             kernel_size=1, padding=0, stride=1,
+                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.cv2 = BasicConv(inter_dim * 4, out_dim,
+                             kernel_size=1, padding=0, stride=1,
+                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.m = nn.MaxPool2d(kernel_size=cfg.spp_pooling_size,
+                              stride=1,
+                              padding=cfg.spp_pooling_size // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))

+ 95 - 0
models/yolov1/yolov1_pred.py

@@ -0,0 +1,95 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# -------------------- Detection Pred Layer --------------------
+## Single-level pred layer
+class Yolov1DetPredLayer(nn.Module):
+    def __init__(self,
+                 cfg,
+                 num_classes):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.stride  = cfg.out_stride
+        self.cls_dim = cfg.head_dim
+        self.reg_dim = cfg.head_dim
+
+        # --------- Network Parameters ----------
+        self.obj_pred = nn.Conv2d(self.cls_dim, 1, kernel_size=1)
+        self.cls_pred = nn.Conv2d(self.cls_dim, num_classes, kernel_size=1)
+        self.reg_pred = nn.Conv2d(self.reg_dim, 4, kernel_size=1)                
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # Init bias
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        # obj pred
+        b = self.obj_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        self.obj_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # cls pred
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+    def generate_anchors(self, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # 特征图的宽和高
+        fmp_h, fmp_w = fmp_size
+
+        # 生成网格的x坐标和y坐标
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # 将xy两部分的坐标拼起来:[H, W, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float()
+        
+        # [H, W, 2] -> [HW, 2]
+        anchors = anchors.view(-1, 2)
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat):
+        # 预测层
+        obj_pred = self.obj_pred(cls_feat)
+        cls_pred = self.cls_pred(cls_feat)
+        reg_pred = self.reg_pred(reg_feat)
+
+        # 生成网格坐标
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+
+        # 对 pred 的size做一些view调整,便于后续的处理
+        # [B, C, H, W] -> [B, H, W, C] -> [B, H*W, C]
+        obj_pred = obj_pred.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().flatten(1, 2)
+        
+        # 解算边界框坐标
+        cxcy_pred = (torch.sigmoid(reg_pred[..., :2]) + anchors[..., :2]) * self.stride
+        bwbh_pred = torch.exp(reg_pred[..., 2:]) * self.stride
+        pred_x1y1 = cxcy_pred - bwbh_pred * 0.5
+        pred_x2y2 = cxcy_pred + bwbh_pred * 0.5
+        box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        # output dict
+        outputs = {"pred_obj": obj_pred,       # (torch.Tensor) [B, M, 1]
+                   "pred_cls": cls_pred,       # (torch.Tensor) [B, M, C]
+                   "pred_reg": reg_pred,       # (torch.Tensor) [B, M, 4]
+                   "pred_box": box_pred,       # (torch.Tensor) [B, M, 4]
+                   "anchors" : anchors,        # (torch.Tensor) [M, 2]
+                   "fmp_size": fmp_size,
+                   "stride"  : self.stride,    # (Int)
+                   }
+
+        return outputs

+ 47 - 0
models/yolov8/README.md

@@ -0,0 +1,47 @@
+# YOLOv8:
+
+|  Model    |  Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) |  ckpt  | logs |
+|-----------|--------|-------|------------------------|-------------------|-------------------|--------------------|--------|------|
+| YOLOv8-S  | 8xb16  |  640  |                        |                   |                   |                    |  |  |
+
+
+## Train YOLO
+### Single GPU
+Taking training YOLOv8-S on COCO as the example,
+```Shell
+python train.py --cuda -d coco --root path/to/coco -m yolov8_s -bs 16  --fp16
+```
+
+### Multi GPU
+Taking training YOLO on COCO as the example,
+```Shell
+python -m torch.distributed.run --nproc_per_node=8 train.py --cuda --distributed -d coco --root /data/datasets/ -m yolov8_s -bs 128 --fp16 --sybn 
+```
+
+## Test YOLO
+Taking testing YOLO on COCO-val as the example,
+```Shell
+python test.py --cuda -d coco --root path/to/coco -m yolov8_s --weight path/to/yolo.pth --show 
+```
+
+## Evaluate YOLO
+Taking evaluating YOLO on COCO-val as the example,
+```Shell
+python eval.py --cuda -d coco --root path/to/coco -m yolov8_s --weight path/to/yolo.pth
+```
+
+## Demo
+### Detect with Image
+```Shell
+python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m yolov8_s --weight path/to/weight --show
+```
+
+### Detect with Video
+```Shell
+python demo.py --mode video --path_to_vid path/to/video --cuda -m yolov8_s --weight path/to/weight --show
+```
+
+### Detect with Camera
+```Shell
+python demo.py --mode camera --cuda -m yolov8_s --weight path/to/weight --show
+```

+ 24 - 0
models/yolov8/build.py

@@ -0,0 +1,24 @@
+import torch.nn as nn
+
+from .loss import SetCriterion
+from .yolov8 import Yolov8
+
+
+# build object detector
+def build_yolov8(cfg, is_val=False):
+    # -------------- Build YOLO --------------
+    model = Yolov8(cfg, is_val)
+
+    # -------------- Initialize YOLO --------------
+    for m in model.modules():
+        if isinstance(m, nn.BatchNorm2d):
+            m.eps = 1e-3
+            m.momentum = 0.03    
+            
+    # -------------- Build criterion --------------
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+        
+    return model, criterion

+ 187 - 0
models/yolov8/loss.py

@@ -0,0 +1,187 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from utils.box_ops import bbox2dist, bbox_iou
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+from .matcher import TaskAlignedAssigner
+
+
+class SetCriterion(object):
+    def __init__(self, cfg):
+        # --------------- Basic parameters ---------------
+        self.cfg = cfg
+        self.reg_max = cfg.reg_max
+        self.num_classes = cfg.num_classes
+        # --------------- Loss config ---------------
+        self.loss_cls_weight = cfg.loss_cls
+        self.loss_box_weight = cfg.loss_box
+        self.loss_dfl_weight = cfg.loss_dfl
+        # --------------- Matcher config ---------------
+        self.matcher = TaskAlignedAssigner(num_classes     = cfg.num_classes,
+                                           topk_candidates = cfg.tal_topk_candidates,
+                                           alpha           = cfg.tal_alpha,
+                                           beta            = cfg.tal_beta
+                                           )
+
+    def loss_classes(self, pred_cls, gt_score):
+        # compute bce loss
+        loss_cls = F.binary_cross_entropy_with_logits(pred_cls, gt_score, reduction='none')
+
+        return loss_cls
+    
+    def loss_bboxes(self, pred_box, gt_box, bbox_weight):
+        # regression loss
+        ious = bbox_iou(pred_box, gt_box, xywh=False, CIoU=True)
+        loss_box = (1.0 - ious.squeeze(-1)) * bbox_weight
+
+        return loss_box
+    
+    def loss_dfl(self, pred_reg, gt_box, anchor, stride, bbox_weight=None):
+        # rescale coords by stride
+        gt_box_s = gt_box / stride
+        anchor_s = anchor / stride
+
+        # compute deltas
+        gt_ltrb_s = bbox2dist(anchor_s, gt_box_s, self.reg_max - 1)
+
+        gt_left = gt_ltrb_s.to(torch.long)
+        gt_right = gt_left + 1
+
+        weight_left = gt_right.to(torch.float) - gt_ltrb_s
+        weight_right = 1 - weight_left
+
+        # loss left
+        loss_left = F.cross_entropy(
+            pred_reg.view(-1, self.reg_max),
+            gt_left.view(-1),
+            reduction='none').view(gt_left.shape) * weight_left
+        # loss right
+        loss_right = F.cross_entropy(
+            pred_reg.view(-1, self.reg_max),
+            gt_right.view(-1),
+            reduction='none').view(gt_left.shape) * weight_right
+
+        loss_dfl = (loss_left + loss_right).mean(-1)
+        
+        if bbox_weight is not None:
+            loss_dfl *= bbox_weight
+
+        return loss_dfl
+
+    def __call__(self, outputs, targets):        
+        """
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_reg']: List(Tensor) [B, M, 4*(reg_max+1)]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['anchors']: List(Tensor) [M, 2]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            outputs['stride_tensor']: List(Tensor) [M, 1]
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        # preds: [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+        bs, num_anchors = cls_preds.shape[:2]
+        device = cls_preds.device
+        anchors = torch.cat(outputs['anchors'], dim=0)
+        
+        # --------------- label assignment ---------------
+        gt_score_targets = []
+        gt_bbox_targets = []
+        fg_masks = []
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)     # [Mp,]
+            tgt_boxs = targets[batch_idx]["boxes"].to(device)        # [Mp, 4]
+
+            if self.cfg.normalize_coords:
+                img_h, img_w = outputs['image_size']
+                tgt_boxs[..., [0, 2]] *= img_w
+                tgt_boxs[..., [1, 3]] *= img_h
+            
+            if self.cfg.box_format == 'xywh':
+                tgt_boxs_x1y1 = tgt_boxs[..., :2] - 0.5 * tgt_boxs[..., 2:]
+                tgt_boxs_x2y2 = tgt_boxs[..., :2] + 0.5 * tgt_boxs[..., 2:]
+                tgt_boxs = torch.cat([tgt_boxs_x1y1, tgt_boxs_x2y2], dim=-1)
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_boxs.max().item() == 0.:
+                # There is no valid gt
+                fg_mask  = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
+                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
+                gt_box   = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
+            else:
+                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
+                tgt_boxs = tgt_boxs[None]                   # [1, Mp, 4]
+                (
+                    _,
+                    gt_box,     # [1, M, 4]
+                    gt_score,   # [1, M, C]
+                    fg_mask,    # [1, M,]
+                    _
+                ) = self.matcher(
+                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
+                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
+                    anc_points = anchors,
+                    gt_labels = tgt_labels,
+                    gt_bboxes = tgt_boxs
+                    )
+            gt_score_targets.append(gt_score)
+            gt_bbox_targets.append(gt_box)
+            fg_masks.append(fg_mask)
+
+        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
+        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
+        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
+        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
+        num_fgs = gt_score_targets.sum()
+        
+        # Average loss normalizer across all the GPUs
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0)
+
+        # ------------------ Classification loss ------------------
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        loss_cls = self.loss_classes(cls_preds, gt_score_targets)
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # ------------------ Regression loss ------------------
+        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
+        box_targets_pos = gt_bbox_targets.view(-1, 4)[fg_masks]
+        bbox_weight = gt_score_targets[fg_masks].sum(-1)
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
+        loss_box = loss_box.sum() / num_fgs
+
+        # ------------------ Distribution focal loss  ------------------
+        ## process anchors
+        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
+        ## process stride tensors
+        strides = torch.cat(outputs['stride_tensor'], dim=0)
+        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
+        ## fg preds
+        reg_preds_pos = reg_preds.view(-1, 4*self.reg_max)[fg_masks]
+        anchors_pos = anchors[fg_masks]
+        strides_pos = strides[fg_masks]
+        ## compute dfl
+        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
+        loss_dfl = loss_dfl.sum() / num_fgs
+
+        # total loss
+        losses = loss_cls * self.loss_cls_weight + loss_box * self.loss_box_weight + loss_dfl * self.loss_dfl_weight
+        loss_dict = dict(
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                loss_dfl = loss_dfl,
+                losses = losses
+        )
+
+        return loss_dict
+    
+
+if __name__ == "__main__":
+    pass

+ 199 - 0
models/yolov8/matcher.py

@@ -0,0 +1,199 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from utils.box_ops import bbox_iou
+
+
+# -------------------------- Task Aligned Assigner --------------------------
+class TaskAlignedAssigner(nn.Module):
+    def __init__(self,
+                 num_classes     = 80,
+                 topk_candidates = 10,
+                 alpha           = 0.5,
+                 beta            = 6.0, 
+                 eps             = 1e-9):
+        super(TaskAlignedAssigner, self).__init__()
+        self.topk_candidates = topk_candidates
+        self.num_classes = num_classes
+        self.bg_idx = num_classes
+        self.alpha = alpha
+        self.beta = beta
+        self.eps = eps
+
+    @torch.no_grad()
+    def forward(self,
+                pd_scores,
+                pd_bboxes,
+                anc_points,
+                gt_labels,
+                gt_bboxes):
+        self.bs = pd_scores.size(0)
+        self.n_max_boxes = gt_bboxes.size(1)
+
+        mask_pos, align_metric, overlaps = self.get_pos_mask(
+            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
+
+        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
+            mask_pos, overlaps, self.n_max_boxes)
+
+        # Assigned target
+        target_labels, target_bboxes, target_scores = self.get_targets(
+            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+        # normalize
+        align_metric *= mask_pos
+        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
+        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
+        target_scores = target_scores * norm_align_metric
+
+        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
+
+    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
+        # get in_gts mask, (b, max_num_obj, h*w)
+        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        # get anchor_align metric, (b, max_num_obj, h*w)
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts)
+        # get topk_metric mask, (b, max_num_obj, h*w)
+        mask_topk = self.select_topk_candidates(align_metric)
+        # merge all mask to a final mask, (b, max_num_obj, h*w)
+        mask_pos = mask_topk * mask_in_gts
+
+        return mask_pos, align_metric, overlaps
+
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, mask_in_gts):
+        """Compute alignment metric given predicted and ground truth bounding boxes."""
+        na = pd_bboxes.shape[-2]
+        mask_in_gts = mask_in_gts.bool()  # b, max_num_obj, h*w
+        overlaps = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
+        bbox_scores = torch.zeros([self.bs, self.n_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
+
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).expand(-1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.squeeze(-1)  # b, max_num_obj
+        # Get the scores of each grid for each gt cls
+        bbox_scores[mask_in_gts] = pd_scores[ind[0], :, ind[1]][mask_in_gts]  # b, max_num_obj, h*w
+
+        # (b, max_num_obj, 1, 4), (b, 1, h*w, 4)
+        pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, self.n_max_boxes, -1, -1)[mask_in_gts]
+        gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[mask_in_gts]
+        overlaps[mask_in_gts] = bbox_iou(gt_boxes, pd_boxes, xywh=False, CIoU=True).squeeze(-1).clamp_(0)
+
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+        return align_metric, overlaps
+
+    def select_topk_candidates(self, metrics, largest=True):
+        """
+        Args:
+            metrics: (b, max_num_obj, h*w).
+            topk_mask: (b, max_num_obj, topk) or None
+        """
+        # (b, max_num_obj, topk)
+        topk_metrics, topk_idxs = torch.topk(metrics, self.topk_candidates, dim=-1, largest=largest)
+        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(topk_idxs)
+        # (b, max_num_obj, topk)
+        topk_idxs.masked_fill_(~topk_mask, 0)
+
+        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+        count_tensor = torch.zeros(metrics.shape, dtype=torch.int8, device=topk_idxs.device)
+        ones = torch.ones_like(topk_idxs[:, :, :1], dtype=torch.int8, device=topk_idxs.device)
+        for k in range(self.topk_candidates):
+            # Expand topk_idxs for each value of k and add 1 at the specified positions
+            count_tensor.scatter_add_(-1, topk_idxs[:, :, k:k + 1], ones)
+        # count_tensor.scatter_add_(-1, topk_idxs, torch.ones_like(topk_idxs, dtype=torch.int8, device=topk_idxs.device))
+        # Filter invalid bboxes
+        count_tensor.masked_fill_(count_tensor > 1, 0)
+
+        return count_tensor.to(metrics.dtype)
+
+    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+        # Assigned target labels, (b, 1)
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
+        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
+
+        # Assigned target boxes, (b, max_num_obj, 4) -> (b, h*w, 4)
+        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+
+        # Assigned target scores
+        target_labels.clamp_(0)
+
+        # 10x faster than F.one_hot()
+        target_scores = torch.zeros((target_labels.shape[0], target_labels.shape[1], self.num_classes),
+                                    dtype=torch.int64,
+                                    device=target_labels.device)  # (b, h*w, 80)
+        target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
+
+        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+
+        return target_labels, target_bboxes, target_scores
+    
+
+# -------------------------- Basic Functions --------------------------
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+    """select the positive anchors's center in gt
+    Args:
+        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
+        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    n_anchors = xy_centers.size(0)
+    bs, n_max_boxes, _ = gt_bboxes.size()
+    _gt_bboxes = gt_bboxes.reshape([-1, 4])
+    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
+    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
+    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
+    b_lt = xy_centers - gt_bboxes_lt
+    b_rb = gt_bboxes_rb - xy_centers
+    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
+    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
+    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
+
+def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+    """if an anchor box is assigned to multiple gts,
+        the one with the highest iou will be selected.
+    Args:
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    Return:
+        target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        fg_mask (Tensor): shape(bs, num_total_anchors)
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    fg_mask = mask_pos.sum(-2)
+    if fg_mask.max() > 1:  # one anchor is assigned to multiple gt_bboxes
+        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, n_max_boxes, -1)  # (b, n_max_boxes, h*w)
+        max_overlaps_idx = overlaps.argmax(1)  # (b, h*w)
+
+        is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
+        is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
+
+        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()  # (b, n_max_boxes, h*w)
+        fg_mask = mask_pos.sum(-2)
+    # Find each grid serve which gt(index)
+    target_gt_idx = mask_pos.argmax(-2)  # (b, h*w)
+
+    return target_gt_idx, fg_mask, mask_pos
+
+def iou_calculator(box1, box2, eps=1e-9):
+    """Calculate iou for batch
+    Args:
+        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
+        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
+    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
+    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
+    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
+    x1y1 = torch.maximum(px1y1, gx1y1)
+    x2y2 = torch.minimum(px2y2, gx2y2)
+    overlap = (x2y2 - x1y1).clip(0).prod(-1)
+    area1 = (px2y2 - px1y1).clip(0).prod(-1)
+    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
+    union = area1 + area2 - overlap + eps
+
+    return overlap / union

+ 145 - 0
models/yolov8/yolov8.py

@@ -0,0 +1,145 @@
+# --------------- Torch components ---------------
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from .yolov8_backbone import Yolov8Backbone
+from .yolov8_neck     import SPPF
+from .yolov8_pafpn    import Yolov8PaFPN
+from .yolov8_head     import Yolov8DetHead
+from .yolov8_pred     import Yolov8DetPredLayer
+
+# --------------- External components ---------------
+from utils.misc import multiclass_nms
+
+
+# YOLOv8
+class Yolov8(nn.Module):
+    def __init__(self,
+                 cfg,
+                 is_val = False,
+                 ) -> None:
+        super(Yolov8, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
+        self.no_multi_labels  = False if is_val else True
+        
+        # ---------------------- Network Parameters ----------------------
+        self.backbone = Yolov8Backbone(cfg)
+        self.neck     = SPPF(cfg, self.backbone.feat_dims[-1], self.backbone.feat_dims[-1])
+        self.fpn      = Yolov8PaFPN(cfg, self.backbone.feat_dims)
+        self.head     = Yolov8DetHead(cfg, self.fpn.out_dims)
+        self.pred     = Yolov8DetPredLayer(cfg, self.head.cls_head_dim, self.head.reg_head_dim)
+
+    def post_process(self, cls_preds, box_preds):
+        """
+        We process predictions at each scale hierarchically
+        Input:
+            cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
+            box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
+        Output:
+            bboxes: np.array -> [N, 4]
+            scores: np.array -> [N,]
+            labels: np.array -> [N,]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
+            if self.no_multi_labels:
+                # [M,]
+                scores, labels = torch.max(cls_pred_i.sigmoid(), dim=1)
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # topk candidates
+                predicted_prob, topk_idxs = scores.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                labels = labels[topk_idxs]
+                bboxes = box_pred_i[topk_idxs]
+            else:
+                # [M, C] -> [MC,]
+                scores_i = cls_pred_i.sigmoid().flatten()
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # torch.sort is actually faster than .topk (at least on GPUs)
+                predicted_prob, topk_idxs = scores_i.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+                labels = topk_idxs % self.num_classes
+
+                bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores, dim=0)
+        labels = torch.cat(all_labels, dim=0)
+        bboxes = torch.cat(all_bboxes, dim=0)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes)
+        
+        return bboxes, scores, labels
+    
+    def forward(self, x):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(x)
+        # ---------------- Neck: SPP ----------------
+        pyramid_feats[-1] = self.neck(pyramid_feats[-1])
+
+        # ---------------- Neck: PaFPN ----------------
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # ---------------- Heads ----------------
+        cls_feats, reg_feats = self.head(pyramid_feats)
+
+        # ---------------- Preds ----------------
+        outputs = self.pred(cls_feats, reg_feats)
+        outputs['image_size'] = [x.shape[2], x.shape[3]]
+
+        if not self.training:
+            all_cls_preds = outputs['pred_cls']
+            all_box_preds = outputs['pred_box']
+
+            # post process
+            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
+        
+        return outputs 

+ 183 - 0
models/yolov8/yolov8_backbone.py

@@ -0,0 +1,183 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolov8_basic import BasicConv, ELANLayer
+except:
+    from  yolov8_basic import BasicConv, ELANLayer
+
+
+# IN1K pretrained weight
+pretrained_urls = {
+    'n': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/rtcnet_n_in1k_62.1.pth",
+    's': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/rtcnet_s_in1k_71.3.pth",
+    'm': None,
+    'l': None,
+    'x': None,
+}
+
+
+# ---------------------------- Basic functions ----------------------------
+class Yolov8Backbone(nn.Module):
+    def __init__(self, cfg):
+        super(Yolov8Backbone, self).__init__()
+        # ------------------ Basic setting ------------------
+        self.model_scale = cfg.scale
+        self.feat_dims = [round(64  * cfg.width),
+                          round(128 * cfg.width),
+                          round(256 * cfg.width),
+                          round(512 * cfg.width),
+                          round(512 * cfg.width * cfg.ratio)]
+        
+        # ------------------ Network setting ------------------
+        ## P1/2
+        self.layer_1 = BasicConv(3, self.feat_dims[0],
+                                 kernel_size=3, padding=1, stride=2,
+                                 act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
+        # P2/4
+        self.layer_2 = nn.Sequential(
+            BasicConv(self.feat_dims[0], self.feat_dims[1],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[1],
+                      out_dim    = self.feat_dims[1],
+                      num_blocks = round(3*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            BasicConv(self.feat_dims[1], self.feat_dims[2],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[2],
+                      out_dim    = self.feat_dims[2],
+                      num_blocks = round(6*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            BasicConv(self.feat_dims[2], self.feat_dims[3],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[3],
+                      out_dim    = self.feat_dims[3],
+                      num_blocks = round(6*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            BasicConv(self.feat_dims[3], self.feat_dims[4],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[4],
+                      out_dim    = self.feat_dims[4],
+                      num_blocks = round(3*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+
+        # Initialize all layers
+        self.init_weights()
+        
+        # Load imagenet pretrained weight
+        if cfg.use_pretrained:
+            self.load_pretrained()
+
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def load_pretrained(self):
+        url = pretrained_urls[self.model_scale]
+        if url is not None:
+            print('Loading backbone pretrained weight from : {}'.format(url))
+            # checkpoint state dict
+            checkpoint = torch.hub.load_state_dict_from_url(
+                url=url, map_location="cpu", check_hash=True)
+            checkpoint_state_dict = checkpoint.pop("model")
+            # model state dict
+            model_state_dict = self.state_dict()
+            # check
+            for k in list(checkpoint_state_dict.keys()):
+                if k in model_state_dict:
+                    shape_model = tuple(model_state_dict[k].shape)
+                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                    if shape_model != shape_checkpoint:
+                        checkpoint_state_dict.pop(k)
+                else:
+                    checkpoint_state_dict.pop(k)
+                    print('Unused key: ', k)
+            # load the weight
+            self.load_state_dict(checkpoint_state_dict)
+        else:
+            print('No pretrained weight for model scale: {}.'.format(self.model_scale))
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ---------------------------- Functions ----------------------------
+## build Yolo's Backbone
+def build_backbone(cfg): 
+    # model
+    backbone = Yolov8Backbone(cfg)
+        
+    return backbone
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    class BaseConfig(object):
+        def __init__(self) -> None:
+            self.bk_act = 'silu'
+            self.bk_norm = 'BN'
+            self.bk_depthwise = False
+            self.width = 1.0
+            self.depth = 1.0
+            self.ratio = 1.0
+            self.scale = "n"
+            self.use_pretrained = True
+
+    cfg = BaseConfig()
+    model = build_backbone(cfg)
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    x = torch.randn(1, 3, 640, 640)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 171 - 0
models/yolov8/yolov8_basic.py

@@ -0,0 +1,171 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.norm1(self.conv1(x))
+            # Pointwise conv
+            x = self.norm2(self.conv2(x))
+            return x
+
+
+# --------------------- Yolov8 modules ---------------------
+class YoloBottleneck(nn.Module):
+    def __init__(self,
+                 in_dim      :int,
+                 out_dim     :int,
+                 kernel_size :List  = [1, 3],
+                 expansion   :float = 0.5,
+                 shortcut    :bool  = False,
+                 act_type    :str   = 'silu',
+                 norm_type   :str   = 'BN',
+                 depthwise   :bool  = False,
+                 ) -> None:
+        super(YoloBottleneck, self).__init__()
+        inter_dim = int(out_dim * expansion)
+        # ----------------- Network setting -----------------
+        self.conv_layer1 = BasicConv(in_dim, inter_dim,
+                                     kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.conv_layer2 = BasicConv(inter_dim, out_dim,
+                                     kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.conv_layer2(self.conv_layer1(x))
+
+        return x + h if self.shortcut else h
+
+class CSPLayer(nn.Module):
+    # CSP Bottleneck with 3 convolutions
+    def __init__(self,
+                 in_dim      :int,
+                 out_dim     :int,
+                 num_blocks  :int   = 1,
+                 kernel_size :List = [3, 3],
+                 expansion   :float = 0.5,
+                 shortcut    :bool  = True,
+                 act_type    :str   = 'silu',
+                 norm_type   :str   = 'BN',
+                 depthwise   :bool  = False,
+                 ) -> None:
+        super().__init__()
+        inter_dim = round(out_dim * expansion)
+        self.input_proj_1 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.input_proj_2 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.output_proj  = BasicConv(2 * inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.module       = nn.Sequential(*[YoloBottleneck(inter_dim,
+                                                           inter_dim,
+                                                           kernel_size,
+                                                           expansion   = 1.0,
+                                                           shortcut    = shortcut,
+                                                           act_type    = act_type,
+                                                           norm_type   = norm_type,
+                                                           depthwise   = depthwise,
+                                                           ) for _ in range(num_blocks)])
+
+    def forward(self, x):
+        x1 = self.input_proj_1(x)
+        x2 = self.input_proj_2(x)
+        x2 = self.module(x2)
+        out = self.output_proj(torch.cat([x1, x2], dim=1))
+
+        return out
+
+class ELANLayer(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expansion  :float = 0.5,
+                 num_blocks :int   = 1,
+                 shortcut   :bool  = False,
+                 act_type   :str   = 'silu',
+                 norm_type  :str   = 'BN',
+                 depthwise  :bool  = False,
+                 ) -> None:
+        super(ELANLayer, self).__init__()
+        inter_dim = round(out_dim * expansion)
+        self.input_proj  = BasicConv(in_dim, inter_dim * 2, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.output_proj = BasicConv((2 + num_blocks) * inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.module      = nn.ModuleList([YoloBottleneck(inter_dim,
+                                                         inter_dim,
+                                                         kernel_size = [3, 3],
+                                                         expansion   = 1.0,
+                                                         shortcut    = shortcut,
+                                                         act_type    = act_type,
+                                                         norm_type   = norm_type,
+                                                         depthwise   = depthwise)
+                                                         for _ in range(num_blocks)])
+
+    def forward(self, x):
+        # Input proj
+        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
+        out = list([x1, x2])
+
+        # Bottlenecl
+        out.extend(m(out[-1]) for m in self.module)
+
+        # Output proj
+        out = self.output_proj(torch.cat(out, dim=1))
+
+        return out

+ 277 - 0
models/yolov8/yolov8_head.py

@@ -0,0 +1,277 @@
+import torch
+import torch.nn as nn
+
+from .yolov8_basic import BasicConv
+
+
+# -------------------- Detection Head --------------------
+## Single-level Detection Head
+class DetHead(nn.Module):
+    def __init__(self,
+                 in_dim       :int  = 256,
+                 cls_head_dim :int  = 256,
+                 reg_head_dim :int  = 256,
+                 num_cls_head :int  = 2,
+                 num_reg_head :int  = 2,
+                 act_type     :str  = "silu",
+                 norm_type    :str  = "BN",
+                 depthwise    :bool = False):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dim = in_dim
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        self.act_type = act_type
+        self.norm_type = norm_type
+        self.depthwise = depthwise
+        
+        # --------- Network Parameters ----------
+        ## cls head
+        cls_feats = []
+        self.cls_head_dim = cls_head_dim
+        for i in range(num_cls_head):
+            if i == 0:
+                cls_feats.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                cls_feats.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        ## reg head
+        reg_feats = []
+        self.reg_head_dim = reg_head_dim
+        for i in range(num_reg_head):
+            if i == 0:
+                reg_feats.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                reg_feats.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+
+        return cls_feats, reg_feats
+    
+## Multi-level Detection Head
+class Yolov8DetHead(nn.Module):
+    def __init__(self, cfg, in_dims):
+        super().__init__()
+        ## ----------- Network Parameters -----------
+        self.multi_level_heads = nn.ModuleList(
+            [DetHead(in_dim       = in_dims[level],
+                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 100)),
+                     reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
+                     num_cls_head = cfg.num_cls_head,
+                     num_reg_head = cfg.num_reg_head,
+                     act_type     = cfg.head_act,
+                     norm_type    = cfg.head_norm,
+                     depthwise    = cfg.head_depthwise)
+                     for level in range(cfg.num_levels)
+                     ])
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
+        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
+
+
+    def forward(self, feats):
+        """
+            feats: List[(Tensor)] [[B, C, H, W], ...]
+        """
+        cls_feats = []
+        reg_feats = []
+        for feat, head in zip(feats, self.multi_level_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+
+        return cls_feats, reg_feats
+
+
+# -------------------- Segmentation Head --------------------
+## Single-level Segmentation Head (not complete yet)
+class SegHead(nn.Module):
+    def __init__(self,
+                 in_dim       :int  = 256,
+                 cls_head_dim :int  = 256,
+                 reg_head_dim :int  = 256,
+                 seg_head_dim :int  = 256,
+                 num_cls_head :int  = 2,
+                 num_reg_head :int  = 2,
+                 num_seg_head :int  = 2,
+                 act_type     :str  = "silu",
+                 norm_type    :str  = "BN",
+                 depthwise    :bool = False):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.in_dim = in_dim
+        self.num_cls_head = num_cls_head
+        self.num_reg_head = num_reg_head
+        self.num_seg_head = num_seg_head
+        self.act_type = act_type
+        self.norm_type = norm_type
+        self.depthwise = depthwise
+        
+        # --------- Network Parameters ----------
+        ## cls head
+        cls_feats = []
+        self.cls_head_dim = cls_head_dim
+        for i in range(num_cls_head):
+            if i == 0:
+                cls_feats.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                cls_feats.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        ## reg head
+        reg_feats = []
+        self.reg_head_dim = reg_head_dim
+        for i in range(num_reg_head):
+            if i == 0:
+                reg_feats.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                reg_feats.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        ## seg head
+        seg_feats = []
+        self.seg_head_dim = seg_head_dim
+        for i in range(num_reg_head):
+            if i == 0:
+                seg_feats.append(
+                    BasicConv(in_dim, self.seg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+            else:
+                seg_feats.append(
+                    BasicConv(self.seg_head_dim, self.seg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=act_type,
+                              norm_type=norm_type,
+                              depthwise=depthwise)
+                              )
+        self.cls_feats = nn.Sequential(*cls_feats)
+        self.reg_feats = nn.Sequential(*reg_feats)
+        self.seg_feats = nn.Sequential(*seg_feats)
+
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        """
+            in_feats: (Tensor) [B, C, H, W]
+        """
+        cls_feats = self.cls_feats(x)
+        reg_feats = self.reg_feats(x)
+        seg_feats = self.seg_feats(x)
+
+        return cls_feats, reg_feats, seg_feats
+
+## Multi-level Segmentation Head (not complete yet)
+class YoloSegHead(nn.Module):
+    def __init__(self, cfg, in_dims):
+        super().__init__()
+        ## ----------- Network Parameters -----------
+        self.multi_level_heads = nn.ModuleList(
+            [SegHead(in_dim       = in_dims[level],
+                     cls_head_dim = max(in_dims[0], min(cfg.num_classes, 100)),
+                     reg_head_dim = max(in_dims[0]//4, 16, 4*cfg.reg_max),
+                     seg_head_dim = in_dims[0],
+                     num_cls_head = cfg.num_cls_head,
+                     num_reg_head = cfg.num_reg_head,
+                     num_seg_head = cfg.num_seg_head,
+                     act_type     = cfg.head_act,
+                     norm_type    = cfg.head_norm,
+                     depthwise    = cfg.head_depthwise)
+                     for level in range(cfg.num_levels)
+                     ])
+        # --------- Basic Parameters ----------
+        self.in_dims = in_dims
+        self.cls_head_dim = self.multi_level_heads[0].cls_head_dim
+        self.reg_head_dim = self.multi_level_heads[0].reg_head_dim
+        self.seg_head_dim = self.multi_level_heads[0].seg_head_dim
+
+    def forward(self, feats):
+        """
+            feats: List[(Tensor)] [[B, C, H, W], ...]
+        """
+        cls_feats = []
+        reg_feats = []
+        seg_feats = []
+        for feat, head in zip(feats, self.multi_level_heads):
+            # ---------------- Pred ----------------
+            cls_feat, reg_feat, seg_feat = head(feat)
+
+            cls_feats.append(cls_feat)
+            reg_feats.append(reg_feat)
+            seg_feats.append(seg_feat)
+
+        return cls_feats, reg_feats, seg_feats

+ 33 - 0
models/yolov8/yolov8_neck.py

@@ -0,0 +1,33 @@
+import torch
+import torch.nn as nn
+
+from .yolov8_basic import BasicConv
+
+
+# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher
+class SPPF(nn.Module):
+    """
+        This code referenced to https://github.com/ultralytics/yolov5
+    """
+    def __init__(self, cfg, in_dim, out_dim):
+        super().__init__()
+        ## ----------- Basic Parameters -----------
+        inter_dim = round(in_dim * cfg.neck_expand_ratio)
+        self.out_dim = out_dim
+        ## ----------- Network Parameters -----------
+        self.cv1 = BasicConv(in_dim, inter_dim,
+                             kernel_size=1, padding=0, stride=1,
+                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.cv2 = BasicConv(inter_dim * 4, out_dim,
+                             kernel_size=1, padding=0, stride=1,
+                             act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.m = nn.MaxPool2d(kernel_size=cfg.spp_pooling_size,
+                              stride=1,
+                              padding=cfg.spp_pooling_size // 2)
+
+    def forward(self, x):
+        x = self.cv1(x)
+        y1 = self.m(x)
+        y2 = self.m(y1)
+
+        return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))

+ 104 - 0
models/yolov8/yolov8_pafpn.py

@@ -0,0 +1,104 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import List
+
+from .yolov8_basic import BasicConv, ELANLayer
+
+
+# PaFPN-ELAN
+class Yolov8PaFPN(nn.Module):
+    def __init__(self,
+                 cfg,
+                 in_dims :List = [256, 512, 1024],
+                 ) -> None:
+        super(Yolov8PaFPN, self).__init__()
+        print('==============================')
+        print('FPN: {}'.format("Yolo PaFPN"))
+        # --------------------------- Basic Parameters ---------------------------
+        self.in_dims = in_dims[::-1]
+        self.out_dims = [round(256*cfg.width), round(512*cfg.width), round(512*cfg.width*cfg.ratio)]
+
+        # ---------------- Top dwon ----------------
+        ## P5 -> P4
+        self.top_down_layer_1 = ELANLayer(in_dim     = self.in_dims[0] + self.in_dims[1],
+                                          out_dim    = round(512*cfg.width),
+                                          expansion  = 0.5,
+                                          num_blocks = round(3 * cfg.depth),
+                                          shortcut   = False,
+                                          act_type   = cfg.fpn_act,
+                                          norm_type  = cfg.fpn_norm,
+                                          depthwise  = cfg.fpn_depthwise,
+                                          )
+        ## P4 -> P3
+        self.top_down_layer_2 = ELANLayer(in_dim     = self.in_dims[2] + round(512*cfg.width),
+                                          out_dim    = round(256*cfg.width),
+                                          expansion  = 0.5,
+                                          num_blocks = round(3 * cfg.depth),
+                                          shortcut   = False,
+                                          act_type   = cfg.fpn_act,
+                                          norm_type  = cfg.fpn_norm,
+                                          depthwise  = cfg.fpn_depthwise,
+                                          )
+        # ---------------- Bottom up ----------------
+        ## P3 -> P4
+        self.dowmsample_layer_1 = BasicConv(round(256*cfg.width), round(256*cfg.width),
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
+        self.bottom_up_layer_1 = ELANLayer(in_dim     = round(256*cfg.width) + round(512*cfg.width),
+                                           out_dim    = round(512*cfg.width),
+                                           expansion  = 0.5,
+                                           num_blocks = round(3 * cfg.depth),
+                                           shortcut   = False,
+                                           act_type   = cfg.fpn_act,
+                                           norm_type  = cfg.fpn_norm,
+                                           depthwise  = cfg.fpn_depthwise,
+                                           )
+        ## P4 -> P5
+        self.dowmsample_layer_2 = BasicConv(round(512*cfg.width), round(512*cfg.width),
+                                            kernel_size=3, padding=1, stride=2,
+                                            act_type=cfg.fpn_act, norm_type=cfg.fpn_norm, depthwise=cfg.fpn_depthwise)
+        self.bottom_up_layer_2 = ELANLayer(in_dim     = round(512*cfg.width) + self.in_dims[0],
+                                           out_dim    = round(512*cfg.width*cfg.ratio),
+                                           expansion  = 0.5,
+                                           num_blocks = round(3 * cfg.depth),
+                                           shortcut   = False,
+                                           act_type   = cfg.fpn_act,
+                                           norm_type  = cfg.fpn_norm,
+                                           depthwise  = cfg.fpn_depthwise,
+                                           )
+        
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        # ------------------ Top down FPN ------------------
+        ## P5 -> P4
+        p5_up = F.interpolate(c5, scale_factor=2.0)
+        p4 = self.top_down_layer_1(torch.cat([p5_up, c4], dim=1))
+
+        ## P4 -> P3
+        p4_up = F.interpolate(p4, scale_factor=2.0)
+        p3 = self.top_down_layer_2(torch.cat([p4_up, c3], dim=1))
+
+        # ------------------ Bottom up FPN ------------------
+        ## p3 -> P4
+        p3_ds = self.dowmsample_layer_1(p3)
+        p4 = self.bottom_up_layer_1(torch.cat([p3_ds, p4], dim=1))
+
+        ## P4 -> 5
+        p4_ds = self.dowmsample_layer_2(p4)
+        p5 = self.bottom_up_layer_2(torch.cat([p4_ds, c5], dim=1))
+
+        out_feats = [p3, p4, p5] # [P3, P4, P5]
+        
+        return out_feats

+ 315 - 0
models/yolov8/yolov8_pred.py

@@ -0,0 +1,315 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# -------------------- Detection Pred Layer --------------------
+## Single-level pred layer
+class DetPredLayer(nn.Module):
+    def __init__(self,
+                 cls_dim     :int = 256,
+                 reg_dim     :int = 256,
+                 stride      :int = 32,
+                 reg_max     :int = 16,
+                 num_classes :int = 80,
+                 num_coords  :int = 4):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.stride = stride
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.reg_max = reg_max
+        self.num_classes = num_classes
+        self.num_coords = num_coords
+
+        # --------- Network Parameters ----------
+        self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
+        self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1)                
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # cls pred bias
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred bias
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+    def generate_anchors(self, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.stride
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat):
+        # pred
+        cls_pred = self.cls_pred(cls_feat)
+        reg_pred = self.reg_pred(reg_feat)
+
+        # generate anchor boxes: [M, 4]
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+        # stride tensor: [M, 1]
+        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
+        
+        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+        
+        # output dict
+        outputs = {"pred_cls": cls_pred,            # List(Tensor) [B, M, C]
+                   "pred_reg": reg_pred,            # List(Tensor) [B, M, 4*(reg_max)]
+                   "anchors": anchors,              # List(Tensor) [M, 2]
+                   "strides": self.stride,          # List(Int) = [8, 16, 32]
+                   "stride_tensor": stride_tensor   # List(Tensor) [M, 1]
+                   }
+
+        return outputs
+
+## Multi-level pred layer
+class Yolov8DetPredLayer(nn.Module):
+    def __init__(self,
+                 cfg,
+                 cls_dim,
+                 reg_dim,
+                 ):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+
+        # ----------- Network Parameters -----------
+        ## pred layers
+        self.multi_level_preds = nn.ModuleList(
+            [DetPredLayer(cls_dim     = cls_dim,
+                          reg_dim     = reg_dim,
+                          stride      = cfg.out_stride[level],
+                          reg_max     = cfg.reg_max,
+                          num_classes = cfg.num_classes,
+                          num_coords  = 4 * cfg.reg_max)
+                          for level in range(cfg.num_levels)
+                          ])
+        ## proj conv
+        proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
+        self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
+        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
+
+    def forward(self, cls_feats, reg_feats):
+        all_anchors = []
+        all_strides = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_box_preds = []
+        for level in range(self.cfg.num_levels):
+            # -------------- Single-level prediction --------------
+            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level])
+
+            # -------------- Decode bbox --------------
+            B, M = outputs["pred_reg"].shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
+            delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.cfg.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+            ## tlbr -> xyxy
+            x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.cfg.out_stride[level]
+            x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.cfg.out_stride[level]
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+            # collect results
+            all_cls_preds.append(outputs["pred_cls"])
+            all_reg_preds.append(outputs["pred_reg"])
+            all_box_preds.append(box_pred)
+            all_anchors.append(outputs["anchors"])
+            all_strides.append(outputs["stride_tensor"])
+        
+        # output dict
+        outputs = {"pred_cls":      all_cls_preds,         # List(Tensor) [B, M, C]
+                   "pred_reg":      all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box":      all_box_preds,         # List(Tensor) [B, M, 4]
+                   "anchors":       all_anchors,           # List(Tensor) [M, 2]
+                   "stride_tensor": all_strides,           # List(Tensor) [M, 1]
+                   "strides":       self.cfg.out_stride,   # List(Int) = [8, 16, 32]
+                   }
+
+        return outputs
+
+
+# -------------------- Segmentation Pred Layer --------------------
+## Single-level pred layer (not complete yet)
+class SegPredLayer(nn.Module):
+    def __init__(self,
+                 cls_dim     :int = 256,
+                 reg_dim     :int = 256,
+                 seg_dim     :int = 256,
+                 stride      :int = 32,
+                 reg_max     :int = 16,
+                 num_classes :int = 80,
+                 num_coords  :int = 4):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.stride = stride
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.seg_dim = seg_dim
+        self.reg_max = reg_max
+        self.num_classes = num_classes
+        self.num_coords = num_coords
+
+        # --------- Network Parameters ----------
+        self.cls_pred = nn.Conv2d(cls_dim, num_classes, kernel_size=1)
+        self.reg_pred = nn.Conv2d(reg_dim, num_coords, kernel_size=1)                
+        self.seg_pred = nn.Conv2d(seg_dim, 1, kernel_size=1)                
+
+        self.init_bias()
+        
+    def init_bias(self):
+        # cls pred bias
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(math.log(5 / self.num_classes / (640. / self.stride) ** 2))
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred bias
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # seg pred bias
+        b = self.seg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.seg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+
+    def generate_anchors(self, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+        # [H, W, 2] -> [HW, 2]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5  # add center offset
+        anchors *= self.stride
+
+        return anchors
+        
+    def forward(self, cls_feat, reg_feat, seg_feat):
+        # pred
+        cls_pred = self.cls_pred(cls_feat)
+        reg_pred = self.reg_pred(reg_feat)
+        seg_pred = self.seg_pred(seg_feat)
+
+        # generate anchor boxes: [M, 4]
+        B, _, H, W = cls_pred.size()
+        fmp_size = [H, W]
+        anchors = self.generate_anchors(fmp_size)
+        anchors = anchors.to(cls_pred.device)
+        # stride tensor: [M, 1]
+        stride_tensor = torch.ones_like(anchors[..., :1]) * self.stride
+        
+        # [B, C, H, W] -> [B, H, W, C] -> [B, M, C]
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4*self.reg_max)
+        
+        # output dict
+        outputs = {"pred_cls": cls_pred,            # List(Tensor) [B, M, C]
+                   "pred_reg": reg_pred,            # List(Tensor) [B, M, 4*(reg_max)]
+                   "anchors": anchors,              # List(Tensor) [M, 2]
+                   "strides": self.stride,          # List(Int) = [8, 16, 32]
+                   "stride_tensor": stride_tensor   # List(Tensor) [M, 1]
+                   }
+
+        return outputs
+
+## Multi-level pred layer
+class YoloSegPredLayer(nn.Module):
+    def __init__(self,
+                 cfg,
+                 cls_dim,
+                 reg_dim,
+                 seg_dim,
+                 ):
+        super().__init__()
+        # --------- Basic Parameters ----------
+        self.cfg = cfg
+        self.cls_dim = cls_dim
+        self.reg_dim = reg_dim
+        self.seg_dim = seg_dim
+
+        # ----------- Network Parameters -----------
+        ## pred layers
+        self.multi_level_preds = nn.ModuleList(
+            [SegPredLayer(cls_dim     = cls_dim,
+                          reg_dim     = reg_dim,
+                          seg_dim     = seg_dim,
+                          stride      = cfg.out_stride[level],
+                          reg_max     = cfg.reg_max,
+                          num_classes = cfg.num_classes,
+                          num_coords  = 4 * cfg.reg_max)
+                          for level in range(cfg.num_levels)
+                          ])
+        ## proj conv
+        proj_init = torch.arange(cfg.reg_max, dtype=torch.float)
+        self.proj_conv = nn.Conv2d(cfg.reg_max, 1, kernel_size=1, bias=False).requires_grad_(False)
+        self.proj_conv.weight.data[:] = nn.Parameter(proj_init.view([1, cfg.reg_max, 1, 1]), requires_grad=False)
+
+    def forward(self, cls_feats, reg_feats, seg_feats):
+        all_anchors = []
+        all_strides = []
+        all_cls_preds = []
+        all_reg_preds = []
+        all_seg_preds = []
+        all_box_preds = []
+        for level in range(self.cfg.num_levels):
+            # -------------- Single-level prediction --------------
+            outputs = self.multi_level_preds[level](cls_feats[level], reg_feats[level], seg_feats[level])
+
+            # -------------- Decode bbox --------------
+            B, M = outputs["pred_reg"].shape[:2]
+            # [B, M, 4*(reg_max)] -> [B, M, 4, reg_max]
+            delta_pred = outputs["pred_reg"].reshape([B, M, 4, self.cfg.reg_max])
+            # [B, M, 4, reg_max] -> [B, reg_max, 4, M]
+            delta_pred = delta_pred.permute(0, 3, 2, 1).contiguous()
+            # [B, reg_max, 4, M] -> [B, 1, 4, M]
+            delta_pred = self.proj_conv(F.softmax(delta_pred, dim=1))
+            # [B, 1, 4, M] -> [B, 4, M] -> [B, M, 4]
+            delta_pred = delta_pred.view(B, 4, M).permute(0, 2, 1).contiguous()
+            ## tlbr -> xyxy
+            x1y1_pred = outputs["anchors"][None] - delta_pred[..., :2] * self.cfg.out_stride[level]
+            x2y2_pred = outputs["anchors"][None] + delta_pred[..., 2:] * self.cfg.out_stride[level]
+            box_pred = torch.cat([x1y1_pred, x2y2_pred], dim=-1)
+
+            # collect results
+            all_cls_preds.append(outputs["pred_cls"])
+            all_reg_preds.append(outputs["pred_reg"])
+            all_box_preds.append(box_pred)
+            all_anchors.append(outputs["anchors"])
+            all_strides.append(outputs["stride_tensor"])
+        
+        # output dict
+        outputs = {"pred_cls":      all_cls_preds,         # List(Tensor) [B, M, C]
+                   "pred_reg":      all_reg_preds,         # List(Tensor) [B, M, 4*(reg_max)]
+                   "pred_box":      all_box_preds,         # List(Tensor) [B, M, 4]
+                   "anchors":       all_anchors,           # List(Tensor) [M, 2]
+                   "stride_tensor": all_strides,           # List(Tensor) [M, 1]
+                   "strides":       self.cfg.out_stride,   # List(Int) = [8, 16, 32]
+                   }
+
+        return outputs
+

+ 27 - 0
requirements.txt

@@ -0,0 +1,27 @@
+torch
+
+torchvision
+
+opencv-python
+
+thop
+
+scipy
+
+matplotlib
+
+numpy
+
+imageio
+
+pycocotools
+
+onnxsim
+
+onnxruntime
+
+openvino
+
+loguru
+
+albumentations

+ 156 - 0
test.py

@@ -0,0 +1,156 @@
+import argparse
+import cv2
+import os
+import time
+import numpy as np
+from copy import deepcopy
+import torch
+
+# load transform
+from dataset.build import build_dataset, build_transform
+
+# load some utils
+from utils.misc import load_weight, compute_flops
+from utils.box_ops import rescale_bboxes
+from utils.vis_tools import visualize
+
+from config import build_config
+from models import build_model
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
+    # Basic setting
+    parser.add_argument('-size', '--img_size', default=640, type=int,
+                        help='the max size of input image')
+    parser.add_argument('--show', action='store_true', default=False,
+                        help='show the visulization results.')
+    parser.add_argument('--save', action='store_true', default=False,
+                        help='save the visulization results.')
+    parser.add_argument('--cuda', action='store_true', default=False, 
+                        help='use cuda.')
+    parser.add_argument('--save_folder', default='det_results/', type=str,
+                        help='Dir to save results')
+    parser.add_argument('-ws', '--window_scale', default=1.0, type=float,
+                        help='resize window of cv2 for visualization.')
+
+    # Model setting
+    parser.add_argument('-m', '--model', default='yolo_n', type=str,
+                        help='build yolo')
+    parser.add_argument('--weight', default=None,
+                        type=str, help='Trained state_dict file path to open')
+    parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
+                        help='fuse Conv & BN')
+    parser.add_argument('--fuse_rep_conv', action='store_true', default=False,
+                        help='fuse Conv & BN')
+
+    # Data setting
+    parser.add_argument('--root', default='D:/python_work/dataset/COCO/',
+                        help='data root')
+    parser.add_argument('-d', '--dataset', default='coco',
+                        help='coco, voc.')
+
+    return parser.parse_args()
+
+
+@torch.no_grad()
+def test_det(args,
+             model, 
+             device, 
+             dataset,
+             transform=None,
+             class_colors=None, 
+             class_names=None):
+    num_images = len(dataset)
+    save_path = os.path.join('det_results/', args.dataset, args.model)
+    os.makedirs(save_path, exist_ok=True)
+
+    for index in range(num_images):
+        print('Testing image {:d}/{:d}....'.format(index+1, num_images))
+        image, _ = dataset.pull_image(index)
+
+        orig_h, orig_w, _ = image.shape
+        orig_size = [orig_w, orig_h]
+
+        # prepare
+        x, _, ratio = transform(image)
+        x = x.unsqueeze(0).to(device)
+
+        t0 = time.time()
+        # inference
+        outputs = model(x)
+        scores = outputs['scores']
+        labels = outputs['labels']
+        bboxes = outputs['bboxes']
+        print("detection time used ", time.time() - t0, "s")
+        
+        # rescale bboxes
+        bboxes = rescale_bboxes(bboxes, orig_size, ratio)
+
+        # vis detection
+        img_processed = visualize(image=image,
+                                  bboxes=bboxes,
+                                  scores=scores,
+                                  labels=labels,
+                                  class_colors=class_colors,
+                                  class_names=class_names)
+        if args.show:
+            h, w = img_processed.shape[:2]
+            sw, sh = int(w*args.window_scale), int(h*args.window_scale)
+            cv2.namedWindow('detection', 0)
+            cv2.resizeWindow('detection', sw, sh)
+            cv2.imshow('detection', img_processed)
+            cv2.waitKey(0)
+
+        if args.save:
+            # save result
+            cv2.imwrite(os.path.join(save_path, str(index).zfill(6) +'.jpg'), img_processed)
+
+
+if __name__ == '__main__':
+    args = parse_args()
+    # cuda
+    if args.cuda:
+        print('use cuda')
+        device = torch.device("cuda")
+    else:
+        device = torch.device("cpu")
+
+    # Dataset & Model Config
+    cfg = build_config(args)
+
+    # Transform
+    transform = build_transform(cfg, is_train=False)
+
+    # Dataset
+    dataset = build_dataset(args, cfg, transform, is_train=False)
+
+    np.random.seed(0)
+    class_colors = [(np.random.randint(255),
+                     np.random.randint(255),
+                     np.random.randint(255)) for _ in range(cfg.num_classes)]
+
+    # build model
+    model = build_model(args, cfg, is_val=False)
+
+    # load trained weight
+    model = load_weight(model, args.weight, args.fuse_conv_bn, args.fuse_rep_conv)
+    model.to(device).eval()
+
+    # compute FLOPs and Params
+    model_copy = deepcopy(model)
+    model_copy.trainable = False
+    model_copy.eval()
+    compute_flops(model_copy, cfg.test_img_size, device)
+    del model_copy
+        
+    print("================= DETECT =================")
+    # run
+    test_det(args         = args,
+             model        = model, 
+             device       = device, 
+             dataset      = dataset,
+             transform    = transform,
+             class_colors = class_colors,
+             class_names  = cfg.class_labels,
+             )

+ 208 - 0
train.py

@@ -0,0 +1,208 @@
+from __future__ import division
+
+import os
+import random
+import numpy as np
+import argparse
+from copy import deepcopy
+
+# ----------------- Torch Components -----------------
+import torch
+import torch.distributed as dist
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+# ----------------- Extra Components -----------------
+from utils import distributed_utils
+from utils.misc import compute_flops, build_dataloader, CollateFunc, ModelEMA
+
+# ----------------- Config Components -----------------
+from config import build_config
+
+# ----------------- Data Components -----------------
+from dataset.build import build_dataset, build_transform
+
+# ----------------- Evaluator Components -----------------
+from evaluator.build import build_evluator
+
+# ----------------- Model Components -----------------
+from models import build_model
+
+# ----------------- Train Components -----------------
+from engine import build_trainer
+
+
+def parse_args():
+    parser = argparse.ArgumentParser(description='Real-time Object Detection LAB')
+    # Random seed
+    parser.add_argument('--seed', default=42, type=int)
+
+    # GPU
+    parser.add_argument('--cuda', action='store_true', default=False,
+                        help='use cuda.')
+    
+    # Image size
+    parser.add_argument('--eval_first', action='store_true', default=False,
+                        help='evaluate model before training.')
+    
+    # Outputs
+    parser.add_argument('--tfboard', action='store_true', default=False,
+                        help='use tensorboard')
+    parser.add_argument('--save_folder', default='weights/', type=str, 
+                        help='path to save weight')
+    parser.add_argument('--vis_tgt', action="store_true", default=False,
+                        help="visualize training data.")
+    parser.add_argument('--vis_aux_loss', action="store_true", default=False,
+                        help="visualize aux loss.")
+    
+    # Mixing precision
+    parser.add_argument('--fp16', dest="fp16", action="store_true", default=False,
+                        help="Adopting mix precision training.")
+    
+    # Batchsize
+    parser.add_argument('-bs', '--batch_size', default=16, type=int, 
+                        help='batch size on all the GPUs.')
+
+    # Model
+    parser.add_argument('-m', '--model', default='yolo_n', type=str,
+                        help='build yolo')
+    parser.add_argument('-p', '--pretrained', default=None, type=str,
+                        help='load pretrained weight')
+    parser.add_argument('-r', '--resume', default=None, type=str,
+                        help='keep training')
+
+    # Dataset
+    parser.add_argument('--root', default='D:/python_work/dataset/VOCdevkit/',
+                        help='data root')
+    parser.add_argument('-d', '--dataset', default='coco',
+                        help='coco, voc')
+    parser.add_argument('--num_workers', default=4, type=int, 
+                        help='Number of workers used in dataloading')
+    
+    # DDP train
+    parser.add_argument('-dist', '--distributed', action='store_true', default=False,
+                        help='distributed training')
+    parser.add_argument('--dist_url', default='env://', 
+                        help='url used to set up distributed training')
+    parser.add_argument('--world_size', default=1, type=int,
+                        help='number of distributed processes')
+    parser.add_argument('--sybn', action='store_true', default=False, 
+                        help='use sybn.')
+    parser.add_argument('--find_unused_parameters', action='store_true', default=False,
+                        help='set find_unused_parameters as True.')
+    
+    # Debug mode
+    parser.add_argument('--debug', action='store_true', default=False, 
+                        help='debug mode.')
+
+    return parser.parse_args()
+
+
+def fix_random_seed(args):
+    seed = args.seed + distributed_utils.get_rank()
+    torch.manual_seed(seed)
+    np.random.seed(seed)
+    random.seed(seed)
+
+
+def train():
+    args = parse_args()
+    print("Setting Arguments.. : ", args)
+    print("----------------------------------------------------------")
+
+    # ---------------------------- Build DDP ----------------------------
+    local_rank = local_process_rank = -1
+    if args.distributed:
+        distributed_utils.init_distributed_mode(args)
+        print("git:\n  {}\n".format(distributed_utils.get_sha()))
+        try:
+            # Multiple Mechine & Multiple GPUs (world size > 8)
+            local_rank = torch.distributed.get_rank()
+            local_process_rank = int(os.getenv('LOCAL_PROCESS_RANK', '0'))
+        except:
+            # Single Mechine & Multiple GPUs (world size <= 8)
+            local_rank = local_process_rank = torch.distributed.get_rank()
+    world_size = distributed_utils.get_world_size()
+    print("LOCAL RANK: ", local_rank)
+    print("LOCAL_PROCESS_RANL: ", local_process_rank)
+    print('WORLD SIZE: {}'.format(world_size))
+
+    # ---------------------------- Build CUDA ----------------------------
+    if args.cuda and torch.cuda.is_available():
+        print('use cuda')
+        device = torch.device("cuda")
+    else:
+        device = torch.device("cpu")
+
+    # ---------------------------- Fix random seed ----------------------------
+    fix_random_seed(args)
+
+    # ---------------------------- Build config ----------------------------
+    cfg = build_config(args)
+
+    # ---------------------------- Build Transform ----------------------------
+    train_transform = build_transform(cfg, is_train=True)
+    val_transform   = build_transform(cfg, is_train=False)
+
+    # ---------------------------- Build Dataset & Dataloader ----------------------------
+    dataset      = build_dataset(args, cfg, train_transform, is_train=True)
+    train_loader = build_dataloader(args, dataset, args.batch_size // world_size, CollateFunc())
+
+    # ---------------------------- Build Evaluator ----------------------------
+    evaluator = build_evluator(args, cfg, val_transform, device)
+
+    # ---------------------------- Build model ----------------------------
+    ## Build model
+    model, criterion = build_model(args, cfg, is_val=True)
+    model = model.to(device).train()
+    model_without_ddp = model
+
+    # ---------------------------- Build Model-EMA ----------------------------
+    if cfg.use_ema and distributed_utils.get_rank() in [-1, 0]:
+        print('Build ModelEMA for {} ...'.format(args.model))
+        model_ema = ModelEMA(model, cfg.ema_decay, cfg.ema_tau)
+    else:
+        model_ema = None
+
+    ## Calcute Params & GFLOPs
+    if distributed_utils.is_main_process:
+        model_copy = deepcopy(model_without_ddp)
+        model_copy.trainable = False
+        model_copy.eval()
+        compute_flops(model=model_copy,
+                      img_size=cfg.test_img_size,
+                      device=device)
+        del model_copy
+    if args.distributed:
+        dist.barrier()
+
+    ## Build DDP model
+    if args.distributed:
+        model = DDP(model, device_ids=[args.gpu], find_unused_parameters=args.find_unused_parameters)
+        if args.sybn:
+            print('use SyncBatchNorm ...')
+            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
+        model_without_ddp = model.module
+
+    if args.distributed:
+        dist.barrier()
+
+    # ---------------------------- Build Trainer ----------------------------
+    trainer = build_trainer(args, cfg, device, model, model_ema, criterion, train_transform, val_transform, dataset, train_loader, evaluator)
+
+    ## Eval before training
+    if args.eval_first and distributed_utils.is_main_process():
+        # to check whether the evaluator can work
+        model_eval = model_without_ddp
+        trainer.eval(model_eval)
+        return
+
+    # ---------------------------- Train pipeline ----------------------------
+    trainer.train(model)
+
+    # Empty cache after train loop
+    del trainer
+    if args.cuda:
+        torch.cuda.empty_cache()
+
+if __name__ == '__main__':
+    train()

+ 36 - 0
train.sh

@@ -0,0 +1,36 @@
+# Args parameters
+MODEL=$1
+DATASET=$2
+DATASET_ROOT=$3
+BATCH_SIZE=$4
+WORLD_SIZE=$5
+MASTER_PORT=$6
+RESUME=$7
+
+
+# -------------------------- Train Pipeline --------------------------
+if [ $WORLD_SIZE == 1 ]; then
+    python train.py \
+            --cuda \
+            --dataset ${DATASET} \
+            --root ${DATASET_ROOT} \
+            --model ${MODEL} \
+            --batch_size ${BATCH_SIZE} \
+            --resume ${RESUME} \
+            --fp16
+elif [[ $WORLD_SIZE -gt 1 && $WORLD_SIZE -le 8 ]]; then
+    python -m torch.distributed.run --nproc_per_node=${WORLD_SIZE} --master_port ${MASTER_PORT} train.py \
+            --cuda \
+            --distributed \
+            --dataset ${DATASET} \
+            --root ${DATASET_ROOT} \
+            --model ${MODEL} \
+            --batch_size ${BATCH_SIZE} \
+            --resume ${RESUME} \
+            --fp16 \
+            --sybn
+else
+    echo "The WORLD_SIZE is set to a value greater than 8, indicating the use of multi-machine \
+          multi-card training mode, which is currently unsupported."
+    exit 1
+fi

+ 0 - 0
utils/__init__.py


+ 206 - 0
utils/box_ops.py

@@ -0,0 +1,206 @@
+from typing import List
+import math
+import numpy as np
+import torch
+from torchvision.ops.boxes import box_area
+
+
+# ------------------ Box ops ------------------
+def box_cxcywh_to_xyxy(x):
+    x_c, y_c, w, h = x.unbind(-1)
+    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
+         (x_c + 0.5 * w), (y_c + 0.5 * h)]
+    return torch.stack(b, dim=-1)
+
+def box_xyxy_to_cxcywh(x):
+    x0, y0, x1, y1 = x.unbind(-1)
+    b = [(x0 + x1) / 2, (y0 + y1) / 2,
+         (x1 - x0), (y1 - y0)]
+    return torch.stack(b, dim=-1)
+
+def rescale_bboxes(bboxes, origin_size, ratio):
+    # rescale bboxes
+    if isinstance(ratio, float):
+        bboxes /= ratio
+    elif isinstance(ratio, List) and len(ratio) == 2:
+        bboxes[..., [0, 2]] /= ratio[0]
+        bboxes[..., [1, 3]] /= ratio[1]
+    else:
+        raise NotImplementedError("ratio should be a int or List[int, int] type.")
+
+    # clip bboxes
+    bboxes[..., [0, 2]] = np.clip(bboxes[..., [0, 2]], a_min=0., a_max=origin_size[0])
+    bboxes[..., [1, 3]] = np.clip(bboxes[..., [1, 3]], a_min=0., a_max=origin_size[1])
+
+    return bboxes
+
+def bbox2dist(anchor_points, bbox, reg_max):
+    '''Transform bbox(xyxy) to dist(ltrb).'''
+    x1y1, x2y2 = torch.split(bbox, 2, -1)
+    lt = anchor_points - x1y1
+    rb = x2y2 - anchor_points
+    dist = torch.cat([lt, rb], -1).clamp(0, reg_max - 0.01)
+    return dist
+
+def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
+    # hack for matcher
+    if proposals.size() != gt.size():
+        proposals = proposals[:, None]
+        gt = gt[None]
+
+    proposals = proposals.float()
+    gt = gt.float()
+    px, py, pw, ph = proposals.unbind(-1)
+    gx, gy, gw, gh = gt.unbind(-1)
+
+    dx = (gx - px) / (pw + 0.1)
+    dy = (gy - py) / (ph + 0.1)
+    dw = torch.log(gw / (pw + 0.1))
+    dh = torch.log(gh / (ph + 0.1))
+    deltas = torch.stack([dx, dy, dw, dh], dim=-1)
+
+    means = deltas.new_tensor(means).unsqueeze(0)
+    stds = deltas.new_tensor(stds).unsqueeze(0)
+    deltas = deltas.sub_(means).div_(stds)
+
+    return deltas
+
+# ------------------ IoU ops ------------------
+def box_iou(boxes1, boxes2):
+    area1 = box_area(boxes1)
+    area2 = box_area(boxes2)
+
+    lt = torch.max(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
+    rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]
+
+    wh = (rb - lt).clamp(min=0)  # [N,M,2]
+    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]
+
+    union = area1[:, None] + area2 - inter
+
+    iou = inter / union
+    return iou, union
+
+def generalized_box_iou(boxes1, boxes2):
+    """
+    Generalized IoU from https://giou.stanford.edu/
+
+    The boxes should be in [x0, y0, x1, y1] format
+
+    Returns a [N, M] pairwise matrix, where N = len(boxes1)
+    and M = len(boxes2)
+    """
+    # degenerate boxes gives inf / nan results
+    # so do an early check
+    assert (boxes1[:, 2:] >= boxes1[:, :2]).all()
+    assert (boxes2[:, 2:] >= boxes2[:, :2]).all()
+    iou, union = box_iou(boxes1, boxes2)
+
+    lt = torch.min(boxes1[:, None, :2], boxes2[:, :2])
+    rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:])
+
+    wh = (rb - lt).clamp(min=0)  # [N,M,2]
+    area = wh[:, :, 0] * wh[:, :, 1]
+
+    return iou - (area - union) / area
+
+def get_ious(bboxes1,
+             bboxes2,
+             box_mode="xyxy",
+             iou_type="iou"):
+    """
+    Compute iou loss of type ['iou', 'giou', 'linear_iou']
+
+    Args:
+        inputs (tensor): pred values
+        targets (tensor): target values
+        weight (tensor): loss weight
+        box_mode (str): 'xyxy' or 'ltrb', 'ltrb' is currently supported.
+        loss_type (str): 'giou' or 'iou' or 'linear_iou'
+        reduction (str): reduction manner
+
+    Returns:
+        loss (tensor): computed iou loss.
+    """
+    if box_mode == "ltrb":
+        bboxes1 = torch.cat((-bboxes1[..., :2], bboxes1[..., 2:]), dim=-1)
+        bboxes2 = torch.cat((-bboxes2[..., :2], bboxes2[..., 2:]), dim=-1)
+    elif box_mode != "xyxy":
+        raise NotImplementedError
+
+    eps = torch.finfo(torch.float32).eps
+
+    bboxes1_area = (bboxes1[..., 2] - bboxes1[..., 0]).clamp_(min=0) \
+        * (bboxes1[..., 3] - bboxes1[..., 1]).clamp_(min=0)
+    bboxes2_area = (bboxes2[..., 2] - bboxes2[..., 0]).clamp_(min=0) \
+        * (bboxes2[..., 3] - bboxes2[..., 1]).clamp_(min=0)
+
+    w_intersect = (torch.min(bboxes1[..., 2], bboxes2[..., 2])
+                   - torch.max(bboxes1[..., 0], bboxes2[..., 0])).clamp_(min=0)
+    h_intersect = (torch.min(bboxes1[..., 3], bboxes2[..., 3])
+                   - torch.max(bboxes1[..., 1], bboxes2[..., 1])).clamp_(min=0)
+
+    area_intersect = w_intersect * h_intersect
+    area_union = bboxes2_area + bboxes1_area - area_intersect
+    ious = area_intersect / area_union.clamp(min=eps)
+
+    if iou_type == "iou":
+        return ious
+    elif iou_type == "giou":
+        g_w_intersect = torch.max(bboxes1[..., 2], bboxes2[..., 2]) \
+            - torch.min(bboxes1[..., 0], bboxes2[..., 0])
+        g_h_intersect = torch.max(bboxes1[..., 3], bboxes2[..., 3]) \
+            - torch.min(bboxes1[..., 1], bboxes2[..., 1])
+        ac_uion = g_w_intersect * g_h_intersect
+        gious = ious - (ac_uion - area_union) / ac_uion.clamp(min=eps)
+        return gious
+    else:
+        raise NotImplementedError
+
+# copy from YOLOv5
+def bbox_iou(box1, box2, xywh=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-7):
+    # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
+
+    # Get the coordinates of bounding boxes
+    if xywh:  # transform from xywh to xyxy
+        (x1, y1, w1, h1), (x2, y2, w2, h2) = box1.chunk(4, -1), box2.chunk(4, -1)
+        w1_, h1_, w2_, h2_ = w1 / 2, h1 / 2, w2 / 2, h2 / 2
+        b1_x1, b1_x2, b1_y1, b1_y2 = x1 - w1_, x1 + w1_, y1 - h1_, y1 + h1_
+        b2_x1, b2_x2, b2_y1, b2_y2 = x2 - w2_, x2 + w2_, y2 - h2_, y2 + h2_
+    else:  # x1, y1, x2, y2 = box1
+        b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
+        b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
+        w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
+        w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
+
+    # Intersection area
+    inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * \
+            (b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)).clamp(0)
+
+    # Union Area
+    union = w1 * h1 + w2 * h2 - inter + eps
+
+    # IoU
+    iou = inter / union
+    if CIoU or DIoU or GIoU:
+        cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1)  # convex (smallest enclosing box) width
+        ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1)  # convex height
+        if CIoU or DIoU:  # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1
+            c2 = cw ** 2 + ch ** 2 + eps  # convex diagonal squared
+            rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4  # center dist ** 2
+            if CIoU:  # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
+                v = (4 / math.pi ** 2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
+                with torch.no_grad():
+                    alpha = v / (v - iou + (1 + eps))
+                return iou - (rho2 / c2 + v * alpha)  # CIoU
+            return iou - rho2 / c2  # DIoU
+        c_area = cw * ch + eps  # convex area
+        return iou - (c_area - union) / c_area  # GIoU https://arxiv.org/pdf/1902.09630.pdf
+    return iou  # IoU
+
+
+if __name__ == '__main__':
+    box1 = torch.tensor([[10, 10, 20, 20]])
+    box2 = torch.tensor([[15, 15, 20, 20]])
+    iou = box_iou(box1, box2)
+    print(iou)

+ 166 - 0
utils/distributed_utils.py

@@ -0,0 +1,166 @@
+# from github: https://github.com/ruinmessi/ASFF/blob/master/utils/distributed_util.py
+
+import torch
+import torch.distributed as dist
+import os
+import subprocess
+import pickle
+
+
+def all_gather(data):
+    """
+    Run all_gather on arbitrary picklable data (not necessarily tensors)
+    Args:
+        data: any picklable object
+    Returns:
+        list[data]: list of data gathered from each rank
+    """
+    world_size = get_world_size()
+    if world_size == 1:
+        return [data]
+
+    # serialized to a Tensor
+    buffer = pickle.dumps(data)
+    storage = torch.ByteStorage.from_buffer(buffer)
+    tensor = torch.ByteTensor(storage).to("cuda")
+
+    # obtain Tensor size of each rank
+    local_size = torch.tensor([tensor.numel()], device="cuda")
+    size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)]
+    dist.all_gather(size_list, local_size)
+    size_list = [int(size.item()) for size in size_list]
+    max_size = max(size_list)
+
+    # receiving Tensor from all ranks
+    # we pad the tensor because torch all_gather does not support
+    # gathering tensors of different shapes
+    tensor_list = []
+    for _ in size_list:
+        tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
+    if local_size != max_size:
+        padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
+        tensor = torch.cat((tensor, padding), dim=0)
+    dist.all_gather(tensor_list, tensor)
+
+    data_list = []
+    for size, tensor in zip(size_list, tensor_list):
+        buffer = tensor.cpu().numpy().tobytes()[:size]
+        data_list.append(pickle.loads(buffer))
+
+    return data_list
+
+
+def reduce_dict(input_dict, average=True):
+    """
+    Args:
+        input_dict (dict): all the values will be reduced
+        average (bool): whether to do average or sum
+    Reduce the values in the dictionary from all processes so that all processes
+    have the averaged results. Returns a dict with the same fields as
+    input_dict, after reduction.
+    """
+    world_size = get_world_size()
+    if world_size < 2:
+        return input_dict
+    with torch.no_grad():
+        names = []
+        values = []
+        # sort the keys so that they are consistent across processes
+        for k in sorted(input_dict.keys()):
+            names.append(k)
+            values.append(input_dict[k])
+        values = torch.stack(values, dim=0)
+        dist.all_reduce(values)
+        if average:
+            values /= world_size
+        reduced_dict = {k: v for k, v in zip(names, values)}
+    return reduced_dict
+
+
+def get_sha():
+    cwd = os.path.dirname(os.path.abspath(__file__))
+
+    def _run(command):
+        return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
+    sha = 'N/A'
+    diff = "clean"
+    branch = 'N/A'
+    try:
+        sha = _run(['git', 'rev-parse', 'HEAD'])
+        subprocess.check_output(['git', 'diff'], cwd=cwd)
+        diff = _run(['git', 'diff-index', 'HEAD'])
+        diff = "has uncommited changes" if diff else "clean"
+        branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
+    except Exception:
+        pass
+    message = f"sha: {sha}, status: {diff}, branch: {branch}"
+    return message
+
+
+def setup_for_distributed(is_master):
+    """
+    This function disables printing when not in master process
+    """
+    import builtins as __builtin__
+    builtin_print = __builtin__.print
+
+    def print(*args, **kwargs):
+        force = kwargs.pop('force', False)
+        if is_master or force:
+            builtin_print(*args, **kwargs)
+
+    __builtin__.print = print
+
+
+def is_dist_avail_and_initialized():
+    if not dist.is_available():
+        return False
+    if not dist.is_initialized():
+        return False
+    return True
+
+
+def get_world_size():
+    if not is_dist_avail_and_initialized():
+        return 1
+    return dist.get_world_size()
+
+
+def get_rank():
+    if not is_dist_avail_and_initialized():
+        return 0
+    return dist.get_rank()
+
+
+def is_main_process():
+    return get_rank() == 0
+
+
+def save_on_master(*args, **kwargs):
+    if is_main_process():
+        torch.save(*args, **kwargs)
+
+
+def init_distributed_mode(args):
+    if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
+        args.rank = int(os.environ["RANK"])
+        args.world_size = int(os.environ['WORLD_SIZE'])
+        args.gpu = int(os.environ['LOCAL_RANK'])
+    elif 'SLURM_PROCID' in os.environ:
+        args.rank = int(os.environ['SLURM_PROCID'])
+        args.gpu = args.rank % torch.cuda.device_count()
+    else:
+        print('Not using distributed mode')
+        args.distributed = False
+        return
+
+    args.distributed = True
+
+    torch.cuda.set_device(args.gpu)
+    args.dist_backend = 'nccl'
+    print('| distributed init (rank {}): {}'.format(
+        args.rank, args.dist_url), flush=True)
+    torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
+                                         world_size=args.world_size, rank=args.rank)
+    torch.distributed.barrier()
+    setup_for_distributed(args.rank == 0)

+ 574 - 0
utils/misc.py

@@ -0,0 +1,574 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.distributed as dist
+from torch.utils.data import DataLoader, DistributedSampler
+
+import cv2
+import math
+import time
+import datetime
+import numpy as np
+from copy import deepcopy
+from thop import profile
+from collections import defaultdict, deque
+
+from .distributed_utils import is_dist_avail_and_initialized
+
+
+# ---------------------------- Train tools ----------------------------
+class SmoothedValue(object):
+    """Track a series of values and provide access to smoothed values over a
+    window or the global series average.
+    """
+
+    def __init__(self, window_size=20, fmt=None):
+        if fmt is None:
+            fmt = "{median:.4f} ({global_avg:.4f})"
+        self.deque = deque(maxlen=window_size)
+        self.total = 0.0
+        self.count = 0
+        self.fmt = fmt
+
+    def update(self, value, n=1):
+        self.deque.append(value)
+        self.count += n
+        self.total += value * n
+
+    def synchronize_between_processes(self):
+        """
+        Warning: does not synchronize the deque!
+        """
+        if not is_dist_avail_and_initialized():
+            return
+        t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
+        dist.barrier()
+        dist.all_reduce(t)
+        t = t.tolist()
+        self.count = int(t[0])
+        self.total = t[1]
+
+    @property
+    def median(self):
+        d = torch.tensor(list(self.deque))
+        return d.median().item()
+
+    @property
+    def avg(self):
+        d = torch.tensor(list(self.deque), dtype=torch.float32)
+        return d.mean().item()
+
+    @property
+    def global_avg(self):
+        return self.total / self.count
+
+    @property
+    def max(self):
+        return max(self.deque)
+
+    @property
+    def value(self):
+        return self.deque[-1]
+
+    def __str__(self):
+        return self.fmt.format(
+            median=self.median,
+            avg=self.avg,
+            global_avg=self.global_avg,
+            max=self.max,
+            value=self.value)
+
+class MetricLogger(object):
+    def __init__(self, delimiter="\t"):
+        self.meters = defaultdict(SmoothedValue)
+        self.delimiter = delimiter
+
+    def update(self, **kwargs):
+        for k, v in kwargs.items():
+            if isinstance(v, torch.Tensor):
+                v = v.item()
+            assert isinstance(v, (float, int))
+            self.meters[k].update(v)
+
+    def __getattr__(self, attr):
+        if attr in self.meters:
+            return self.meters[attr]
+        if attr in self.__dict__:
+            return self.__dict__[attr]
+        raise AttributeError("'{}' object has no attribute '{}'".format(
+            type(self).__name__, attr))
+
+    def __str__(self):
+        loss_str = []
+        for name, meter in self.meters.items():
+            loss_str.append(
+                "{}: {}".format(name, str(meter))
+            )
+        return self.delimiter.join(loss_str)
+
+    def synchronize_between_processes(self):
+        for meter in self.meters.values():
+            meter.synchronize_between_processes()
+
+    def add_meter(self, name, meter):
+        self.meters[name] = meter
+
+    def log_every(self, iterable, print_freq, header=None):
+        i = 0
+        if not header:
+            header = ''
+        start_time = time.time()
+        end = time.time()
+        iter_time = SmoothedValue(fmt='{avg:.4f}')
+        data_time = SmoothedValue(fmt='{avg:.4f}')
+        space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
+        if torch.cuda.is_available():
+            log_msg = self.delimiter.join([
+                header,
+                '[{0' + space_fmt + '}/{1}]',
+                'eta: {eta}',
+                '{meters}',
+                'time: {time}',
+                'data: {data}',
+                'max mem: {memory:.0f}'
+            ])
+        else:
+            log_msg = self.delimiter.join([
+                header,
+                '[{0' + space_fmt + '}/{1}]',
+                'eta: {eta}',
+                '{meters}',
+                'time: {time}',
+                'data: {data}'
+            ])
+        MB = 1024.0 * 1024.0
+        for obj in iterable:
+            data_time.update(time.time() - end)
+            yield obj
+            iter_time.update(time.time() - end)
+            if i % print_freq == 0 or i == len(iterable) - 1:
+                eta_seconds = iter_time.global_avg * (len(iterable) - i)
+                eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
+                if torch.cuda.is_available():
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time),
+                        memory=torch.cuda.max_memory_allocated() / MB))
+                else:
+                    print(log_msg.format(
+                        i, len(iterable), eta=eta_string,
+                        meters=str(self),
+                        time=str(iter_time), data=str(data_time)))
+            i += 1
+            end = time.time()
+        total_time = time.time() - start_time
+        total_time_str = str(datetime.timedelta(seconds=int(total_time)))
+        print('{} Total time: {} ({:.4f} s / it)'.format(
+            header, total_time_str, total_time / len(iterable)))
+
+
+# ---------------------------- For Dataset ----------------------------
+## build dataloader
+def build_dataloader(args, dataset, batch_size, collate_fn=None):
+    # distributed
+    if args.distributed:
+        sampler = DistributedSampler(dataset)
+    else:
+        sampler = torch.utils.data.RandomSampler(dataset)
+
+    batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
+
+    dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
+                            collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
+    
+    return dataloader
+    
+## collate_fn for dataloader
+class CollateFunc(object):
+    def __call__(self, batch):
+        targets = []
+        images = []
+
+        for sample in batch:
+            image = sample[0]
+            target = sample[1]
+
+            images.append(image)
+            targets.append(target)
+
+        images = torch.stack(images, 0) # [B, C, H, W]
+
+        return images, targets
+
+
+# ---------------------------- For Loss ----------------------------
+## FocalLoss
+def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
+    """
+    Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
+    Args:
+        inputs: A float tensor of arbitrary shape.
+                The predictions for each example.
+        targets: A float tensor with the same shape as inputs. Stores the binary
+                 classification label for each element in inputs
+                (0 for the negative class and 1 for the positive class).
+        alpha: (optional) Weighting factor in range (0,1) to balance
+                positive vs negative examples. Default = -1 (no weighting).
+        gamma: Exponent of the modulating factor (1 - p_t) to
+               balance easy vs hard examples.
+    Returns:
+        Loss tensor
+    """
+    prob = inputs.sigmoid()
+    ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
+    p_t = prob * targets + (1 - prob) * (1 - targets)
+    loss = ce_loss * ((1 - p_t) ** gamma)
+
+    if alpha >= 0:
+        alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
+        loss = alpha_t * loss
+
+    return loss.mean(1).sum() / num_boxes
+
+## Variable FocalLoss
+def varifocal_loss_with_logits(pred_logits,
+                               gt_score,
+                               label,
+                               normalizer=1.0,
+                               alpha=0.75,
+                               gamma=2.0):
+    pred_score = F.sigmoid(pred_logits)
+    weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
+    loss = F.binary_cross_entropy_with_logits(
+        pred_logits, gt_score, weight=weight, reduction='none')
+    return loss.mean(1).sum() / normalizer
+
+## InverseSigmoid
+def inverse_sigmoid(x, eps=1e-5):
+    x = x.clamp(min=0, max=1)
+    x1 = x.clamp(min=eps)
+    x2 = (1 - x).clamp(min=eps)
+    return torch.log(x1/x2)
+
+
+# ---------------------------- For Model ----------------------------
+## fuse Conv & BN layer
+def fuse_conv_bn(module):
+    """Recursively fuse conv and bn in a module.
+    During inference, the functionary of batch norm layers is turned off
+    but only the mean and var alone channels are used, which exposes the
+    chance to fuse it with the preceding conv layers to save computations and
+    simplify network structures.
+    Args:
+        module (nn.Module): Module to be fused.
+    Returns:
+        nn.Module: Fused module.
+    """
+    last_conv = None
+    last_conv_name = None
+    
+    def _fuse_conv_bn(conv, bn):
+        """Fuse conv and bn into one module.
+        Args:
+            conv (nn.Module): Conv to be fused.
+            bn (nn.Module): BN to be fused.
+        Returns:
+            nn.Module: Fused module.
+        """
+        conv_w = conv.weight
+        conv_b = conv.bias if conv.bias is not None else torch.zeros_like(
+            bn.running_mean)
+
+        factor = bn.weight / torch.sqrt(bn.running_var + bn.eps)
+        conv.weight = nn.Parameter(conv_w *
+                                factor.reshape([conv.out_channels, 1, 1, 1]))
+        conv.bias = nn.Parameter((conv_b - bn.running_mean) * factor + bn.bias)
+        return conv
+    for name, child in module.named_children():
+        if isinstance(child,
+                      (nn.modules.batchnorm._BatchNorm, nn.SyncBatchNorm)):
+            if last_conv is None:  # only fuse BN that is after Conv
+                continue
+            fused_conv = _fuse_conv_bn(last_conv, child)
+            module._modules[last_conv_name] = fused_conv
+            # To reduce changes, set BN as Identity instead of deleting it.
+            module._modules[name] = nn.Identity()
+            last_conv = None
+        elif isinstance(child, nn.Conv2d):
+            last_conv = child
+            last_conv_name = name
+        else:
+            fuse_conv_bn(child)
+    return module
+
+## replace module
+def replace_module(module, replaced_module_type, new_module_type, replace_func=None) -> nn.Module:
+    """
+    Replace given type in module to a new type. mostly used in deploy.
+
+    Args:
+        module (nn.Module): model to apply replace operation.
+        replaced_module_type (Type): module type to be replaced.
+        new_module_type (Type)
+        replace_func (function): python function to describe replace logic. Defalut value None.
+
+    Returns:
+        model (nn.Module): module that already been replaced.
+    """
+
+    def default_replace_func(replaced_module_type, new_module_type):
+        return new_module_type()
+
+    if replace_func is None:
+        replace_func = default_replace_func
+
+    model = module
+    if isinstance(module, replaced_module_type):
+        model = replace_func(replaced_module_type, new_module_type)
+    else:  # recurrsively replace
+        for name, child in module.named_children():
+            new_child = replace_module(child, replaced_module_type, new_module_type)
+            if new_child is not child:  # child is already replaced
+                model.add_module(name, new_child)
+
+    return model
+
+## compute FLOPs & Parameters
+def compute_flops(model, img_size, device):
+    x = torch.randn(1, 3, img_size, img_size).to(device)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))
+
+## load trained weight
+def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_rep_conv=False):
+    # check ckpt file
+    if path_to_ckpt is None:
+        print('no weight file ...')
+    else:
+        checkpoint = torch.load(path_to_ckpt, map_location='cpu')
+        print('--------------------------------------')
+        print('Best model infor:')
+        print('Epoch: {}'.format(checkpoint["epoch"]))
+        print('mAP: {}'.format(checkpoint["mAP"]))
+        print('--------------------------------------')
+        checkpoint_state_dict = checkpoint["model"]
+        model.load_state_dict(checkpoint_state_dict)
+
+        print('Finished loading model!')
+
+    # fuse rep conv
+    if fuse_rep_conv:
+        print("Fusing RepConv ...")
+        for m in model.modules():
+            if hasattr(m, 'fuse_convs'):
+                m.fuse_convs()
+
+    # fuse conv & bn
+    if fuse_cbn:
+        print('Fusing Conv & BN ...')
+        model = fuse_conv_bn(model)
+
+    return model
+
+## Model EMA
+class ModelEMA(object):
+    def __init__(self, model, ema_decay=0.9999, ema_tau=2000, updates=0):
+        # Create EMA
+        self.ema = deepcopy(self.de_parallel(model)).eval()  # FP32 EMA
+        self.updates = updates  # number of EMA updates
+        self.decay = lambda x: ema_decay * (1 - math.exp(-x / ema_tau))  # decay exponential ramp (to help early epochs)
+        for p in self.ema.parameters():
+            p.requires_grad_(False)
+
+    def is_parallel(self, model):
+        # Returns True if model is of type DP or DDP
+        return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
+
+    def de_parallel(self, model):
+        # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
+        return model.module if self.is_parallel(model) else model
+
+    def copy_attr(self, a, b, include=(), exclude=()):
+        # Copy attributes from b to a, options to only include [...] and to exclude [...]
+        for k, v in b.__dict__.items():
+            if (len(include) and k not in include) or k.startswith('_') or k in exclude:
+                continue
+            else:
+                setattr(a, k, v)
+
+    def update(self, model):
+        # Update EMA parameters
+        self.updates += 1
+        d = self.decay(self.updates)
+
+        msd = self.de_parallel(model).state_dict()  # model state_dict
+        for k, v in self.ema.state_dict().items():
+            if v.dtype.is_floating_point:  # true for FP16 and FP32
+                v *= d
+                v += (1 - d) * msd[k].detach()
+        # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
+
+    def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
+        # Update EMA attributes
+        self.copy_attr(self.ema, model, include, exclude)
+
+## SiLU
+class SiLU(nn.Module):
+    """export-friendly version of nn.SiLU()"""
+
+    @staticmethod
+    def forward(x):
+        return x * torch.sigmoid(x)
+
+
+# ---------------------------- NMS ----------------------------
+## basic NMS
+def nms(bboxes, scores, nms_thresh):
+    """"Pure Python NMS."""
+    x1 = bboxes[:, 0]  #xmin
+    y1 = bboxes[:, 1]  #ymin
+    x2 = bboxes[:, 2]  #xmax
+    y2 = bboxes[:, 3]  #ymax
+
+    areas = (x2 - x1) * (y2 - y1)
+    order = scores.argsort()[::-1]
+
+    keep = []
+    while order.size > 0:
+        i = order[0]
+        keep.append(i)
+        # compute iou
+        xx1 = np.maximum(x1[i], x1[order[1:]])
+        yy1 = np.maximum(y1[i], y1[order[1:]])
+        xx2 = np.minimum(x2[i], x2[order[1:]])
+        yy2 = np.minimum(y2[i], y2[order[1:]])
+
+        w = np.maximum(1e-10, xx2 - xx1)
+        h = np.maximum(1e-10, yy2 - yy1)
+        inter = w * h
+
+        iou = inter / (areas[i] + areas[order[1:]] - inter + 1e-14)
+        #reserve all the boundingbox whose ovr less than thresh
+        inds = np.where(iou <= nms_thresh)[0]
+        order = order[inds + 1]
+
+    return keep
+
+## class-agnostic NMS 
+def multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh):
+    # nms
+    keep = nms(bboxes, scores, nms_thresh)
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+## class-aware NMS 
+def multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes):
+    # nms
+    keep = np.zeros(len(bboxes), dtype=np.int32)
+    for i in range(num_classes):
+        inds = np.where(labels == i)[0]
+        if len(inds) == 0:
+            continue
+        c_bboxes = bboxes[inds]
+        c_scores = scores[inds]
+        c_keep = nms(c_bboxes, c_scores, nms_thresh)
+        keep[inds[c_keep]] = 1
+    keep = np.where(keep > 0)
+    scores = scores[keep]
+    labels = labels[keep]
+    bboxes = bboxes[keep]
+
+    return scores, labels, bboxes
+
+## multi-class NMS 
+def multiclass_nms(scores, labels, bboxes, nms_thresh, num_classes, class_agnostic=False):
+    if class_agnostic:
+        return multiclass_nms_class_agnostic(scores, labels, bboxes, nms_thresh)
+    else:
+        return multiclass_nms_class_aware(scores, labels, bboxes, nms_thresh, num_classes)
+
+
+# ---------------------------- Processor for Deployment ----------------------------
+## Pre-processer
+class PreProcessor(object):
+    def __init__(self, img_size, keep_ratio=True):
+        self.img_size = img_size
+        self.keep_ratio = keep_ratio
+        self.input_size = [img_size, img_size]
+        
+
+    def __call__(self, image, swap=(2, 0, 1)):
+        """
+        Input:
+            image: (ndarray) [H, W, 3] or [H, W]
+            formar: color format
+        """
+        if len(image.shape) == 3:
+            padded_img = np.ones((self.input_size[0], self.input_size[1], 3), np.float32) * 114.
+        else:
+            padded_img = np.ones(self.input_size, np.float32) * 114.
+        # resize
+        if self.keep_ratio:
+            orig_h, orig_w = image.shape[:2]
+            r = min(self.input_size[0] / orig_h, self.input_size[1] / orig_w)
+            resize_size = (int(orig_w * r), int(orig_h * r))
+            if r != 1:
+                resized_img = cv2.resize(image, resize_size, interpolation=cv2.INTER_LINEAR)
+            else:
+                resized_img = image
+
+            # padding
+            padded_img[:resized_img.shape[0], :resized_img.shape[1]] = resized_img
+            
+            # [H, W, C] -> [C, H, W]
+            padded_img = padded_img.transpose(swap)
+            padded_img = np.ascontiguousarray(padded_img, dtype=np.float32) / 255.
+
+            return padded_img, r
+        else:
+            orig_h, orig_w = image.shape[:2]
+            r = np.array([self.input_size[0] / orig_w, self.input_size[1] / orig_w])
+            if [orig_h, orig_w] == self.input_size:
+                resized_img = image
+            else:
+                resized_img = cv2.resize(image, self.input_size, interpolation=cv2.INTER_LINEAR)
+
+            return resized_img, r
+
+## Post-processer
+class PostProcessor(object):
+    def __init__(self, num_classes, conf_thresh=0.15, nms_thresh=0.5):
+        self.num_classes = num_classes
+        self.conf_thresh = conf_thresh
+        self.nms_thresh = nms_thresh
+
+
+    def __call__(self, predictions):
+        """
+        Input:
+            predictions: (ndarray) [n_anchors_all, 4+1+C]
+        """
+        bboxes = predictions[..., :4]
+        scores = predictions[..., 4:]
+
+        # scores & labels
+        labels = np.argmax(scores, axis=1)                      # [M,]
+        scores = scores[(np.arange(scores.shape[0]), labels)]   # [M,]
+
+        # thresh
+        keep = np.where(scores > self.conf_thresh)
+        scores = scores[keep]
+        labels = labels[keep]
+        bboxes = bboxes[keep]
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes, True)
+
+        return bboxes, scores, labels

+ 0 - 0
utils/solver/__init__.py


+ 49 - 0
utils/solver/lr_scheduler.py

@@ -0,0 +1,49 @@
+import math
+import torch
+
+
+# ------------------------- WarmUp LR Scheduler -------------------------
+## Warmup LR Scheduler
+class LinearWarmUpLrScheduler(object):
+    def __init__(self, base_lr=0.01, wp_iter=500):
+        self.base_lr = base_lr
+        self.wp_iter = wp_iter
+        self.warmup_factor = 0.00066667
+
+    def set_lr(self, optimizer, cur_lr):
+        for param_group in optimizer.param_groups:
+            init_lr = param_group['initial_lr']
+            ratio = init_lr / self.base_lr
+            param_group['lr'] = cur_lr * ratio
+
+    def __call__(self, iter, optimizer):
+        # warmup
+        assert iter < self.wp_iter
+        alpha = iter / self.wp_iter
+        warmup_factor = self.warmup_factor * (1 - alpha) + alpha
+        tmp_lr = self.base_lr * warmup_factor
+        self.set_lr(optimizer, tmp_lr)
+        
+                           
+# ------------------------- LR Scheduler -------------------------
+def build_lr_scheduler(cfg, optimizer, resume=None):
+    print('==============================')
+    print('LR Scheduler: {}'.format(cfg.lr_scheduler))
+
+    if cfg.lr_scheduler == "step":
+        lr_step = [cfg.max_epoch // 3, cfg.max_epoch // 3 * 2]
+        lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_step, gamma=0.1)
+    elif cfg.lr_scheduler == "cosine":
+        lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg.max_epoch - cfg.warmup_epoch - 1, eta_min=cfg.min_lr)
+    else:
+        raise NotImplementedError("Unknown lr scheduler: {}".format(cfg.lr_scheduler))
+        
+    if resume is not None and resume.lower() != "none":
+        checkpoint = torch.load(resume)
+        if 'lr_scheduler' in checkpoint.keys():
+            print('--Load lr scheduler from the checkpoint: ', resume)
+            # checkpoint state dict
+            checkpoint_state_dict = checkpoint.pop("lr_scheduler")
+            lr_scheduler.load_state_dict(checkpoint_state_dict)
+
+    return lr_scheduler

+ 104 - 0
utils/solver/optimizer.py

@@ -0,0 +1,104 @@
+import torch
+
+
+def build_yolo_optimizer(cfg, model, resume=None):
+    print('==============================')
+    print('Optimizer: {}'.format(cfg.optimizer))
+    print('--base lr: {}'.format(cfg.base_lr))
+    print('--min lr:  {}'.format(cfg.min_lr))
+    print('--momentum: {}'.format(cfg.momentum))
+    print('--weight_decay: {}'.format(cfg.weight_decay))
+    print('--grad accumulate: {}'.format(cfg.grad_accumulate))
+
+    # ------------- Divide model's parameters -------------
+    param_dicts = [], [], []
+    norm_names = ["norm"] + ["norm{}".format(i) for i in range(10000)]
+    for n, p in model.named_parameters():
+        if p.requires_grad:
+            if "bias" == n.split(".")[-1]:
+                param_dicts[0].append(p)      # no weight decay for all layers' bias
+            else:
+                if n.split(".")[-2] in norm_names:
+                    param_dicts[1].append(p)  # no weight decay for all NormLayers' weight
+                else:
+                    param_dicts[2].append(p)  # weight decay for all Non-NormLayers' weight
+
+    # Build optimizer
+    if cfg.optimizer == 'sgd':
+        optimizer = torch.optim.SGD(param_dicts[0], lr=cfg.base_lr, momentum=cfg.momentum, weight_decay=0.0)
+    elif cfg.optimizer =='adamw':
+        optimizer = torch.optim.AdamW(param_dicts[0], lr=cfg.base_lr, weight_decay=0.0)
+    else:
+        raise NotImplementedError("Unknown optimizer: {}".format(cfg.optimizer))
+    
+    # Add param groups
+    optimizer.add_param_group({"params": param_dicts[1], "weight_decay": 0.0})
+    optimizer.add_param_group({"params": param_dicts[2], "weight_decay": cfg.weight_decay})
+
+    start_epoch = 0
+    if resume and resume != 'None':
+        checkpoint = torch.load(resume)
+        # checkpoint state dict
+        try:
+            checkpoint_state_dict = checkpoint.pop("optimizer")
+            print('--Load optimizer from the checkpoint: ', resume)
+            optimizer.load_state_dict(checkpoint_state_dict)
+            start_epoch = checkpoint.pop("epoch") + 1
+            del checkpoint, checkpoint_state_dict
+        except:
+            print("No optimzier in the given checkpoint.")
+                                                        
+    return optimizer, start_epoch
+
+
+def build_rtdetr_optimizer(cfg, model, resume=None):
+    print('==============================')
+    print('Optimizer: {}'.format(cfg.optimizer))
+    print('--base lr: {}'.format(cfg.base_lr))
+    print('--weight_decay: {}'.format(cfg.weight_decay))
+    print('--grad accumulate: {}'.format(cfg.grad_accumulate))
+
+    # ------------- Divide model's parameters -------------
+    param_dicts = [], [], [], [], [], []
+    norm_names = ["norm"] + ["norm{}".format(i) for i in range(10000)]
+    for n, p in model.named_parameters():
+        # Non-Backbone's learnable parameters
+        if "backbone" not in n and p.requires_grad:
+            if "bias" == n.split(".")[-1]:
+                param_dicts[0].append(p)      # no weight decay for all layers' bias
+            else:
+                if n.split(".")[-2] in norm_names:
+                    param_dicts[1].append(p)  # no weight decay for all NormLayers' weight
+                else:
+                    param_dicts[2].append(p)  # weight decay for all Non-NormLayers' weight
+        # Backbone's learnable parameters
+        elif "backbone" in n and p.requires_grad:
+            if "bias" == n.split(".")[-1]:
+                param_dicts[3].append(p)      # no weight decay for all layers' bias
+            else:
+                if n.split(".")[-2] in norm_names:
+                    param_dicts[4].append(p)  # no weight decay for all NormLayers' weight
+                else:
+                    param_dicts[5].append(p)  # weight decay for all Non-NormLayers' weight
+
+    # Non-Backbone's learnable parameters
+    optimizer = torch.optim.AdamW(param_dicts[0], lr=cfg.base_lr, weight_decay=0.0)
+    optimizer.add_param_group({"params": param_dicts[1], "weight_decay": 0.0})
+    optimizer.add_param_group({"params": param_dicts[2], "weight_decay": cfg.weight_decay})
+
+    # Backbone's learnable parameters
+    backbone_lr = cfg.base_lr * cfg.backbone_lr_ratio
+    optimizer.add_param_group({"params": param_dicts[3], "lr": backbone_lr, "weight_decay": 0.0})
+    optimizer.add_param_group({"params": param_dicts[4], "lr": backbone_lr, "weight_decay": 0.0})
+    optimizer.add_param_group({"params": param_dicts[5], "lr": backbone_lr, "weight_decay": cfg.weight_decay})
+
+    start_epoch = 0
+    if resume and resume != 'None':
+        print('--Load optimizer from the checkpoint: ', resume)
+        checkpoint = torch.load(resume)
+        # checkpoint state dict
+        checkpoint_state_dict = checkpoint.pop("optimizer")
+        optimizer.load_state_dict(checkpoint_state_dict)
+        start_epoch = checkpoint.pop("epoch") + 1
+                                                        
+    return optimizer, start_epoch

+ 169 - 0
utils/vis_tools.py

@@ -0,0 +1,169 @@
+import cv2
+import os
+import numpy as np
+import matplotlib.pyplot as plt
+
+
+# -------------------------- For Detection Task --------------------------
+## Draw bbox & label on the image
+def plot_bbox_labels(img, bbox, label=None, cls_color=None, text_scale=0.4):
+    x1, y1, x2, y2 = bbox
+    x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
+    t_size = cv2.getTextSize(label, 0, fontScale=1, thickness=2)[0]
+    # plot bbox
+    cv2.rectangle(img, (x1, y1), (x2, y2), cls_color, 2)
+    
+    if label is not None:
+        # plot title bbox
+        cv2.rectangle(img, (x1, y1-t_size[1]), (int(x1 + t_size[0] * text_scale), y1), cls_color, -1)
+        # put the test on the title bbox
+        cv2.putText(img, label, (int(x1), int(y1 - 5)), 0, text_scale, (0, 0, 0), 1, lineType=cv2.LINE_AA)
+
+    return img
+
+## Visualize the detection results
+def visualize(image, bboxes, scores, labels, class_colors, class_names):
+    ts = 0.4
+    for i, bbox in enumerate(bboxes):
+        cls_id = int(labels[i])
+        cls_color = class_colors[cls_id]
+            
+        mess = '%s: %.2f' % (class_names[cls_id], scores[i])
+        image = plot_bbox_labels(image, bbox, mess, cls_color, text_scale=ts)
+
+    return image
+        
+## Visualize the input data during the training stage
+def vis_data(images, targets, num_classes=80, normalized_bbox=False, color_format='bgr', pixel_mean=None, pixel_std=None, box_format="xyxy"):
+    """
+        images: (tensor) [B, 3, H, W]
+        targets: (list) a list of targets
+    """
+    batch_size = images.size(0)
+    np.random.seed(0)
+    class_colors = [(np.random.randint(255),
+                     np.random.randint(255),
+                     np.random.randint(255)) for _ in range(num_classes)]
+
+    for bi in range(batch_size):
+        tgt_boxes = targets[bi]['boxes']
+        tgt_labels = targets[bi]['labels']
+        # to numpy
+        image = images[bi].permute(1, 2, 0).cpu().numpy()
+
+        # denormalize image
+        if pixel_mean is not None and pixel_std is not None:
+            image = image * pixel_std + pixel_mean
+        
+        if color_format == 'rgb':
+            image = image[..., (2, 1, 0)] # RGB to BGR
+            
+        image = image.astype(np.uint8)
+        image = image.copy()
+        img_h, img_w = image.shape[:2]
+
+        # visualize target
+        for box, label in zip(tgt_boxes, tgt_labels):
+            if box_format == "xyxy":
+                x1, y1, x2, y2 = box
+            elif box_format == "xywh":
+                cx, cy, bw, bh = box
+                x1 = cx - 0.5 * bw
+                y1 = cy - 0.5 * bh
+                x2 = cx + 0.5 * bw
+                y2 = cy + 0.5 * bh
+
+            if normalized_bbox:
+                x1 *= img_w
+                y1 *= img_h
+                x2 *= img_w
+                y2 *= img_h
+
+            x1, y1 = int(x1), int(y1)
+            x2, y2 = int(x2), int(y2)
+            cls_id = int(label)
+
+            # draw box
+            color = class_colors[cls_id]
+            cv2.rectangle(image, (x1, y1), (x2, y2), color, 2)
+
+        cv2.imshow('train target', image)
+        cv2.waitKey(0)
+
+## convert feature to he heatmap
+def convert_feature_heatmap(feature):
+    """
+        feature: (ndarray) [H, W, C]
+    """
+    heatmap = None
+
+    return heatmap
+
+## draw feature on the image
+def draw_feature(img, features, save=None):
+    """
+        img: (ndarray & cv2.Mat) [H, W, C], where the C is 3 for RGB or 1 for Gray.
+        features: (List[ndarray]). It is a list of the multiple feature map whose shape is [H, W, C].
+        save: (bool) save the result or not.
+    """
+    img_h, img_w = img.shape[:2]
+
+    for i, fmp in enumerate(features):
+        hmp = convert_feature_heatmap(fmp)
+        hmp = cv2.resize(hmp, (img_w, img_h))
+        hmp = hmp.astype(np.uint8)*255
+        hmp_rgb = cv2.applyColorMap(hmp, cv2.COLORMAP_JET)
+        
+        superimposed_img = hmp_rgb * 0.4 + img 
+
+        # show the heatmap
+        plt.imshow(hmp)
+        plt.close()
+
+        # show the image with heatmap
+        cv2.imshow("image with heatmap", superimposed_img)
+        cv2.waitKey(0)
+        cv2.destroyAllWindows()
+
+        if save:
+            save_dir = 'feature_heatmap'
+            os.makedirs(save_dir, exist_ok=True)
+            cv2.imwrite(os.path.join(save_dir, 'feature_{}.png'.format(i) ), superimposed_img)    
+
+
+# -------------------------- For Tracking Task --------------------------
+def get_color(idx):
+    idx = idx * 3
+    color = ((37 * idx) % 255, (17 * idx) % 255, (29 * idx) % 255)
+
+    return color
+
+def plot_tracking(image, tlwhs, obj_ids, scores=None, frame_id=0, fps=0., ids2=None):
+    im = np.ascontiguousarray(np.copy(image))
+    im_h, im_w = im.shape[:2]
+
+    top_view = np.zeros([im_w, im_w, 3], dtype=np.uint8) + 255
+
+    #text_scale = max(1, image.shape[1] / 1600.)
+    #text_thickness = 2
+    #line_thickness = max(1, int(image.shape[1] / 500.))
+    text_scale = 2
+    text_thickness = 2
+    line_thickness = 3
+
+    radius = max(5, int(im_w/140.))
+    cv2.putText(im, 'frame: %d fps: %.2f num: %d' % (frame_id, fps, len(tlwhs)),
+                (0, int(15 * text_scale)), cv2.FONT_HERSHEY_PLAIN, 2, (0, 0, 255), thickness=2)
+
+    for i, tlwh in enumerate(tlwhs):
+        x1, y1, w, h = tlwh
+        intbox = tuple(map(int, (x1, y1, x1 + w, y1 + h)))
+        obj_id = int(obj_ids[i])
+        id_text = '{}'.format(int(obj_id))
+        if ids2 is not None:
+            id_text = id_text + ', {}'.format(int(ids2[i]))
+        color = get_color(abs(obj_id))
+        cv2.rectangle(im, intbox[0:2], intbox[2:4], color=color, thickness=line_thickness)
+        cv2.putText(im, id_text, (intbox[0], intbox[1]), cv2.FONT_HERSHEY_PLAIN, text_scale, (0, 0, 255),
+                    thickness=text_thickness)
+    return im