浏览代码

add RT-DETR transform

yjh0410 1 年之前
父节点
当前提交
b77da8e973

+ 27 - 0
config/data_config/transform_config.py

@@ -19,6 +19,7 @@ yolov5_x_trans_config = {
     'mixup_prob': 0.2,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolov5_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
@@ -39,6 +40,7 @@ yolov5_l_trans_config = {
     'mixup_prob': 0.15,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolov5_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
@@ -59,6 +61,7 @@ yolov5_m_trans_config = {
     'mixup_prob': 0.10,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolov5_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
@@ -79,6 +82,7 @@ yolov5_s_trans_config = {
     'mixup_prob': 0.0,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolov5_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
@@ -99,6 +103,7 @@ yolov5_n_trans_config = {
     'mixup_prob': 0.0,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolov5_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
@@ -119,6 +124,7 @@ yolov5_p_trans_config = {
     'mixup_prob': 0.0,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolov5_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
@@ -141,6 +147,7 @@ yolox_x_trans_config = {
     'mixup_prob': 1.0,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolox_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]
 }
 
@@ -161,6 +168,7 @@ yolox_l_trans_config = {
     'mixup_prob': 1.0,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolox_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
@@ -181,6 +189,7 @@ yolox_m_trans_config = {
     'mixup_prob': 1.0,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolox_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
@@ -201,6 +210,7 @@ yolox_s_trans_config = {
     'mixup_prob': 1.0,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolox_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
@@ -221,6 +231,7 @@ yolox_n_trans_config = {
     'mixup_prob': 0.5,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolox_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
@@ -241,6 +252,7 @@ yolox_p_trans_config = {
     'mixup_prob': 0.0,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolox_mixup',
+    'mosaic_keep_ratio': True,
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
 
@@ -254,5 +266,20 @@ ssd_trans_config = {
     'mixup_prob': 0.,
     'mosaic_type': 'yolov5_mosaic',
     'mixup_type': 'yolov5_mixup',
+    'mosaic_keep_ratio': False,
+    'mixup_scale': [0.5, 1.5]
+}
+
+
+# ----------------------- SSD-Style Transform -----------------------
+rtdetr_trans_config = {
+    'aug_type': 'rtdetr',
+    'use_ablu': False,
+    # Mosaic & Mixup are not used for RT_DETR-style augmentation
+    'mosaic_prob': 0.,
+    'mixup_prob': 0.,
+    'mosaic_type': 'yolov5_mosaic',
+    'mixup_type': 'yolov5_mixup',
+    'mosaic_keep_ratio': False,
     'mixup_scale': [0.5, 1.5]
 }

+ 21 - 11
dataset/build.py

@@ -1,22 +1,28 @@
 import os
 
 try:
+    # dataset class
     from .voc import VOCDataset
     from .coco import COCODataset
     from .crowdhuman import CrowdHumanDataset
     from .widerface import WiderFaceDataset
     from .customed import CustomedDataset
+    # transform class
     from .data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
     from .data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
+    from .data_augment.rtdetr_augment import RTDetrAugmentation, RTDetrBaseTransform
 
 except:
+    # dataset class
     from voc import VOCDataset
     from coco import COCODataset
     from crowdhuman import CrowdHumanDataset
     from widerface import WiderFaceDataset
     from customed import CustomedDataset
+    # transform class
     from data_augment.ssd_augment import SSDAugmentation, SSDBaseTransform
     from data_augment.yolov5_augment import YOLOv5Augmentation, YOLOv5BaseTransform
+    from data_augment.rtdetr_augment import RTDetrAugmentation, RTDetrBaseTransform
 
 
 # ------------------------------ Dataset ------------------------------
@@ -92,28 +98,23 @@ def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
 
 # ------------------------------ Transform ------------------------------
 def build_transform(args, trans_config, max_stride=32, is_train=False):
-    # Modify trans_config
+    # ---------------- Modify trans_config ----------------
     if is_train:
         ## mosaic prob.
         if args.mosaic is not None:
-            trans_config['mosaic_prob']=args.mosaic if is_train else 0.0
-        else:
-            trans_config['mosaic_prob']=trans_config['mosaic_prob'] if is_train else 0.0
+            trans_config['mosaic_prob'] = args.mosaic
         ## mixup prob.
         if args.mixup is not None:
-            trans_config['mixup_prob']=args.mixup if is_train else 0.0
-        else:
-            trans_config['mixup_prob']=trans_config['mixup_prob']  if is_train else 0.0
+            trans_config['mixup_prob'] = args.mixup
 
-    # Transform
+    # ---------------- Build transform ----------------
+    ## SSD-style transform
     if trans_config['aug_type'] == 'ssd':
         if is_train:
             transform = SSDAugmentation(img_size=args.img_size,)
         else:
             transform = SSDBaseTransform(img_size=args.img_size,)
-        trans_config['mosaic_prob'] = 0.0
-        trans_config['mixup_prob'] = 0.0
-
+    ## YOLO-style transform
     elif trans_config['aug_type'] == 'yolov5':
         if is_train:
             transform = YOLOv5Augmentation(
@@ -126,5 +127,14 @@ def build_transform(args, trans_config, max_stride=32, is_train=False):
                 img_size=args.img_size,
                 max_stride=max_stride
                 )
+    ## RT_DETR-style transform
+    elif trans_config['aug_type'] == 'rtdetr':
+        if is_train:
+            use_mosaic = False if trans_config['mosaic_prob'] < 0.2 else True
+            transform = RTDetrAugmentation(
+                img_size=args.img_size, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], use_mosaic=use_mosaic)
+        else:
+            transform = RTDetrBaseTransform(
+                img_size=args.img_size, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375])
 
     return transform, trans_config

+ 14 - 4
dataset/coco.py

@@ -125,7 +125,7 @@ class COCODataset(Dataset):
         # Mosaic
         if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
             image, target = yolov5_mosaic_augment(
-                image_list, target_list, self.img_size, self.trans_config, self.is_train)
+                image_list, target_list, self.img_size, self.trans_config, self.trans_config['mosaic_keep_ratio'], self.is_train)
 
         return image, target
 
@@ -253,7 +253,7 @@ if __name__ == "__main__":
     parser.add_argument('-size', '--img_size', default=640, type=int,
                         help='input image size.')
     parser.add_argument('--aug_type', type=str, default='ssd',
-                        help='augmentation type')
+                        help='augmentation type: ssd, yolov5, rtdetr.')
     parser.add_argument('--mosaic', default=0., type=float,
                         help='mosaic augmentation.')
     parser.add_argument('--mixup', default=0., type=float,
@@ -284,10 +284,13 @@ if __name__ == "__main__":
         'mixup_prob': args.mixup,
         'mosaic_type': 'yolov5_mosaic',
         'mixup_type': args.mixup_type,   # optional: yolov5_mixup, yolox_mixup
+        'mosaic_keep_ratio': False,
         'mixup_scale': [0.5, 1.5]
     }
-
     transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
+    pixel_mean = transform.pixel_mean
+    pixel_std  = transform.pixel_std
+    color_format = transform.color_format
 
     dataset = COCODataset(
         img_size=args.img_size,
@@ -312,6 +315,13 @@ if __name__ == "__main__":
 
         # to numpy
         image = image.permute(1, 2, 0).numpy()
+        
+        # denormalize
+        image = image * pixel_std + pixel_mean
+        if color_format == 'rgb':
+            # RGB to BGR
+            image = image[..., (2, 1, 0)]
+
         # to uint8
         image = image.astype(np.uint8)
         image = image.copy()
@@ -326,7 +336,7 @@ if __name__ == "__main__":
             color = class_colors[cls_id]
             # class name
             label = coco_class_labels[coco_class_index[cls_id]]
-            image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
+            image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
             # put the test on the bbox
             cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
         cv2.imshow('gt', image)

+ 12 - 2
dataset/crowdhuman.py

@@ -77,7 +77,7 @@ class CrowdHumanDataset(Dataset):
         # Mosaic
         if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
             image, target = yolov5_mosaic_augment(
-                image_list, target_list, self.img_size, self.trans_config, self.is_train)
+                image_list, target_list, self.img_size, self.trans_config, self.trans_config['mosaic_keep_ratio'], self.is_train)
 
         return image, target
 
@@ -219,10 +219,13 @@ if __name__ == "__main__":
         'mixup_prob': args.mixup,
         'mosaic_type': 'yolov5_mosaic',
         'mixup_type': args.mixup_type,   # optional: yolov5_mixup, yolox_mixup
+        'mosaic_keep_ratio': False,
         'mixup_scale': [0.5, 1.5]
     }
-
     transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
+    pixel_mean = transform.pixel_mean
+    pixel_std  = transform.pixel_std
+    color_format = transform.color_format
 
     dataset = CrowdHumanDataset(
         img_size=args.img_size,
@@ -245,6 +248,13 @@ if __name__ == "__main__":
 
         # to numpy
         image = image.permute(1, 2, 0).numpy()
+        
+        # denormalize
+        image = image * pixel_std + pixel_mean
+        if color_format == 'rgb':
+            # RGB to BGR
+            image = image[..., (2, 1, 0)]
+
         # to uint8
         image = image.astype(np.uint8)
         image = image.copy()

+ 19 - 7
dataset/customed.py

@@ -120,7 +120,7 @@ class CustomedDataset(Dataset):
         # Mosaic
         if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
             image, target = yolov5_mosaic_augment(
-                image_list, target_list, self.img_size, self.trans_config, self.is_train)
+                image_list, target_list, self.img_size, self.trans_config, self.trans_config['mosaic_keep_ratio'], self.is_train)
 
         return image, target
 
@@ -262,25 +262,29 @@ if __name__ == "__main__":
     args = parser.parse_args()
 
     trans_config = {
-        'aug_type': 'yolov5',  # optional: ssd, yolov5
+        'aug_type': args.aug_type,    # optional: ssd, yolov5
         # Basic Augment
         'degrees': 0.0,
         'translate': 0.2,
-        'scale': [0.5, 2.0],
+        'scale': [0.1, 2.0],
         'shear': 0.0,
         'perspective': 0.0,
         'hsv_h': 0.015,
         'hsv_s': 0.7,
         'hsv_v': 0.4,
+        'use_ablu': True,
         # Mosaic & Mixup
-        'mosaic_prob': 1.0,
-        'mixup_prob': 1.0,
+        'mosaic_prob': args.mosaic,
+        'mixup_prob': args.mixup,
         'mosaic_type': 'yolov5_mosaic',
-        'mixup_type': 'yolov5_mixup',
+        'mixup_type': args.mixup_type,   # optional: yolov5_mixup, yolox_mixup
+        'mosaic_keep_ratio': False,
         'mixup_scale': [0.5, 1.5]
     }
-
     transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
+    pixel_mean = transform.pixel_mean
+    pixel_std  = transform.pixel_std
+    color_format = transform.color_format
 
     dataset = CustomedDataset(
         img_size=args.img_size,
@@ -305,6 +309,14 @@ if __name__ == "__main__":
 
         # to numpy
         image = image.permute(1, 2, 0).numpy()
+        
+        # denormalize
+        image = image * pixel_std + pixel_mean
+        if color_format == 'rgb':
+            # RGB to BGR
+            image = image[..., (2, 1, 0)]
+
+        # to uint8
         image = image.astype(np.uint8)
         image = image.copy()
         img_h, img_w = image.shape[:2]

+ 355 - 11
dataset/data_augment/rtdetr_augment.py

@@ -1,23 +1,367 @@
+# ------------------------------------------------------------
 # Data preprocessor for Real-time DETR
+# ------------------------------------------------------------
+import cv2
+import numpy as np
+from numpy import random
+
+import torch
+import torch.nn.functional as F
 
 
 # ------------------------- Augmentations -------------------------
+class Compose(object):
+    """Composes several augmentations together.
+    Args:
+        transforms (List[Transform]): list of transforms to compose.
+    Example:
+        >>> augmentations.Compose([
+        >>>     transforms.CenterCrop(10),
+        >>>     transforms.ToTensor(),
+        >>> ])
+    """
+
+    def __init__(self, transforms):
+        self.transforms = transforms
+
+    def __call__(self, image, target=None):
+        for t in self.transforms:
+            image, target = t(image, target)
+        return image, target
+
+## Convert color format
+class ConvertColorFormat(object):
+    def __init__(self, color_format='rgb'):
+        self.color_format = color_format
+
+    def __call__(self, image, target=None):
+        """
+        Input:
+            image: (np.array) a OpenCV image with BGR color format.
+            target: None
+        Output:
+            image: (np.array) a OpenCV image with given color format.
+            target: None
+        """
+        # Convert color format
+        if self.color_format == 'rgb':
+            image = image[..., (2, 1, 0)]    # BGR -> RGB
+        elif self.color_format == 'bgr':
+            image = image
+        else:
+            raise NotImplementedError("Unknown color format: <{}>".format(self.color_format))
+
+        return image, target
+
+## Random Photometric Distort
+class RandomPhotometricDistort(object):
+    """
+    Distort image w.r.t hue, saturation and exposure.
+    """
+
+    def __init__(self, hue=0.1, saturation=1.5, exposure=1.5):
+        super().__init__()
+        self.hue = hue
+        self.saturation = saturation
+        self.exposure = exposure
+
+    def __call__(self, image: np.ndarray, target=None) -> np.ndarray:
+        """
+        Args:
+            img (ndarray): of shape HxW, HxWxC, or NxHxWxC. The array can be
+                of type uint8 in range [0, 255], or floating point in range
+                [0, 1] or [0, 255].
+
+        Returns:
+            ndarray: the distorted image(s).
+        """
+        if random.random() < 0.5:
+            dhue = np.random.uniform(low=-self.hue, high=self.hue)
+            dsat = self._rand_scale(self.saturation)
+            dexp = self._rand_scale(self.exposure)
+
+            image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
+            image = np.asarray(image, dtype=np.float32) / 255.
+            image[:, :, 1] *= dsat
+            image[:, :, 2] *= dexp
+            H = image[:, :, 0] + dhue * 179 / 255.
+
+            if dhue > 0:
+                H[H > 1.0] -= 1.0
+            else:
+                H[H < 0.0] += 1.0
+
+            image[:, :, 0] = H
+            image = (image * 255).clip(0, 255).astype(np.uint8)
+            image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR)
+            image = np.asarray(image, dtype=np.uint8)
+
+        return image, target
+
+    def _rand_scale(self, upper_bound):
+        """
+        Calculate random scaling factor.
+
+        Args:
+            upper_bound (float): range of the random scale.
+        Returns:
+            random scaling factor (float) whose range is
+            from 1 / s to s .
+        """
+        scale = np.random.uniform(low=1, high=upper_bound)
+        if np.random.rand() > 0.5:
+            return scale
+        return 1 / scale
+
+## Random IoU based Sample Crop
+class RandomSampleCrop(object):
+    def __init__(self):
+        self.sample_options = (
+            # using entire original input image
+            None,
+            # sample a patch s.t. MIN jaccard w/ obj in .1,.3,.4,.7,.9
+            (0.1, None),
+            (0.3, None),
+            (0.7, None),
+            (0.9, None),
+            # randomly sample a patch
+            (None, None),
+        )
+
+    def intersect(self, box_a, box_b):
+        max_xy = np.minimum(box_a[:, 2:], box_b[2:])
+        min_xy = np.maximum(box_a[:, :2], box_b[:2])
+        inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
+
+        return inter[:, 0] * inter[:, 1]
+
+    def compute_iou(self, box_a, box_b):
+        inter = self.intersect(box_a, box_b)
+        area_a = ((box_a[:, 2]-box_a[:, 0]) *
+                (box_a[:, 3]-box_a[:, 1]))  # [A,B]
+        area_b = ((box_b[2]-box_b[0]) *
+                (box_b[3]-box_b[1]))  # [A,B]
+        union = area_a + area_b - inter
+        return inter / union  # [A,B]
+
+    def __call__(self, image, target=None):
+        height, width, _ = image.shape
+
+        # check target
+        if len(target["boxes"]) == 0:
+            return image, target
+
+        while True:
+            # randomly choose a mode
+            sample_id = np.random.randint(len(self.sample_options))
+            mode = self.sample_options[sample_id]
+            if mode is None:
+                return image, target
+
+            boxes = target["boxes"]
+            labels = target["labels"]
+
+            min_iou, max_iou = mode
+            if min_iou is None:
+                min_iou = float('-inf')
+            if max_iou is None:
+                max_iou = float('inf')
+
+            # max trails (50)
+            for _ in range(50):
+                current_image = image
+
+                w = random.uniform(0.3 * width, width)
+                h = random.uniform(0.3 * height, height)
+
+                # aspect ratio constraint b/t .5 & 2
+                if h / w < 0.5 or h / w > 2:
+                    continue
+
+                left = random.uniform(width - w)
+                top = random.uniform(height - h)
+
+                # convert to integer rect x1,y1,x2,y2
+                rect = np.array([int(left), int(top), int(left+w), int(top+h)])
+
+                # calculate IoU (jaccard overlap) b/t the cropped and gt boxes
+                overlap = self.compute_iou(boxes, rect)
+
+                # is min and max overlap constraint satisfied? if not try again
+                if overlap.min() < min_iou and max_iou < overlap.max():
+                    continue
+
+                # cut the crop from the image
+                current_image = current_image[rect[1]:rect[3], rect[0]:rect[2],
+                                              :]
+
+                # keep overlap with gt box IF center in sampled patch
+                centers = (boxes[:, :2] + boxes[:, 2:]) / 2.0
+
+                # mask in all gt boxes that above and to the left of centers
+                m1 = (rect[0] < centers[:, 0]) * (rect[1] < centers[:, 1])
+
+                # mask in all gt boxes that under and to the right of centers
+                m2 = (rect[2] > centers[:, 0]) * (rect[3] > centers[:, 1])
+
+                # mask in that both m1 and m2 are true
+                mask = m1 * m2
+
+                # have any valid boxes? try again if not
+                if not mask.any():
+                    continue
+
+                # take only matching gt boxes
+                current_boxes = boxes[mask, :].copy()
+
+                # take only matching gt labels
+                current_labels = labels[mask]
+
+                # should we use the box left and top corner or the crop's
+                current_boxes[:, :2] = np.maximum(current_boxes[:, :2],
+                                                  rect[:2])
+                # adjust to crop (by substracting crop's left,top)
+                current_boxes[:, :2] -= rect[:2]
+
+                current_boxes[:, 2:] = np.minimum(current_boxes[:, 2:],
+                                                  rect[2:])
+                # adjust to crop (by substracting crop's left,top)
+                current_boxes[:, 2:] -= rect[:2]
+
+                # update target
+                target["boxes"] = current_boxes
+                target["labels"] = current_labels
+
+                return current_image, target
+
+## Random HFlip
+class RandomHorizontalFlip(object):
+    def __init__(self, p=0.5):
+        self.p = p
+
+    def __call__(self, image, target=None):
+        if random.random() < self.p:
+            orig_h, orig_w = image.shape[:2]
+            image = image[:, ::-1]
+            if target is not None:
+                if "boxes" in target:
+                    boxes = target["boxes"].copy()
+                    boxes[..., [0, 2]] = orig_w - boxes[..., [2, 0]]
+                    target["boxes"] = boxes
+
+        return image, target
+
+## Resize tensor image
+class Resize(object):
+    def __init__(self, img_size=640):
+        self.img_size = img_size
+
+    def __call__(self, image, target=None):
+        orig_h, orig_w = image.shape[:2]
+
+        # resize
+        image = cv2.resize(image, (self.img_size, self.img_size)).astype(np.float32)
+        img_h, img_w = image.shape[:2]
+
+        # rescale bboxes
+        if target is not None:
+            boxes = target["boxes"]
+            boxes[:, [0, 2]] = boxes[:, [0, 2]] / orig_w * img_w
+            boxes[:, [1, 3]] = boxes[:, [1, 3]] / orig_h * img_h
+            target["boxes"] = boxes
+
+        return image, target
+
+## Normalize tensor image
+class Normalize(object):
+    def __init__(self, pixel_mean, pixel_std):
+        self.pixel_mean = pixel_mean
+        self.pixel_std = pixel_std
+
+    def __call__(self, image, target=None):
+        # normalize image
+        image = (image - self.pixel_mean) / self.pixel_std
+
+        return image, target
+
+## Convert ndarray to torch.Tensor
+class ToTensor(object):
+    def __call__(self, image, target=None):        
+        # Convert torch.Tensor
+        image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float()
+
+        if target is not None:
+            target["boxes"] = torch.as_tensor(target["boxes"]).float()
+            target["labels"] = torch.as_tensor(target["labels"]).long()
+
+        return image, target
 
 
 # ------------------------- Preprocessers -------------------------
 ## Transform for Train
 class RTDetrAugmentation(object):
-    def __init__(self):
-        return
-    
-    def __call__(self,):
-        pass
+    def __init__(self, img_size=640, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375], use_mosaic=False):
+        # ----------------- Basic parameters -----------------
+        self.img_size = img_size
+        self.use_mosaic = use_mosaic
+        self.pixel_mean = pixel_mean  # RGB format
+        self.pixel_std = pixel_std    # RGB format
+        self.color_format = 'rgb'
+
+        # ----------------- Transforms -----------------
+        if use_mosaic:
+            # For use-mosaic setting, we do not use RandomSampleCrop processor.
+            self.augment = Compose([
+                RandomPhotometricDistort(hue=0.5, saturation=1.5, exposure=1.5),
+                RandomHorizontalFlip(p=0.5),
+                Resize(img_size=self.img_size),
+                ConvertColorFormat(self.color_format),
+                Normalize(self.pixel_mean, self.pixel_std),
+                ToTensor()
+            ])
+        else:
+            # For no-mosaic setting, we use RandomSampleCrop processor.
+            self.augment = Compose([
+                RandomPhotometricDistort(hue=0.5, saturation=1.5, exposure=1.5),
+                RandomSampleCrop(),
+                RandomHorizontalFlip(p=0.5),
+                Resize(img_size=self.img_size),
+                ConvertColorFormat(self.color_format),
+                Normalize(self.pixel_mean, self.pixel_std),
+                ToTensor()
+            ])
+
+    def __call__(self, image, target, mosaic=False):
+        orig_h, orig_w = image.shape[:2]
+        ratio = [self.img_size / orig_w, self.img_size / orig_h]
 
-## Transform for Val
+        image, target = self.augment(image, target)
+
+        return image, target, ratio
+
+
+## Transform for Eval
 class RTDetrBaseTransform(object):
-    def __init__(self):
-        return
-    
-    def __call__(self,):
-        pass
+    def __init__(self, img_size=640, pixel_mean=[123.675, 116.28, 103.53], pixel_std=[58.395, 57.12, 57.375]):
+        # ----------------- Basic parameters -----------------
+        self.img_size = img_size
+        self.pixel_mean = pixel_mean  # RGB format
+        self.pixel_std = pixel_std    # RGB format
+        self.color_format = 'rgb'
+
+        # ----------------- Transforms -----------------
+        self.transform = Compose([
+            Resize(img_size=self.img_size),
+            ConvertColorFormat(self.color_format),
+            Normalize(self.pixel_mean, self.pixel_std),
+            ToTensor()
+        ])
+
+
+    def __call__(self, image, target, mosaic=False):
+        orig_h, orig_w = image.shape[:2]
+        ratio = [self.img_size / orig_w, self.img_size / orig_h]
+
+        image, target = self.transform(image, target)
 
+        return image, target, ratio

+ 40 - 42
dataset/data_augment/ssd_augment.py

@@ -1,36 +1,13 @@
+# ------------------------------------------------------------
+# Data preprocessor for SSD
+# ------------------------------------------------------------
 import cv2
 import numpy as np
 import torch
 from numpy import random
 
 
-def intersect(box_a, box_b):
-    max_xy = np.minimum(box_a[:, 2:], box_b[2:])
-    min_xy = np.maximum(box_a[:, :2], box_b[:2])
-    inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
-    return inter[:, 0] * inter[:, 1]
-
-
-def jaccard_numpy(box_a, box_b):
-    """Compute the jaccard overlap of two sets of boxes.  The jaccard overlap
-    is simply the intersection over union of two boxes.
-    E.g.:
-        A ∩ B / A ∪ B = A ∩ B / (area(A) + area(B) - A ∩ B)
-    Args:
-        box_a: Multiple bounding boxes, Shape: [num_boxes,4]
-        box_b: Single bounding box, Shape: [4]
-    Return:
-        jaccard overlap: Shape: [box_a.shape[0], box_a.shape[1]]
-    """
-    inter = intersect(box_a, box_b)
-    area_a = ((box_a[:, 2]-box_a[:, 0]) *
-              (box_a[:, 3]-box_a[:, 1]))  # [A,B]
-    area_b = ((box_b[2]-box_b[0]) *
-              (box_b[3]-box_b[1]))  # [A,B]
-    union = area_a + area_b - inter
-    return inter / union  # [A,B]
-
-
+# ------------------------- Augmentations -------------------------
 class Compose(object):
     """Composes several augmentations together.
     Args:
@@ -50,12 +27,12 @@ class Compose(object):
             img, boxes, labels = t(img, boxes, labels)
         return img, boxes, labels
 
-
+## Convert Image to float type
 class ConvertFromInts(object):
     def __call__(self, image, boxes=None, labels=None):
         return image.astype(np.float32), boxes, labels
 
-
+## Convert color format
 class ConvertColor(object):
     def __init__(self, current='BGR', transform='HSV'):
         self.transform = transform
@@ -70,7 +47,7 @@ class ConvertColor(object):
             raise NotImplementedError
         return image, boxes, labels
 
-
+## Resize image
 class Resize(object):
     def __init__(self, img_size=640):
         self.img_size = img_size
@@ -86,7 +63,7 @@ class Resize(object):
 
         return image, boxes, labels
 
-
+## Random Saturation
 class RandomSaturation(object):
     def __init__(self, lower=0.5, upper=1.5):
         self.lower = lower
@@ -100,7 +77,7 @@ class RandomSaturation(object):
 
         return image, boxes, labels
 
-
+## Random Hue
 class RandomHue(object):
     def __init__(self, delta=18.0):
         assert delta >= 0.0 and delta <= 360.0
@@ -113,7 +90,7 @@ class RandomHue(object):
             image[:, :, 0][image[:, :, 0] < 0.0] += 360.0
         return image, boxes, labels
 
-
+## Random Lighting noise
 class RandomLightingNoise(object):
     def __init__(self):
         self.perms = ((0, 1, 2), (0, 2, 1),
@@ -127,7 +104,7 @@ class RandomLightingNoise(object):
             image = shuffle(image)
         return image, boxes, labels
 
-
+## Random Contrast
 class RandomContrast(object):
     def __init__(self, lower=0.5, upper=1.5):
         self.lower = lower
@@ -142,7 +119,7 @@ class RandomContrast(object):
             image *= alpha
         return image, boxes, labels
 
-
+## Random Brightness
 class RandomBrightness(object):
     def __init__(self, delta=32):
         assert delta >= 0.0
@@ -155,7 +132,7 @@ class RandomBrightness(object):
             image += delta
         return image, boxes, labels
 
-
+## Random SampleCrop
 class RandomSampleCrop(object):
     """Crop
     Arguments:
@@ -182,6 +159,21 @@ class RandomSampleCrop(object):
             (None, None),
         )
 
+    def intersect(self, box_a, box_b):
+        max_xy = np.minimum(box_a[:, 2:], box_b[2:])
+        min_xy = np.maximum(box_a[:, :2], box_b[:2])
+        inter = np.clip((max_xy - min_xy), a_min=0, a_max=np.inf)
+        return inter[:, 0] * inter[:, 1]
+
+    def compute_iou(self, box_a, box_b):
+        inter = self.intersect(box_a, box_b)
+        area_a = ((box_a[:, 2]-box_a[:, 0]) *
+                (box_a[:, 3]-box_a[:, 1]))  # [A,B]
+        area_b = ((box_b[2]-box_b[0]) *
+                (box_b[3]-box_b[1]))  # [A,B]
+        union = area_a + area_b - inter
+        return inter / union  # [A,B]
+
     def __call__(self, image, boxes=None, labels=None):
         height, width, _ = image.shape
         # check
@@ -219,7 +211,7 @@ class RandomSampleCrop(object):
                 rect = np.array([int(left), int(top), int(left+w), int(top+h)])
 
                 # calculate IoU (jaccard overlap) b/t the cropped and gt boxes
-                overlap = jaccard_numpy(boxes, rect)
+                overlap = self.compute_iou(boxes, rect)
 
                 # is min and max overlap constraint satisfied? if not try again
                 if overlap.min() < min_iou and max_iou < overlap.max():
@@ -264,7 +256,7 @@ class RandomSampleCrop(object):
 
                 return current_image, current_boxes, current_labels
 
-
+## Random scaling
 class Expand(object):
     def __call__(self, image, boxes, labels):
         if random.randint(2):
@@ -288,7 +280,7 @@ class Expand(object):
 
         return image, boxes, labels
 
-
+## Random HFlip
 class RandomHorizontalFlip(object):
     def __call__(self, image, boxes, classes):
         _, width, _ = image.shape
@@ -298,7 +290,7 @@ class RandomHorizontalFlip(object):
             boxes[:, 0::2] = width - boxes[:, 2::-2]
         return image, boxes, classes
 
-
+## Random swap channels
 class SwapChannels(object):
     """Transforms a tensorized image by swapping the channels in the order
      specified in the swap tuple.
@@ -324,7 +316,7 @@ class SwapChannels(object):
         image = image[:, :, self.swaps]
         return image
 
-
+## Random color jitter
 class PhotometricDistort(object):
     def __init__(self):
         self.pd = [
@@ -348,11 +340,14 @@ class PhotometricDistort(object):
         return im, boxes, labels
 
 
-# ----------------------- Main Functions -----------------------
+# ------------------------- Preprocessers -------------------------
 ## SSD-style Augmentation
 class SSDAugmentation(object):
     def __init__(self, img_size=640):
         self.img_size = img_size
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [1., 1., 1.]
+        self.color_format = 'bgr'
         self.augment = Compose([
             ConvertFromInts(),                         # 将int类型转换为float32类型
             PhotometricDistort(),                      # 图像颜色增强
@@ -384,6 +379,9 @@ class SSDAugmentation(object):
 class SSDBaseTransform(object):
     def __init__(self, img_size):
         self.img_size = img_size
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [1., 1., 1.]
+        self.color_format = 'bgr'
 
     def __call__(self, image, target=None, mosaic=False):
         # resize

+ 15 - 6
dataset/data_augment/yolov5_augment.py

@@ -123,7 +123,7 @@ class Albumentations(object):
 
 # ------------------------- Strong augmentations -------------------------
 ## YOLOv5-Mosaic
-def yolov5_mosaic_augment(image_list, target_list, img_size, affine_params, is_train=False):
+def yolov5_mosaic_augment(image_list, target_list, img_size, affine_params, keep_ratio=True, is_train=False):
     assert len(image_list) == 4
 
     mosaic_img = np.ones([img_size*2, img_size*2, image_list[0].shape[2]], dtype=np.uint8) * 114
@@ -141,10 +141,14 @@ def yolov5_mosaic_augment(image_list, target_list, img_size, affine_params, is_t
         orig_h, orig_w, _ = img_i.shape
 
         # resize
-        r = img_size / max(orig_h, orig_w)
-        if r != 1: 
-            interp = cv2.INTER_LINEAR if (is_train or r > 1) else cv2.INTER_AREA
-            img_i = cv2.resize(img_i, (int(orig_w * r), int(orig_h * r)), interpolation=interp)
+        if keep_ratio:
+            r = img_size / max(orig_h, orig_w)
+            if r != 1: 
+                interp = cv2.INTER_LINEAR if (is_train or r > 1) else cv2.INTER_AREA
+                img_i = cv2.resize(img_i, (int(orig_w * r), int(orig_h * r)), interpolation=interp)
+        else:
+            interp = cv2.INTER_LINEAR if is_train else cv2.INTER_AREA
+            img_i = cv2.resize(img_i, (img_size, img_size), interpolation=interp)
         h, w, _ = img_i.shape
 
         # place img in img4
@@ -332,6 +336,9 @@ class YOLOv5Augmentation(object):
     def __init__(self, img_size=640, trans_config=None, use_ablu=False):
         # Basic parameters
         self.img_size = img_size
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [1., 1., 1.]
+        self.color_format = 'bgr'
         self.trans_config = trans_config
         # Albumentations
         self.ablu_trans = Albumentations(img_size) if use_ablu else None
@@ -413,7 +420,9 @@ class YOLOv5BaseTransform(object):
     def __init__(self, img_size=640, max_stride=32):
         self.img_size = img_size
         self.max_stride = max_stride
-
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [1., 1., 1.]
+        self.color_format = 'bgr'
 
     def __call__(self, image, target=None, mosaic=False):
         # --------------- Keep ratio Resize ---------------

+ 14 - 3
dataset/voc.py

@@ -164,7 +164,7 @@ class VOCDataset(data.Dataset):
         # Mosaic
         if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
             image, target = yolov5_mosaic_augment(
-                image_list, target_list, self.img_size, self.trans_config, self.is_train)
+                image_list, target_list, self.img_size, self.trans_config, self.trans_config['mosaic_keep_ratio'], self.is_train)
 
         return image, target
 
@@ -257,7 +257,7 @@ if __name__ == "__main__":
     parser.add_argument('-size', '--img_size', default=640, type=int,
                         help='input image size.')
     parser.add_argument('--aug_type', type=str, default='ssd',
-                        help='augmentation type')
+                        help='augmentation type: ssd, yolov5, rtdetr.')
     parser.add_argument('--mosaic', default=0., type=float,
                         help='mosaic augmentation.')
     parser.add_argument('--mixup', default=0., type=float,
@@ -288,9 +288,13 @@ if __name__ == "__main__":
         'mixup_prob': args.mixup,
         'mosaic_type': 'yolov5_mosaic',
         'mixup_type': args.mixup_type,   # optional: yolov5_mixup, yolox_mixup
+        'mosaic_keep_ratio': False,
         'mixup_scale': [0.5, 1.5]
     }
     transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
+    pixel_mean = transform.pixel_mean
+    pixel_std  = transform.pixel_std
+    color_format = transform.color_format
 
     dataset = VOCDataset(
         img_size=args.img_size,
@@ -315,6 +319,13 @@ if __name__ == "__main__":
 
         # to numpy
         image = image.permute(1, 2, 0).numpy()
+        
+        # denormalize
+        image = image * pixel_std + pixel_mean
+        if color_format == 'rgb':
+            # RGB to BGR
+            image = image[..., (2, 1, 0)]
+
         # to uint8
         image = image.astype(np.uint8)
         image = image.copy()
@@ -330,7 +341,7 @@ if __name__ == "__main__":
                 color = class_colors[cls_id]
                 # class name
                 label = VOC_CLASSES[cls_id]
-                image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), (0,0,255), 2)
+                image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)
                 # put the test on the bbox
                 cv2.putText(image, label, (int(x1), int(y1 - 5)), 0, 0.5, color, 1, lineType=cv2.LINE_AA)
         cv2.imshow('gt', image)

+ 12 - 2
dataset/widerface.py

@@ -81,7 +81,7 @@ class WiderFaceDataset(Dataset):
         # Mosaic
         if self.trans_config['mosaic_type'] == 'yolov5_mosaic':
             image, target = yolov5_mosaic_augment(
-                image_list, target_list, self.img_size, self.trans_config, self.is_train)
+                image_list, target_list, self.img_size, self.trans_config, self.trans_config['mosaic_keep_ratio'], self.is_train)
 
         return image, target
 
@@ -222,10 +222,13 @@ if __name__ == "__main__":
         'mixup_prob': args.mixup,
         'mosaic_type': 'yolov5_mosaic',
         'mixup_type': args.mixup_type,   # optional: yolov5_mixup, yolox_mixup
+        'mosaic_keep_ratio': False,
         'mixup_scale': [0.5, 1.5]
     }
-
     transform, trans_cfg = build_transform(args, trans_config, 32, args.is_train)
+    pixel_mean = transform.pixel_mean
+    pixel_std  = transform.pixel_std
+    color_format = transform.color_format
 
     dataset = WiderFaceDataset(
         img_size=args.img_size,
@@ -248,6 +251,13 @@ if __name__ == "__main__":
 
         # to numpy
         image = image.permute(1, 2, 0).numpy()
+        
+        # denormalize
+        image = image * pixel_std + pixel_mean
+        if color_format == 'rgb':
+            # RGB to BGR
+            image = image[..., (2, 1, 0)]
+
         # to uint8
         image = image.astype(np.uint8)
         image = image.copy()

+ 0 - 0
models/detectors/rtrdet/README.md → models/detectors/rtdetr/README.md


+ 187 - 0
models/detectors/rtdetr/basic_modules/backbone.py

@@ -0,0 +1,187 @@
+import torch
+import torch.nn as nn
+from torch import Tensor
+from typing import Callable, List, Optional, Type, Union
+
+try:
+    from .basic import conv1x1, BasicBlock, Bottleneck
+except:
+    from basic import conv1x1, BasicBlock, Bottleneck
+   
+
+# IN1K pretrained weights
+pretrained_urls = {
+    # ResNet series
+    'resnet18': None,
+    'resnet34': None,
+    'resnet50': None,
+    'resnet101': None,
+    'resnet152': None,
+    # ShuffleNet series
+}
+
+
+# ----------------- Model functions -----------------
+## Build backbone network
+def build_backbone(cfg, pretrained):
+    if 'resnet' in cfg['backbone']:
+        # Build ResNet
+        model, feats = build_resnet(cfg, pretrained)
+    else:
+        raise NotImplementedError("Unknown backbone: <>.".format(cfg['backbone']))
+    
+    return model, feats
+
+## Load pretrained weight
+def load_pretrained(model_name):
+    return
+
+
+# ----------------- ResNet Backbone -----------------
+class ResNet(nn.Module):
+    def __init__(self,
+                 block: Type[Union[BasicBlock, Bottleneck]],
+                 layers: List[int],
+                 num_classes: int = 1000,
+                 zero_init_residual: bool = False,
+                 groups: int = 1,
+                 width_per_group: int = 64,
+                 replace_stride_with_dilation: Optional[List[bool]] = None,
+                 norm_layer: Optional[Callable[..., nn.Module]] = None,
+                 ) -> None:
+        super().__init__()
+        # --------------- Basic parameters ----------------
+        self.groups = groups
+        self.base_width = width_per_group
+        self.inplanes = 64
+        self.dilation = 1
+        self.zero_init_residual = zero_init_residual
+        self.replace_stride_with_dilation = [False, False, False] if replace_stride_with_dilation is None else replace_stride_with_dilation
+        if len(self.replace_stride_with_dilation) != 3:
+            raise ValueError(
+                "replace_stride_with_dilation should be None "
+                f"or a 3-element tuple, got {self.replace_stride_with_dilation}"
+            )
+
+        # --------------- Network parameters ----------------
+        self._norm_layer = nn.BatchNorm2d if norm_layer is None else norm_layer
+        ## Stem layer
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
+        self.bn1 = self._norm_layer(self.inplanes)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        ## Res Layer
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=self.replace_stride_with_dilation[0])
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=self.replace_stride_with_dilation[1])
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=self.replace_stride_with_dilation[2])
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        self._init_layer()
+
+    def _init_layer(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        if self.zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
+                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
+                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
+                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
+
+    def _make_layer(
+        self,
+        block: Type[Union[BasicBlock, Bottleneck]],
+        planes: int,
+        blocks: int,
+        stride: int = 1,
+        dilate: bool = False,
+    ) -> nn.Sequential:
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                norm_layer(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(
+            block(
+                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
+            )
+        )
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(
+                block(
+                    self.inplanes,
+                    planes,
+                    groups=self.groups,
+                    base_width=self.base_width,
+                    dilation=self.dilation,
+                    norm_layer=norm_layer,
+                )
+            )
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x: Tensor) -> Tensor:
+        # See note [TorchScript super()]
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        x = torch.flatten(x, 1)
+        x = self.fc(x)
+
+        return x
+
+def _resnet(block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], **kwargs) -> ResNet:
+    return ResNet(block, layers, **kwargs)
+
+def build_resnet(cfg, pretrained=False, **kwargs):
+    # ---------- Build ResNet ----------
+    if   cfg['backbone'] == 'resnet18':
+        model = _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
+        feats = [128, 256, 512]
+    elif cfg['backbone'] == 'resnet34':
+        model = _resnet(BasicBlock, [3, 4, 6, 3], **kwargs)
+        feats = [128, 256, 512]
+    elif cfg['backbone'] == 'resnet50':
+        model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
+        feats = [512, 1024, 2048]
+    elif cfg['backbone'] == 'resnet101':
+        model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
+        feats = [512, 1024, 2048]
+    elif cfg['backbone'] == 'resnet152':
+        model = _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
+        feats = [512, 1024, 2048]
+
+    # ---------- Load pretrained ----------
+    if pretrained:
+        # TODO: load IN1K pretrained
+        pass
+
+    return model, feats
+
+
+# ----------------- ShuffleNet Backbone -----------------
+## TODO: Add shufflenet-v2

+ 195 - 0
models/detectors/rtdetr/basic_modules/basic.py

@@ -0,0 +1,195 @@
+import torch
+import torch.nn as nn
+from torch import Tensor
+from typing import List, Optional, Callable
+
+
+# ----------------- CNN modules -----------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+    """3x3 convolution with padding"""
+    return nn.Conv2d(
+        in_planes,
+        out_planes,
+        kernel_size=3,
+        stride=stride,
+        padding=dilation,
+        groups=groups,
+        bias=False,
+        dilation=dilation,
+    )
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type  :str  = 'lrelu',   # activation
+                 norm_type :str  ='BN',       # normalization
+                 depthwise :bool =False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+class BasicBlock(nn.Module):
+    expansion: int = 1
+
+    def __init__(
+        self,
+        inplanes: int,
+        planes: int,
+        stride: int = 1,
+        downsample: Optional[nn.Module] = None,
+        groups: int = 1,
+        base_width: int = 64,
+        dilation: int = 1,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        if groups != 1 or base_width != 64:
+            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = norm_layer(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = norm_layer(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x: Tensor) -> Tensor:
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+class Bottleneck(nn.Module):
+    expansion: int = 4
+
+    def __init__(
+        self,
+        inplanes: int,
+        planes: int,
+        stride: int = 1,
+        downsample: Optional[nn.Module] = None,
+        groups: int = 1,
+        base_width: int = 64,
+        dilation: int = 1,
+        norm_layer: Optional[Callable[..., nn.Module]] = None,
+    ) -> None:
+        super().__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        width = int(planes * (base_width / 64.0)) * groups
+        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv1x1(inplanes, width)
+        self.bn1 = norm_layer(width)
+        self.conv2 = conv3x3(width, width, stride, groups, dilation)
+        self.bn2 = norm_layer(width)
+        self.conv3 = conv1x1(width, planes * self.expansion)
+        self.bn3 = norm_layer(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x: Tensor) -> Tensor:
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+# ----------------- Transformer modules -----------------

+ 7 - 0
models/detectors/rtdetr/basic_modules/neck.py

@@ -0,0 +1,7 @@
+import torch
+import torch.nn as nn
+
+
+# Build neck
+def build_neck(cfg, in_dim, out_dim):
+    return

+ 110 - 0
models/detectors/rtdetr/basic_modules/pafpn.py

@@ -0,0 +1,110 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .basic import Conv, RTCBlock
+
+
+# Build PaFPN
+def build_pafpn(cfg, in_dims, out_dim):
+    return
+
+
+# ----------------- Feature Pyramid Network -----------------
+## Real-time Convolutional PaFPN
+class RTCPaFPN(nn.Module):
+    def __init__(self, 
+                 in_dims   = [256, 512, 512],
+                 width     = 1.0,
+                 depth     = 1.0,
+                 ratio     = 1.0,
+                 act_type  = 'silu',
+                 norm_type = 'BN',
+                 depthwise = False):
+        super(RTCPaFPN, self).__init__()
+        print('==============================')
+        print('FPN: {}'.format("RTC-PaFPN"))
+        # ---------------- Basic parameters ----------------
+        self.in_dims = in_dims
+        self.width = width
+        self.depth = depth
+        self.out_dim = [round(256 * width), round(512 * width), round(512 * width * ratio)]
+        c3, c4, c5 = in_dims
+
+        # ---------------- Top dwon ----------------
+        ## P5 -> P4
+        self.top_down_layer_1 = RTCBlock(in_dim       = c5 + c4,
+                                         out_dim      = round(512*width),
+                                         num_blocks   = round(3*depth),
+                                         shortcut     = False,
+                                         act_type     = act_type,
+                                         norm_type    = norm_type,
+                                         depthwise    = depthwise,
+                                         )
+        ## P4 -> P3
+        self.top_down_layer_2 = RTCBlock(in_dim       = round(512*width) + c3,
+                                         out_dim      = round(256*width),
+                                         num_blocks   = round(3*depth),
+                                         shortcut     = False,
+                                         act_type     = act_type,
+                                         norm_type    = norm_type,
+                                         depthwise    = depthwise,
+                                         )
+        # ---------------- Bottom up ----------------
+        ## P3 -> P4
+        self.dowmsample_layer_1 = Conv(round(256*width), round(256*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.bottom_up_layer_1  = RTCBlock(in_dim       = round(256*width) + round(512*width),
+                                           out_dim      = round(512*width),
+                                           num_blocks   = round(3*depth),
+                                           shortcut     = False,
+                                           act_type     = act_type,
+                                           norm_type    = norm_type,
+                                           depthwise    = depthwise,
+                                           )
+        ## P4 -> P5
+        self.dowmsample_layer_2 = Conv(round(512*width), round(512*width), k=3, p=1, s=2, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.bottom_up_layer_2 = RTCBlock(in_dim       = round(512 * width) + c5,
+                                          out_dim      = round(512 * width * ratio),
+                                          num_blocks   = round(3*depth),
+                                          shortcut     = False,
+                                          act_type     = act_type,
+                                          norm_type    = norm_type,
+                                          depthwise    = depthwise,
+                                          )
+
+        self.init_weights()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, features):
+        c3, c4, c5 = features
+
+        # Top down
+        ## P5 -> P4
+        c6 = F.interpolate(c5, scale_factor=2.0)
+        c7 = torch.cat([c6, c4], dim=1)
+        c8 = self.top_down_layer_1(c7)
+        ## P4 -> P3
+        c9 = F.interpolate(c8, scale_factor=2.0)
+        c10 = torch.cat([c9, c3], dim=1)
+        c11 = self.top_down_layer_2(c10)
+
+        # Bottom up
+        # p3 -> P4
+        c12 = self.dowmsample_layer_1(c11)
+        c13 = torch.cat([c12, c8], dim=1)
+        c14 = self.bottom_up_layer_1(c13)
+        # P4 -> P5
+        c15 = self.dowmsample_layer_2(c14)
+        c16 = torch.cat([c15, c5], dim=1)
+        c17 = self.bottom_up_layer_2(c16)
+
+        out_feats = [c11, c14, c17] # [P3, P4, P5]
+        
+        return out_feats

+ 0 - 0
models/detectors/rtrdet/build.py → models/detectors/rtdetr/build.py


+ 0 - 0
models/detectors/rtrdet/loss.py → models/detectors/rtdetr/loss.py


+ 0 - 0
models/detectors/rtrdet/matcher.py → models/detectors/rtdetr/matcher.py


+ 1 - 1
models/detectors/rtrdet/rtrdet.py → models/detectors/rtdetr/rtdetr.py

@@ -1,5 +1,5 @@
 # Real-time Transformer-based Object Detector
 
 
-class RTRDet():
+class RT_DETR():
     pass

+ 35 - 0
models/detectors/rtdetr/rtdetr_decoder.py

@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# ----------------- Dencoder for Detection task -----------------
+class DetDecoder(nn.Module):
+    def __init__(self, ):
+        super().__init__()
+        self.backbone = None
+        self.neck = None
+        self.fpn = None
+
+    def forward(self, x):
+        return
+
+
+# ----------------- Dencoder for Segmentation task -----------------
+class SegDecoder(nn.Module):
+    def __init__(self, ):
+        super().__init__()
+        # TODO: design seg-decoder
+
+    def forward(self, x):
+        return
+
+
+# ----------------- Dencoder for Pose estimation task -----------------
+class PosDecoder(nn.Module):
+    def __init__(self, ):
+        super().__init__()
+        # TODO: design seg-decoder
+
+    def forward(self, x):
+        return

+ 19 - 0
models/detectors/rtdetr/rtdetr_encoder.py

@@ -0,0 +1,19 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .basic_modules.backbone import build_backbone
+from .basic_modules.pafpn    import build_pafpn
+
+
+# ----------------- Image Encoder -----------------
+class ImageEncoder(nn.Module):
+    def __init__(self, ):
+        super().__init__()
+        self.backbone = None
+        self.neck = None
+        self.fpn = None
+
+    def forward(self, x):
+        return
+    

+ 0 - 129
models/detectors/rtrdet/rtrdet_basic.py

@@ -1,129 +0,0 @@
-import torch
-import torch.nn as nn
-from typing import List
-
-
-# ----------------- CNN modules -----------------
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-    elif act_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-        
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-    elif norm_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-
-class Conv(nn.Module):
-    def __init__(self, 
-                 c1,                   # in channels
-                 c2,                   # out channels 
-                 k=1,                  # kernel size 
-                 p=0,                  # padding
-                 s=1,                  # padding
-                 d=1,                  # dilation
-                 act_type  :str  = 'lrelu',   # activation
-                 norm_type :str  ='BN',       # normalization
-                 depthwise :bool =False):
-        super(Conv, self).__init__()
-        convs = []
-        add_bias = False if norm_type else True
-        if depthwise:
-            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
-            # depthwise conv
-            if norm_type:
-                convs.append(get_norm(norm_type, c1))
-            if act_type:
-                convs.append(get_activation(act_type))
-            # pointwise conv
-            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
-            if norm_type:
-                convs.append(get_norm(norm_type, c2))
-            if act_type:
-                convs.append(get_activation(act_type))
-
-        else:
-            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
-            if norm_type:
-                convs.append(get_norm(norm_type, c2))
-            if act_type:
-                convs.append(get_activation(act_type))
-            
-        self.convs = nn.Sequential(*convs)
-
-
-    def forward(self, x):
-        return self.convs(x)
-
-class Bottleneck(nn.Module):
-    def __init__(self,
-                 in_dim       :int,
-                 out_dim      :int,
-                 expand_ratio :float = 0.5,
-                 kernel_sizes :List = [3, 3],
-                 shortcut     :bool = True,
-                 act_type     :str  = 'silu',
-                 norm_type    :str  = 'BN',
-                 depthwise    :bool = False,):
-        super(Bottleneck, self).__init__()
-        inter_dim = int(out_dim * expand_ratio)  # hidden channels            
-        self.cv1 = Conv(in_dim, inter_dim, k=kernel_sizes[0], p=kernel_sizes[0]//2, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
-        self.cv2 = Conv(inter_dim, out_dim, k=kernel_sizes[1], p=kernel_sizes[1]//2, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
-        self.shortcut = shortcut and in_dim == out_dim
-
-    def forward(self, x):
-        h = self.cv2(self.cv1(x))
-
-        return x + h if self.shortcut else h
-
-class RTCBlock(nn.Module):
-    def __init__(self,
-                 in_dim     :int,
-                 out_dim    :int,
-                 num_blocks :int  = 1,
-                 shortcut   :bool = False,
-                 act_type   :str  = 'silu',
-                 norm_type  :str  = 'BN',
-                 depthwise  :bool = False,):
-        super(RTCBlock, self).__init__()
-        self.inter_dim = out_dim // 2
-        self.input_proj = Conv(in_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
-        self.m = nn.Sequential(*(
-            Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
-            for _ in range(num_blocks)))
-        self.output_proj = Conv((2 + num_blocks) * self.inter_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
-
-    def forward(self, x):
-        # Input proj
-        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
-        out = list([x1, x2])
-
-        # Bottlenecl
-        out.extend(m(out[-1]) for m in self.m)
-
-        # Output proj
-        out = self.output_proj(torch.cat(out, dim=1))
-
-        return out
-
-
-# ----------------- Transformer modules -----------------

+ 0 - 0
models/detectors/rtrdet/rtrdet_decoder.py


+ 0 - 0
models/detectors/rtrdet/rtrdet_encoder.py


+ 0 - 0
models/detectors/rtrdet/rtrdet_head.py