浏览代码

modify RTCDet-v2

yjh0410 2 年之前
父节点
当前提交
8d590a1158

+ 21 - 0
config/__init__.py

@@ -39,6 +39,13 @@ from .data_config.transform_config import (
     rtcdet_v1_medium_trans_config,
     rtcdet_v1_large_trans_config,
     rtcdet_v1_huge_trans_config,
+    # RTMDet-v2-Style
+    rtcdet_v2_pico_trans_config,
+    rtcdet_v2_nano_trans_config,
+    rtcdet_v2_small_trans_config,
+    rtcdet_v2_medium_trans_config,
+    rtcdet_v2_large_trans_config,
+    rtcdet_v2_huge_trans_config,
 )
 
 def build_trans_config(trans_config='ssd'):
@@ -91,6 +98,20 @@ def build_trans_config(trans_config='ssd'):
     elif trans_config == 'rtcdet_v1_huge':
         cfg = rtcdet_v1_huge_trans_config
 
+    # RTMDetv2-style transform 
+    elif trans_config == 'rtcdet_v2_pico':
+        cfg = rtcdet_v2_pico_trans_config
+    elif trans_config == 'rtcdet_v2_nano':
+        cfg = rtcdet_v2_nano_trans_config
+    elif trans_config == 'rtcdet_v2_small':
+        cfg = rtcdet_v2_small_trans_config
+    elif trans_config == 'rtcdet_v2_medium':
+        cfg = rtcdet_v2_medium_trans_config
+    elif trans_config == 'rtcdet_v2_large':
+        cfg = rtcdet_v2_large_trans_config
+    elif trans_config == 'rtcdet_v2_huge':
+        cfg = rtcdet_v2_huge_trans_config
+
     print('Transform Config: {} \n'.format(cfg))
 
     return cfg

+ 134 - 18
config/data_config/transform_config.py

@@ -7,7 +7,7 @@ yolov5_huge_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.2,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -26,7 +26,7 @@ yolov5_large_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.2,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -45,7 +45,7 @@ yolov5_medium_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.2,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -64,7 +64,7 @@ yolov5_small_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.2,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -83,7 +83,7 @@ yolov5_nano_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.1,
-    'scale': 0.5,
+    'scale': [0.5, 1.5],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -102,7 +102,7 @@ yolov5_pico_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.1,
-    'scale': 0.5,
+    'scale': [0.5, 1.5],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -123,7 +123,7 @@ yolox_huge_trans_config = {
     # Basic Augment
     'degrees': 10.0,
     'translate': 0.1,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 2.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -142,7 +142,7 @@ yolox_large_trans_config = {
     # Basic Augment
     'degrees': 10.0,
     'translate': 0.1,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 2.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -161,7 +161,7 @@ yolox_medium_trans_config = {
     # Basic Augment
     'degrees': 10.0,
     'translate': 0.1,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 2.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -180,7 +180,7 @@ yolox_small_trans_config = {
     # Basic Augment
     'degrees': 10.0,
     'translate': 0.1,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 2.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -199,7 +199,7 @@ yolox_nano_trans_config = {
     # Basic Augment
     'degrees': 10.0,
     'translate': 0.1,
-    'scale': 0.5,
+    'scale': [0.5, 1.5],
     'shear': 2.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -218,7 +218,7 @@ yolox_pico_trans_config = {
     # Basic Augment
     'degrees': 10.0,
     'translate': 0.1,
-    'scale': 0.9,
+    'scale': [0.5, 1.5],
     'shear': 2.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -251,7 +251,7 @@ rtcdet_v1_huge_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.2,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -270,7 +270,7 @@ rtcdet_v1_large_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.2,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -289,7 +289,7 @@ rtcdet_v1_medium_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.2,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -308,7 +308,7 @@ rtcdet_v1_small_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.2,
-    'scale': 0.9,
+    'scale': [0.1, 2.0],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -327,7 +327,7 @@ rtcdet_v1_nano_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.1,
-    'scale': 0.5,
+    'scale': [0.5, 1.5],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -346,7 +346,7 @@ rtcdet_v1_pico_trans_config = {
     # Basic Augment
     'degrees': 0.0,
     'translate': 0.1,
-    'scale': 0.5,
+    'scale': [0.5, 1.5],
     'shear': 0.0,
     'perspective': 0.0,
     'hsv_h': 0.015,
@@ -359,3 +359,119 @@ rtcdet_v1_pico_trans_config = {
     'mixup_type': 'yolox_mixup',
     'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
 }
+
+
+# ----------------------- RTMDet-v2's Transform -----------------------
+rtcdet_v2_huge_trans_config = {
+    'aug_type': 'yolov5',
+    # Basic Augment
+    'degrees': 0.0,
+    'translate': 0.2,
+    'scale': [0.5, 2.0],
+    'shear': 0.0,
+    'perspective': 0.0,
+    'hsv_h': 0.015,
+    'hsv_s': 0.7,
+    'hsv_v': 0.4,
+    # Mosaic & Mixup
+    'mosaic_prob': 1.0,
+    'mixup_prob': 1.0,
+    'mosaic_type': 'yolov5_mosaic',
+    'mixup_type': 'yolov5_mixup',
+    'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
+}
+
+rtcdet_v2_large_trans_config = {
+    'aug_type': 'yolov5',
+    # Basic Augment
+    'degrees': 0.0,
+    'translate': 0.2,
+    'scale': [0.5, 2.0],
+    'shear': 0.0,
+    'perspective': 0.0,
+    'hsv_h': 0.015,
+    'hsv_s': 0.7,
+    'hsv_v': 0.4,
+    # Mosaic & Mixup
+    'mosaic_prob': 1.0,
+    'mixup_prob': 1.0,
+    'mosaic_type': 'yolov5_mosaic',
+    'mixup_type': 'yolov5_mixup',
+    'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
+}
+
+rtcdet_v2_medium_trans_config = {
+    'aug_type': 'yolov5',
+    # Basic Augment
+    'degrees': 0.0,
+    'translate': 0.2,
+    'scale': [0.5, 2.0],
+    'shear': 0.0,
+    'perspective': 0.0,
+    'hsv_h': 0.015,
+    'hsv_s': 0.7,
+    'hsv_v': 0.4,
+    # Mosaic & Mixup
+    'mosaic_prob': 1.0,
+    'mixup_prob': 1.0,
+    'mosaic_type': 'yolov5_mosaic',
+    'mixup_type': 'yolov5_mixup',
+    'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
+}
+
+rtcdet_v2_small_trans_config = {
+    'aug_type': 'yolov5',
+    # Basic Augment
+    'degrees': 0.0,
+    'translate': 0.2,
+    'scale': [0.5, 2.0],
+    'shear': 0.0,
+    'perspective': 0.0,
+    'hsv_h': 0.015,
+    'hsv_s': 0.7,
+    'hsv_v': 0.4,
+    # Mosaic & Mixup
+    'mosaic_prob': 1.0,
+    'mixup_prob': 1.0,
+    'mosaic_type': 'yolov5_mosaic',
+    'mixup_type': 'yolov5_mixup',
+    'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
+}
+
+rtcdet_v2_nano_trans_config = {
+    'aug_type': 'yolov5',
+    # Basic Augment
+    'degrees': 0.0,
+    'translate': 0.2,
+    'scale': [0.5, 1.5],
+    'shear': 0.0,
+    'perspective': 0.0,
+    'hsv_h': 0.015,
+    'hsv_s': 0.7,
+    'hsv_v': 0.4,
+    # Mosaic & Mixup
+    'mosaic_prob': 1.0,
+    'mixup_prob': 0.5,
+    'mosaic_type': 'yolov5_mosaic',
+    'mixup_type': 'yolov5_mixup',
+    'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
+}
+
+rtcdet_v2_pico_trans_config = {
+    'aug_type': 'yolov5',
+    # Basic Augment
+    'degrees': 0.0,
+    'translate': 0.2,
+    'scale': [0.5, 1.5],
+    'shear': 0.0,
+    'perspective': 0.0,
+    'hsv_h': 0.015,
+    'hsv_s': 0.7,
+    'hsv_v': 0.4,
+    # Mosaic & Mixup
+    'mosaic_prob': 0.5,
+    'mixup_prob': 0.0,
+    'mosaic_type': 'yolov5_mosaic',
+    'mixup_type': 'yolov5_mixup',
+    'mixup_scale': [0.5, 1.5]   # "mixup_scale" is not used for YOLOv5MixUp
+}

+ 10 - 4
config/model_config/rtcdet_v2_config.py

@@ -42,10 +42,13 @@ rtcdet_v2_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.5, 1.5],   # 320 -> 960
-        'trans_type': 'rtcdet_v1_nano',
+        'trans_type': 'rtcdet_v2_nano',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': {'ota': {'center_sampling_radius': 2.5,
+        'matcher': {'tal': {'topk': 10,
+                            'alpha': 0.5,
+                            'beta': 6.0},
+                    'ota': {'center_sampling_radius': 2.5,
                              'topk_candidate': 10},
                     },
         # ---------------- Loss config ----------------
@@ -98,10 +101,13 @@ rtcdet_v2_cfg = {
         # ---------------- Train config ----------------
         ## Input
         'multi_scale': [0.5, 1.25],   # 320 -> 800
-        'trans_type': 'rtcdet_v1_large',
+        'trans_type': 'rtcdet_v2_large',
         # ---------------- Assignment config ----------------
         ## Matcher
-        'matcher': {'ota': {'center_sampling_radius': 2.5,
+        'matcher': {'tal': {'topk': 10,
+                            'alpha': 0.5,
+                            'beta': 6.0},
+                    'ota': {'center_sampling_radius': 2.5,
                              'topk_candidate': 10},
                     },
         # ---------------- Loss config ----------------

+ 1 - 1
dataset/coco.py

@@ -244,7 +244,7 @@ if __name__ == "__main__":
         # Basic Augment
         'degrees': 0.0,
         'translate': 0.2,
-        'scale': 0.9,
+        'scale': [0.5, 2.0],
         'shear': 0.0,
         'perspective': 0.0,
         'hsv_h': 0.015,

+ 3 - 3
dataset/data_augment/yolov5_augment.py

@@ -11,7 +11,7 @@ def random_perspective(image,
                        targets=(),
                        degrees=10,
                        translate=.1,
-                       scale=.1,
+                       scale=[0.1, 2.0],
                        shear=10,
                        perspective=0.0,
                        border=(0, 0)):
@@ -35,7 +35,7 @@ def random_perspective(image,
     R = np.eye(3)
     a = random.uniform(-degrees, degrees)
     # a += random.choice([-180, -90, 0, 90])  # add 90deg rotations to small rotations
-    s = random.uniform(1 - scale, 1 + scale)
+    s = random.uniform(scale[0], scale[1])
     # s = 2 ** random.uniform(-scale, scale)
     R[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
 
@@ -97,7 +97,7 @@ def augment_hsv(img, hgain=0.5, sgain=0.5, vgain=0.5):
 
 # ------------------------- Strong augmentations -------------------------
 ## YOLOv5-Mosaic
-def yolov5_mosaic_augment(image_list, target_list, img_size, affine_params=None, is_train=False):
+def yolov5_mosaic_augment(image_list, target_list, img_size, affine_params, is_train=False):
     assert len(image_list) == 4
 
     mosaic_img = np.ones([img_size*2, img_size*2, image_list[0].shape[2]], dtype=np.uint8) * 114

+ 1 - 1
dataset/ourdataset.py

@@ -225,7 +225,7 @@ if __name__ == "__main__":
         # Basic Augment
         'degrees': 0.0,
         'translate': 0.2,
-        'scale': 0.9,
+        'scale': [0.5, 2.0],
         'shear': 0.0,
         'perspective': 0.0,
         'hsv_h': 0.015,

+ 1 - 1
dataset/voc.py

@@ -268,7 +268,7 @@ if __name__ == "__main__":
         # Basic Augment
         'degrees': 0.0,
         'translate': 0.2,
-        'scale': 0.9,
+        'scale': [0.5, 2.0],
         'shear': 0.0,
         'perspective': 0.0,
         'hsv_h': 0.015,

+ 124 - 2
models/detectors/rtcdet_v2/loss.py

@@ -4,7 +4,7 @@ import torch.nn.functional as F
 from utils.box_ops import bbox2dist, get_ious
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
-from .matcher import AlignedSimOTA
+from .matcher import TaskAlignedAssigner, AlignedSimOTA
 
 
 class Criterion(object):
@@ -20,6 +20,13 @@ class Criterion(object):
         self.loss_dfl_weight = cfg['loss_dfl_weight']
         # ---------------- Matcher ----------------
         matcher_config = cfg['matcher']
+        ## TAL assigner
+        self.tal_matcher = TaskAlignedAssigner(
+            topk=matcher_config['tal']['topk'],
+            alpha=matcher_config['tal']['alpha'],
+            beta=matcher_config['tal']['beta'],
+            num_classes=num_classes
+            )
         ## SimOTA assigner
         self.ota_matcher = AlignedSimOTA(
             center_sampling_radius=matcher_config['ota']['center_sampling_radius'],
@@ -27,6 +34,12 @@ class Criterion(object):
             num_classes=num_classes
         )
 
+    def __call__(self, outputs, targets, epoch=0):
+        if epoch < self.args.max_epoch // 2:
+            return self.ota_loss(outputs, targets)
+        else:
+            return self.tal_loss(outputs, targets)
+
     def ema_update(self, name: str, value, initial_value, momentum=0.9):
         if hasattr(self, name):
             old = getattr(self, name)
@@ -93,8 +106,117 @@ class Criterion(object):
 
         return loss_dfl
     
+    # ----------------- Loss with TAL assigner -----------------
+    def tal_loss(self, outputs, targets):
+        """ Compute loss with TAL assigner """
+        bs = outputs['pred_cls'][0].shape[0]
+        device = outputs['pred_cls'][0].device
+        anchors = torch.cat(outputs['anchors'], dim=0)
+        num_anchors = anchors.shape[0]
+        # preds: [B, M, C]
+        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
+        reg_preds = torch.cat(outputs['pred_reg'], dim=1)
+        box_preds = torch.cat(outputs['pred_box'], dim=1)
+
+        # --------------- label assignment ---------------
+        gt_label_targets = []
+        gt_score_targets = []
+        gt_bbox_targets = []
+        fg_masks = []
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)
+
+            # check target
+            if len(tgt_labels) == 0 or tgt_bboxes.max().item() == 0.:
+                # There is no valid gt
+                fg_mask = cls_preds.new_zeros(1, num_anchors).bool()               #[1, M,]
+                gt_label = cls_preds.new_zeros((1, num_anchors,))                  #[1, M,]
+                gt_score = cls_preds.new_zeros((1, num_anchors, self.num_classes)) #[1, M, C]
+                gt_box = cls_preds.new_zeros((1, num_anchors, 4))                  #[1, M, 4]
+            else:
+                tgt_labels = tgt_labels[None, :, None]      # [1, Mp, 1]
+                tgt_bboxes = tgt_bboxes[None]                   # [1, Mp, 4]
+                (
+                    gt_label,   #[1, M]
+                    gt_box,     #[1, M, 4]
+                    gt_score,   #[1, M, C]
+                    fg_mask,    #[1, M,]
+                    _
+                ) = self.tal_matcher(
+                    pd_scores = cls_preds[batch_idx:batch_idx+1].detach().sigmoid(), 
+                    pd_bboxes = box_preds[batch_idx:batch_idx+1].detach(),
+                    anc_points = anchors,
+                    gt_labels = tgt_labels,
+                    gt_bboxes = tgt_bboxes
+                    )
+            gt_label_targets.append(gt_label)
+            gt_score_targets.append(gt_score)
+            gt_bbox_targets.append(gt_box)
+            fg_masks.append(fg_mask)
+
+        # List[B, 1, M, C] -> Tensor[B, M, C] -> Tensor[BM, C]
+        fg_masks = torch.cat(fg_masks, 0).view(-1)                                    # [BM,]
+        gt_score_targets = torch.cat(gt_score_targets, 0).view(-1, self.num_classes)  # [BM, C]
+        gt_bbox_targets = torch.cat(gt_bbox_targets, 0).view(-1, 4)                   # [BM, 4]
+        gt_label_targets = torch.cat(gt_label_targets, 0).view(-1)                    # [BM,]
+        gt_label_targets = torch.where(fg_masks > 0, gt_label_targets, torch.full_like(gt_label_targets, self.num_classes))
+        gt_labels_one_hot = F.one_hot(gt_label_targets.long(), self.num_classes + 1)[..., :-1]
+        bbox_weight = gt_score_targets[fg_masks].sum(-1)
+        num_fgs = max(gt_score_targets.sum(), 1)
+
+        # average loss normalizer across all the GPUs
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = max(num_fgs / get_world_size(), 1.0)
+
+        # update loss normalizer with EMA
+        if self.use_ema_update:
+            normalizer = self.ema_update("loss_normalizer", max(num_fgs, 1), 100)
+        else:
+            normalizer = num_fgs
+
+        # ------------------ Classification loss ------------------
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        loss_cls = self.loss_classes(cls_preds, gt_score_targets, gt_labels_one_hot, vfl=False)
+        loss_cls = loss_cls.sum() / normalizer
+
+        # ------------------ Regression loss ------------------
+        box_preds_pos = box_preds.view(-1, 4)[fg_masks]
+        box_targets_pos = gt_bbox_targets[fg_masks]
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
+        loss_box = loss_box.sum() / normalizer
+
+        # ------------------ Distribution focal loss  ------------------
+        ## process anchors
+        anchors = anchors[None].repeat(bs, 1, 1).view(-1, 2)
+        ## process stride tensors
+        strides = torch.cat(outputs['stride_tensor'], dim=0)
+        strides = strides.unsqueeze(0).repeat(bs, 1, 1).view(-1, 1)
+        ## fg preds
+        reg_preds_pos = reg_preds.view(-1, 4*self.cfg['reg_max'])[fg_masks]
+        anchors_pos = anchors[fg_masks]
+        strides_pos = strides[fg_masks]
+        ## compute dfl
+        loss_dfl = self.loss_dfl(reg_preds_pos, box_targets_pos, anchors_pos, strides_pos, bbox_weight)
+        loss_dfl = loss_dfl.sum() / normalizer
+
+        # total loss
+        losses = self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box + \
+                 self.loss_dfl_weight * loss_dfl
+
+        loss_dict = dict(
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                loss_dfl = loss_dfl,
+                losses = losses
+        )
+
+        return loss_dict
+    
     # ----------------- Loss with SimOTA assigner -----------------
-    def __call__(self, outputs, targets, epoch=0):
+    def ota_loss(self, outputs, targets):
         """ Compute loss with SimOTA assigner """
         bs = outputs['pred_cls'][0].shape[0]
         device = outputs['pred_cls'][0].device

+ 194 - 1
models/detectors/rtcdet_v2/matcher.py

@@ -1,8 +1,201 @@
 import torch
+import torch.nn as nn
 import torch.nn.functional as F
-from utils.box_ops import box_iou
+from utils.box_ops import box_iou, bbox_iou
 
 
+# -------------------------- Basic Functions --------------------------
+def select_candidates_in_gts(xy_centers, gt_bboxes, eps=1e-9):
+    """select the positive anchors's center in gt
+    Args:
+        xy_centers (Tensor): shape(bs*n_max_boxes, num_total_anchors, 4)
+        gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    n_anchors = xy_centers.size(0)
+    bs, n_max_boxes, _ = gt_bboxes.size()
+    _gt_bboxes = gt_bboxes.reshape([-1, 4])
+    xy_centers = xy_centers.unsqueeze(0).repeat(bs * n_max_boxes, 1, 1)
+    gt_bboxes_lt = _gt_bboxes[:, 0:2].unsqueeze(1).repeat(1, n_anchors, 1)
+    gt_bboxes_rb = _gt_bboxes[:, 2:4].unsqueeze(1).repeat(1, n_anchors, 1)
+    b_lt = xy_centers - gt_bboxes_lt
+    b_rb = gt_bboxes_rb - xy_centers
+    bbox_deltas = torch.cat([b_lt, b_rb], dim=-1)
+    bbox_deltas = bbox_deltas.reshape([bs, n_max_boxes, n_anchors, -1])
+    return (bbox_deltas.min(axis=-1)[0] > eps).to(gt_bboxes.dtype)
+
+def select_highest_overlaps(mask_pos, overlaps, n_max_boxes):
+    """if an anchor box is assigned to multiple gts,
+        the one with the highest iou will be selected.
+    Args:
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+        overlaps (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    Return:
+        target_gt_idx (Tensor): shape(bs, num_total_anchors)
+        fg_mask (Tensor): shape(bs, num_total_anchors)
+        mask_pos (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    fg_mask = mask_pos.sum(axis=-2)
+    if fg_mask.max() > 1:
+        mask_multi_gts = (fg_mask.unsqueeze(1) > 1).repeat([1, n_max_boxes, 1])
+        max_overlaps_idx = overlaps.argmax(axis=1)
+        is_max_overlaps = F.one_hot(max_overlaps_idx, n_max_boxes)
+        is_max_overlaps = is_max_overlaps.permute(0, 2, 1).to(overlaps.dtype)
+        mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos)
+        fg_mask = mask_pos.sum(axis=-2)
+    target_gt_idx = mask_pos.argmax(axis=-2)
+    return target_gt_idx, fg_mask , mask_pos
+
+def iou_calculator(box1, box2, eps=1e-9):
+    """Calculate iou for batch
+    Args:
+        box1 (Tensor): shape(bs, n_max_boxes, 1, 4)
+        box2 (Tensor): shape(bs, 1, num_total_anchors, 4)
+    Return:
+        (Tensor): shape(bs, n_max_boxes, num_total_anchors)
+    """
+    box1 = box1.unsqueeze(2)  # [N, M1, 4] -> [N, M1, 1, 4]
+    box2 = box2.unsqueeze(1)  # [N, M2, 4] -> [N, 1, M2, 4]
+    px1y1, px2y2 = box1[:, :, :, 0:2], box1[:, :, :, 2:4]
+    gx1y1, gx2y2 = box2[:, :, :, 0:2], box2[:, :, :, 2:4]
+    x1y1 = torch.maximum(px1y1, gx1y1)
+    x2y2 = torch.minimum(px2y2, gx2y2)
+    overlap = (x2y2 - x1y1).clip(0).prod(-1)
+    area1 = (px2y2 - px1y1).clip(0).prod(-1)
+    area2 = (gx2y2 - gx1y1).clip(0).prod(-1)
+    union = area1 + area2 - overlap + eps
+
+    return overlap / union
+
+
+# -------------------------- Task Aligned Assigner --------------------------
+class TaskAlignedAssigner(nn.Module):
+    def __init__(self, topk=10, alpha=0.5, beta=6.0, eps=1e-9, num_classes=80):
+        super(TaskAlignedAssigner, self).__init__()
+        self.topk = topk
+        self.num_classes = num_classes
+        self.bg_idx = num_classes
+        self.alpha = alpha
+        self.beta = beta
+        self.eps = eps
+
+    @torch.no_grad()
+    def forward(self,
+                pd_scores,
+                pd_bboxes,
+                anc_points,
+                gt_labels,
+                gt_bboxes):
+        """This code referenced to
+           https://github.com/Nioolek/PPYOLOE_pytorch/blob/master/ppyoloe/assigner/tal_assigner.py
+        Args:
+            pd_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            pd_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            anc_points (Tensor): shape(num_total_anchors, 2)
+            gt_labels (Tensor): shape(bs, n_max_boxes, 1)
+            gt_bboxes (Tensor): shape(bs, n_max_boxes, 4)
+        Returns:
+            target_labels (Tensor): shape(bs, num_total_anchors)
+            target_bboxes (Tensor): shape(bs, num_total_anchors, 4)
+            target_scores (Tensor): shape(bs, num_total_anchors, num_classes)
+            fg_mask (Tensor): shape(bs, num_total_anchors)
+        """
+        self.bs = pd_scores.size(0)
+        self.n_max_boxes = gt_bboxes.size(1)
+
+        mask_pos, align_metric, overlaps = self.get_pos_mask(
+            pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points)
+
+        target_gt_idx, fg_mask, mask_pos = select_highest_overlaps(
+            mask_pos, overlaps, self.n_max_boxes)
+
+        # assigned target
+        target_labels, target_bboxes, target_scores = self.get_targets(
+            gt_labels, gt_bboxes, target_gt_idx, fg_mask)
+
+        # normalize
+        align_metric *= mask_pos
+        pos_align_metrics = align_metric.amax(axis=-1, keepdim=True)  # b, max_num_obj
+        pos_overlaps = (overlaps * mask_pos).amax(axis=-1, keepdim=True)  # b, max_num_obj
+        norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
+        target_scores = target_scores * norm_align_metric
+
+        return target_labels, target_bboxes, target_scores, fg_mask.bool(), target_gt_idx
+
+
+    def get_pos_mask(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes, anc_points):
+        # get anchor_align metric, (b, max_num_obj, h*w)
+        align_metric, overlaps = self.get_box_metrics(pd_scores, pd_bboxes, gt_labels, gt_bboxes)
+        # get in_gts mask, (b, max_num_obj, h*w)
+        mask_in_gts = select_candidates_in_gts(anc_points, gt_bboxes)
+        # get topk_metric mask, (b, max_num_obj, h*w)
+        mask_topk = self.select_topk_candidates(align_metric * mask_in_gts)
+        # merge all mask to a final mask, (b, max_num_obj, h*w)
+        mask_pos = mask_topk * mask_in_gts
+
+        return mask_pos, align_metric, overlaps
+
+
+    def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes):
+        ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long)  # 2, b, max_num_obj
+        ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes)  # b, max_num_obj
+        ind[1] = gt_labels.long().squeeze(-1)  # b, max_num_obj
+        # get the scores of each grid for each gt cls
+        bbox_scores = pd_scores[ind[0], :, ind[1]]  # b, max_num_obj, h*w
+
+        overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False).squeeze(3).clamp(0)
+        align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
+
+        return align_metric, overlaps
+
+
+    def select_topk_candidates(self, metrics, largest=True):
+        """
+        Args:
+            metrics: (b, max_num_obj, h*w).
+            topk_mask: (b, max_num_obj, topk) or None
+        """
+
+        num_anchors = metrics.shape[-1]  # h*w
+        # (b, max_num_obj, topk)
+        topk_metrics, topk_idxs = torch.topk(metrics, self.topk, dim=-1, largest=largest)
+        topk_mask = (topk_metrics.max(-1, keepdim=True)[0] > self.eps).tile([1, 1, self.topk])
+        # (b, max_num_obj, topk)
+        topk_idxs[~topk_mask] = 0
+        # (b, max_num_obj, topk, h*w) -> (b, max_num_obj, h*w)
+        is_in_topk = F.one_hot(topk_idxs, num_anchors).sum(-2)
+        # filter invalid bboxes
+        is_in_topk = torch.where(is_in_topk > 1, 0, is_in_topk)
+        return is_in_topk.to(metrics.dtype)
+
+
+    def get_targets(self, gt_labels, gt_bboxes, target_gt_idx, fg_mask):
+        """
+        Args:
+            gt_labels: (b, max_num_obj, 1)
+            gt_bboxes: (b, max_num_obj, 4)
+            target_gt_idx: (b, h*w)
+            fg_mask: (b, h*w)
+        """
+
+        # assigned target labels, (b, 1)
+        batch_ind = torch.arange(end=self.bs, dtype=torch.int64, device=gt_labels.device)[..., None]
+        target_gt_idx = target_gt_idx + batch_ind * self.n_max_boxes  # (b, h*w)
+        target_labels = gt_labels.long().flatten()[target_gt_idx]  # (b, h*w)
+
+        # assigned target boxes, (b, max_num_obj, 4) -> (b, h*w)
+        target_bboxes = gt_bboxes.view(-1, 4)[target_gt_idx]
+
+        # assigned target scores
+        target_labels.clamp(0)
+        target_scores = F.one_hot(target_labels, self.num_classes)  # (b, h*w, 80)
+        fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.num_classes)  # (b, h*w, 80)
+        target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
+
+        return target_labels, target_bboxes, target_scores
+    
+
 # -------------------------- Aligned SimOTA Assigner --------------------------
 class AlignedSimOTA(object):
     """