Parcourir la source

add yolof in yolo project

yjh0410 il y a 1 an
Parent
commit
a5cdca1ef1

+ 4 - 0
yolo/models/__init__.py

@@ -10,6 +10,7 @@ from .yolov5_af.build import build_yolov5af
 from .yolov7_af.build import build_yolov7af
 from .yolov8.build    import build_yolov8
 from .gelan.build     import build_gelan
+from .yolof.build     import build_yolof
 from .rtdetr.build    import build_rtdetr
 
 
@@ -40,6 +41,9 @@ def build_model(args, cfg, is_val=False):
     ## GElan
     elif 'gelan' in args.model:
         model, criterion = build_gelan(cfg, is_val)
+    ## YOLOF
+    elif 'yolof' in args.model:
+        model, criterion = build_yolof(cfg, is_val)
     ## RT-DETR
     elif 'rtdetr' in args.model:
         model, criterion = build_rtdetr(cfg, is_val)

+ 0 - 0
yolo/models/yolof/README.md


+ 16 - 0
yolo/models/yolof/build.py

@@ -0,0 +1,16 @@
+from .loss import SetCriterion
+from .yolof import Yolof
+
+
+# build object detector
+def build_yolof(cfg, is_val=False):
+    # -------------- Build YOLO --------------
+    model = Yolof(cfg, is_val)
+  
+    # -------------- Build criterion --------------
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+        
+    return model, criterion

+ 131 - 0
yolo/models/yolof/loss.py

@@ -0,0 +1,131 @@
+import torch
+import torch.nn.functional as F
+
+from utils.box_ops import get_ious
+from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
+
+from .matcher import SimOtaMatcher
+
+
+class SetCriterion(object):
+    def __init__(self, cfg):
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        # --------------- Loss config ---------------
+        self.loss_cls_weight = cfg.loss_cls
+        self.loss_box_weight = cfg.loss_box
+        # --------------- Matcher config ---------------
+        self.matcher = SimOtaMatcher(soft_center_radius = cfg.ota_soft_center_radius,
+                                     topk_candidates    = cfg.ota_topk_candidates,
+                                     num_classes        = cfg.num_classes,
+                                     )
+
+    def loss_classes(self, pred_cls, target, beta=2.0):
+        # Quality FocalLoss
+        """
+            pred_cls: (torch.Tensor): [N, C]。
+            target:   (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N)
+        """
+        label, score = target
+        pred_sigmoid = pred_cls.sigmoid()
+        scale_factor = pred_sigmoid
+        zerolabel = scale_factor.new_zeros(pred_cls.shape)
+
+        ce_loss = F.binary_cross_entropy_with_logits(
+            pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
+        
+        bg_class_ind = pred_cls.shape[-1]
+        pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
+        if pos.shape[0] > 0:
+            pos_label = label[pos].long()
+
+            scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
+
+            ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
+                pred_cls[pos, pos_label], score[pos],
+                reduction='none') * scale_factor.abs().pow(beta)
+
+        return ce_loss
+    
+    def loss_bboxes(self, pred_box, gt_box, bbox_weight=None):
+        ious = get_ious(pred_box, gt_box, box_mode="xyxy", iou_type='giou')
+        loss_box = 1.0 - ious
+
+        if bbox_weight is not None:
+            loss_box = loss_box.squeeze(-1) * bbox_weight
+
+        return loss_box
+
+    def __call__(self, outputs, targets):        
+        """
+            outputs['pred_cls']: List(Tensor) [B, M, C]
+            outputs['pred_reg']: List(Tensor) [B, M, 4]
+            outputs['pred_box']: List(Tensor) [B, M, 4]
+            outputs['strides']: List(Int) [8, 16, 32] output stride
+            targets: (List) [dict{'boxes': [...], 
+                                 'labels': [...], 
+                                 'orig_size': ...}, ...]
+        """
+        bs          = outputs['pred_cls'].shape[0]
+        device      = outputs['pred_cls'].device
+        anchors     = outputs['anchors']
+        stride      = outputs['stride']
+        # preds: [B, M, C]
+        cls_preds = outputs['pred_cls']
+        box_preds = outputs['pred_box']
+        
+        # --------------- label assignment ---------------
+        cls_targets = []
+        box_targets = []
+        assign_metrics = []
+        for batch_idx in range(bs):
+            tgt_labels = targets[batch_idx]["labels"].to(device)  # [N,]
+            tgt_bboxes = targets[batch_idx]["boxes"].to(device)   # [N, 4]
+            assigned_result = self.matcher(stride=stride,
+                                           anchors=anchors[..., :2],
+                                           pred_cls=cls_preds[batch_idx].detach(),
+                                           pred_box=box_preds[batch_idx].detach(),
+                                           gt_labels=tgt_labels,
+                                           gt_bboxes=tgt_bboxes
+                                           )
+            cls_targets.append(assigned_result['assigned_labels'])
+            box_targets.append(assigned_result['assigned_bboxes'])
+            assign_metrics.append(assigned_result['assign_metrics'])
+
+        # List[B, M, C] -> Tensor[BM, C]
+        cls_targets = torch.cat(cls_targets, dim=0)
+        box_targets = torch.cat(box_targets, dim=0)
+        assign_metrics = torch.cat(assign_metrics, dim=0)
+
+        # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+        bg_class_ind = self.num_classes
+        pos_inds = ((cls_targets >= 0) & (cls_targets < bg_class_ind)).nonzero().squeeze(1)
+        num_fgs = assign_metrics.sum()
+
+        if is_dist_avail_and_initialized():
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = (num_fgs / get_world_size()).clamp(1.0).item()
+        bbox_weight = assign_metrics[pos_inds]
+
+        # ------------------ Classification loss ------------------
+        cls_preds = cls_preds.view(-1, self.num_classes)
+        loss_cls = self.loss_classes(cls_preds, (cls_targets, assign_metrics))
+        loss_cls = loss_cls.sum() / num_fgs
+
+        # ------------------ Regression loss ------------------
+        box_preds_pos = box_preds.view(-1, 4)[pos_inds]
+        box_targets_pos = box_targets[pos_inds]
+        loss_box = self.loss_bboxes(box_preds_pos, box_targets_pos, bbox_weight)
+        loss_box = loss_box.sum() / num_fgs
+
+        # total loss
+        losses = self.loss_cls_weight * loss_cls + \
+                 self.loss_box_weight * loss_box
+        loss_dict = dict(
+                loss_cls = loss_cls,
+                loss_box = loss_box,
+                losses = losses
+        )
+
+        return loss_dict
+    

+ 160 - 0
yolo/models/yolof/matcher.py

@@ -0,0 +1,160 @@
+# ---------------------------------------------------------------------
+# Copyright (c) Megvii Inc. All rights reserved.
+# ---------------------------------------------------------------------
+
+import math
+import torch
+import torch.nn.functional as F
+
+from utils.box_ops import *
+
+
+class SimOtaMatcher(object):
+    def __init__(self, num_classes, soft_center_radius=3.0, topk_candidates=13):
+        self.num_classes = num_classes
+        self.soft_center_radius = soft_center_radius
+        self.topk_candidates = topk_candidates
+
+    @torch.no_grad()
+    def __call__(self, 
+                 stride, 
+                 anchors, 
+                 pred_cls, 
+                 pred_box,
+                 gt_labels,
+                 gt_bboxes):
+        # List[F, M, 2] -> [M, 2]
+        num_gt = len(gt_labels)
+
+        # check gt
+        if num_gt == 0 or gt_bboxes.max().item() == 0.:
+            return {
+                'assigned_labels': gt_labels.new_full(pred_cls[..., 0].shape,
+                                                      self.num_classes,
+                                                      dtype=torch.long),
+                'assigned_bboxes': gt_bboxes.new_full(pred_box.shape, 0),
+                'assign_metrics': gt_bboxes.new_full(pred_cls[..., 0].shape, 0)
+            }
+        
+        # get inside points: [N, M]
+        is_in_gt = self.find_inside_points(gt_bboxes, anchors)
+        valid_mask = is_in_gt.sum(dim=0) > 0  # [M,]
+
+        # ----------------------------------- soft center prior -----------------------------------
+        gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
+        distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1)
+                    ).pow(2).sum(-1).sqrt() / stride  # [N, M]
+        distance = distance * valid_mask.unsqueeze(0)
+        soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
+
+        # ----------------------------------- regression cost -----------------------------------
+        pair_wise_ious, _ = box_iou(gt_bboxes, pred_box)  # [N, M]
+        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) * 3.0
+
+        # ----------------------------------- classification cost -----------------------------------
+        ## select the predicted scores corresponded to the gt_labels
+        pairwise_pred_scores = pred_cls.permute(1, 0)  # [M, C] -> [C, M]
+        pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float()   # [N, M]
+        ## scale factor
+        scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0)
+        ## cls cost
+        pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
+            pairwise_pred_scores, pair_wise_ious,
+            reduction="none") * scale_factor # [N, M]
+            
+        del pairwise_pred_scores
+
+        ## foreground cost matrix
+        cost_matrix = pair_wise_cls_loss + pair_wise_ious_loss + soft_center_prior
+        max_pad_value = torch.ones_like(cost_matrix) * 1e9
+        cost_matrix = torch.where(valid_mask[None].repeat(num_gt, 1),   # [N, M]
+                                  cost_matrix, max_pad_value)
+
+        # ----------------------------------- dynamic label assignment -----------------------------------
+        matched_pred_ious, matched_gt_inds, fg_mask_inboxes = self.dynamic_k_matching(
+            cost_matrix, pair_wise_ious, num_gt)
+        del pair_wise_cls_loss, cost_matrix, pair_wise_ious, pair_wise_ious_loss
+
+        # -----------------------------------process assigned labels -----------------------------------
+        assigned_labels = gt_labels.new_full(pred_cls[..., 0].shape,
+                                             self.num_classes)  # [M,]
+        assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].squeeze(-1)
+        assigned_labels = assigned_labels.long()  # [M,]
+
+        assigned_bboxes = gt_bboxes.new_full(pred_box.shape, 0)        # [M, 4]
+        assigned_bboxes[fg_mask_inboxes] = gt_bboxes[matched_gt_inds]  # [M, 4]
+
+        assign_metrics = gt_bboxes.new_full(pred_cls[..., 0].shape, 0) # [M,]
+        assign_metrics[fg_mask_inboxes] = matched_pred_ious            # [M,]
+
+        assigned_dict = dict(
+            assigned_labels=assigned_labels,
+            assigned_bboxes=assigned_bboxes,
+            assign_metrics=assign_metrics
+            )
+        
+        return assigned_dict
+
+    def find_inside_points(self, gt_bboxes, anchors):
+        """
+            gt_bboxes: Tensor -> [N, 2]
+            anchors:   Tensor -> [M, 2]
+        """
+        num_anchors = anchors.shape[0]
+        num_gt = gt_bboxes.shape[0]
+
+        anchors_expand = anchors.unsqueeze(0).repeat(num_gt, 1, 1)           # [N, M, 2]
+        gt_bboxes_expand = gt_bboxes.unsqueeze(1).repeat(1, num_anchors, 1)  # [N, M, 4]
+
+        # offset
+        lt = anchors_expand - gt_bboxes_expand[..., :2]
+        rb = gt_bboxes_expand[..., 2:] - anchors_expand
+        bbox_deltas = torch.cat([lt, rb], dim=-1)
+
+        is_in_gts = bbox_deltas.min(dim=-1).values > 0
+
+        return is_in_gts
+    
+    def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
+        """Use IoU and matching cost to calculate the dynamic top-k positive
+        targets.
+
+        Args:
+            cost_matrix (Tensor): Cost matrix.
+            pairwise_ious (Tensor): Pairwise iou matrix.
+            num_gt (int): Number of gt.
+            valid_mask (Tensor): Mask for valid bboxes.
+        Returns:
+            tuple: matched ious and gt indexes.
+        """
+        matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
+        # select candidate topk ious for dynamic-k calculation
+        # candidate_topk = min(self.topk_candidates, pairwise_ious.size(1))
+        candidate_topk = self.topk_candidates
+        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
+        # calculate dynamic k for each gt
+        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
+
+        # sorting the batch cost matirx is faster than topk
+        _, sorted_indices = torch.sort(cost_matrix, dim=1)
+        for gt_idx in range(num_gt):
+            topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]]
+            matching_matrix[gt_idx, :][topk_ids] = 1
+
+        del topk_ious, dynamic_ks, topk_ids
+
+        prior_match_gt_mask = matching_matrix.sum(0) > 1
+        if prior_match_gt_mask.sum() > 0:
+            cost_min, cost_argmin = torch.min(
+                cost_matrix[:, prior_match_gt_mask], dim=0)
+            matching_matrix[:, prior_match_gt_mask] *= 0
+            matching_matrix[cost_argmin, prior_match_gt_mask] = 1
+
+        # get foreground mask inside box and center prior
+        fg_mask_inboxes = matching_matrix.sum(0) > 0
+        matched_pred_ious = (matching_matrix *
+                             pairwise_ious).sum(0)[fg_mask_inboxes]
+        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
+
+        return matched_pred_ious, matched_gt_inds, fg_mask_inboxes
+        

+ 139 - 0
yolo/models/yolof/yolof.py

@@ -0,0 +1,139 @@
+# --------------- Torch components ---------------
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from .yolof_backbone  import YolofBackbone
+from .yolof_upsampler import YolofUpsampler
+from .yolof_encoder   import YolofEncoder
+from .yolof_decoder   import YolofDecoder
+
+# --------------- External components ---------------
+from utils.misc import multiclass_nms
+
+
+# Yolof
+class Yolof(nn.Module):
+    def __init__(self,
+                 cfg,
+                 is_val = False,
+                 ) -> None:
+        super(Yolof, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
+        self.no_multi_labels  = False if is_val else True
+        
+        # ---------------------- Network Parameters ----------------------
+        self.backbone  = YolofBackbone(cfg)
+        self.upsampler = YolofUpsampler(cfg, self.backbone.feat_dims[-1], cfg.head_dim)
+        self.encoder   = YolofEncoder(cfg, cfg.head_dim, cfg.head_dim)
+        self.decoder   = YolofDecoder(cfg, self.encoder.out_dim)
+
+    def post_process(self, cls_preds, box_preds):
+        """
+        We process predictions at each scale hierarchically
+        Input:
+            cls_preds: List[torch.Tensor] -> [[B, M, C], ...], B=1
+            box_preds: List[torch.Tensor] -> [[B, M, 4], ...], B=1
+        Output:
+            bboxes: np.array -> [N, 4]
+            scores: np.array -> [N,]
+            labels: np.array -> [N,]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for cls_pred_i, box_pred_i in zip(cls_preds, box_preds):
+            cls_pred_i = cls_pred_i[0]
+            box_pred_i = box_pred_i[0]
+            if self.no_multi_labels:
+                # [M,]
+                scores, labels = torch.max(cls_pred_i.sigmoid(), dim=1)
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # topk candidates
+                predicted_prob, topk_idxs = scores.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                labels = labels[topk_idxs]
+                bboxes = box_pred_i[topk_idxs]
+            else:
+                # [M, C] -> [MC,]
+                scores_i = cls_pred_i.sigmoid().flatten()
+
+                # Keep top k top scoring indices only.
+                num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+                # torch.sort is actually faster than .topk (at least on GPUs)
+                predicted_prob, topk_idxs = scores_i.sort(descending=True)
+                topk_scores = predicted_prob[:num_topk]
+                topk_idxs = topk_idxs[:num_topk]
+
+                # filter out the proposals with low confidence score
+                keep_idxs = topk_scores > self.conf_thresh
+                scores = topk_scores[keep_idxs]
+                topk_idxs = topk_idxs[keep_idxs]
+
+                anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+                labels = topk_idxs % self.num_classes
+
+                bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores, dim=0)
+        labels = torch.cat(all_labels, dim=0)
+        bboxes = torch.cat(all_bboxes, dim=0)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes)
+        
+        return bboxes, scores, labels
+    
+    def forward(self, x):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(x)
+
+        # ---------------- Encoder ----------------
+        x = self.upsampler(pyramid_feats[-1])
+        x = self.encoder(x)
+
+        # ---------------- Decoder ----------------
+        outputs = self.decoder(x)
+        outputs['image_size'] = [x.shape[2], x.shape[3]]
+
+        if not self.training:
+            all_cls_preds = [outputs['pred_cls'],]
+            all_box_preds = [outputs['pred_box'],]
+
+            # post process
+            bboxes, scores, labels = self.post_process(all_cls_preds, all_box_preds)
+            outputs = {
+                "scores": scores,
+                "labels": labels,
+                "bboxes": bboxes
+            }
+        
+        return outputs 

+ 181 - 0
yolo/models/yolof/yolof_backbone.py

@@ -0,0 +1,181 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolof_basic import BasicConv, ELANLayer
+except:
+    from  yolof_basic import BasicConv, ELANLayer
+
+# IN1K pretrained weight
+pretrained_urls = {
+    'n': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/elandarknet_n_in1k_62.1.pth",
+    's': "https://github.com/yjh0410/ICLab/releases/download/in1k_pretrained/elandarknet_s_in1k_71.3.pth",
+    'm': None,
+    'l': None,
+    'x': None,
+}
+
+# ---------------------------- Basic functions ----------------------------
+class YolofBackbone(nn.Module):
+    def __init__(self, cfg):
+        super(YolofBackbone, self).__init__()
+        # ------------------ Basic setting ------------------
+        self.model_scale = cfg.scale
+        self.feat_dims = [round(64  * cfg.width),
+                          round(128 * cfg.width),
+                          round(256 * cfg.width),
+                          round(512 * cfg.width),
+                          round(512 * cfg.width * cfg.ratio)]
+        
+        # ------------------ Network setting ------------------
+        ## P1/2
+        self.layer_1 = BasicConv(3, self.feat_dims[0],
+                                 kernel_size=3, padding=1, stride=2,
+                                 act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise)
+        # P2/4
+        self.layer_2 = nn.Sequential(
+            BasicConv(self.feat_dims[0], self.feat_dims[1],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[1],
+                      out_dim    = self.feat_dims[1],
+                      num_blocks = round(3*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+        # P3/8
+        self.layer_3 = nn.Sequential(
+            BasicConv(self.feat_dims[1], self.feat_dims[2],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[2],
+                      out_dim    = self.feat_dims[2],
+                      num_blocks = round(6*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+        # P4/16
+        self.layer_4 = nn.Sequential(
+            BasicConv(self.feat_dims[2], self.feat_dims[3],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[3],
+                      out_dim    = self.feat_dims[3],
+                      num_blocks = round(6*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+        # P5/32
+        self.layer_5 = nn.Sequential(
+            BasicConv(self.feat_dims[3], self.feat_dims[4],
+                      kernel_size=3, padding=1, stride=2,
+                      act_type=cfg.bk_act, norm_type=cfg.bk_norm, depthwise=cfg.bk_depthwise),
+            ELANLayer(in_dim     = self.feat_dims[4],
+                      out_dim    = self.feat_dims[4],
+                      num_blocks = round(3*cfg.depth),
+                      expansion  = 0.5,
+                      shortcut   = True,
+                      act_type   = cfg.bk_act,
+                      norm_type  = cfg.bk_norm,
+                      depthwise  = cfg.bk_depthwise)
+        )
+
+        # Initialize all layers
+        self.init_weights()
+        
+        # Load imagenet pretrained weight
+        if cfg.use_pretrained:
+            self.load_pretrained()
+        
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def load_pretrained(self):
+        url = pretrained_urls[self.model_scale]
+        if url is not None:
+            print('Loading backbone pretrained weight from : {}'.format(url))
+            # checkpoint state dict
+            checkpoint = torch.hub.load_state_dict_from_url(
+                url=url, map_location="cpu", check_hash=True)
+            checkpoint_state_dict = checkpoint.pop("model")
+            # model state dict
+            model_state_dict = self.state_dict()
+            # check
+            for k in list(checkpoint_state_dict.keys()):
+                if k in model_state_dict:
+                    shape_model = tuple(model_state_dict[k].shape)
+                    shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
+                    if shape_model != shape_checkpoint:
+                        checkpoint_state_dict.pop(k)
+                else:
+                    checkpoint_state_dict.pop(k)
+                    print('Unused key: ', k)
+            # load the weight
+            self.load_state_dict(checkpoint_state_dict)
+        else:
+            print('No pretrained weight for model scale: {}.'.format(self.model_scale))
+
+    def forward(self, x):
+        c1 = self.layer_1(x)
+        c2 = self.layer_2(c1)
+        c3 = self.layer_3(c2)
+        c4 = self.layer_4(c3)
+        c5 = self.layer_5(c4)
+        outputs = [c3, c4, c5]
+
+        return outputs
+
+
+# ---------------------------- Functions ----------------------------
+## build Yolo's Backbone
+def build_backbone(cfg): 
+    # model
+    backbone = YolofBackbone(cfg)
+        
+    return backbone
+
+
+if __name__ == '__main__':
+    import time
+    from thop import profile
+    class BaseConfig(object):
+        def __init__(self) -> None:
+            self.bk_act = 'silu'
+            self.bk_norm = 'BN'
+            self.bk_depthwise = False
+            self.use_pretrained = True
+            self.width = 0.50
+            self.depth = 0.34
+            self.ratio = 2.0
+            self.scale = "s"
+
+    cfg = BaseConfig()
+    model = build_backbone(cfg)
+    x = torch.randn(1, 3, 640, 640)
+    t0 = time.time()
+    outputs = model(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    for out in outputs:
+        print(out.shape)
+
+    x = torch.randn(1, 3, 640, 640)
+    print('==============================')
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 172 - 0
yolo/models/yolof/yolof_basic.py

@@ -0,0 +1,172 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# --------------------- Basic modules ---------------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class BasicConv(nn.Module):
+    def __init__(self, 
+                 in_dim,                   # in channels
+                 out_dim,                  # out channels 
+                 kernel_size=1,            # kernel size 
+                 padding=0,                # padding
+                 stride=1,                 # padding
+                 dilation=1,               # dilation
+                 act_type  :str = 'lrelu', # activation
+                 norm_type :str = 'BN',    # normalization
+                 depthwise :bool = False
+                ):
+        super(BasicConv, self).__init__()
+        self.depthwise = depthwise
+        use_bias = False if norm_type is not None else True
+        if not depthwise:
+            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1, bias=use_bias)
+            self.norm = get_norm(norm_type, out_dim)
+        else:
+            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=use_bias)
+            self.norm1 = get_norm(norm_type, in_dim)
+            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
+            self.norm2 = get_norm(norm_type, out_dim)
+        self.act  = get_activation(act_type)
+
+    def forward(self, x):
+        if not self.depthwise:
+            return self.act(self.norm(self.conv(x)))
+        else:
+            # Depthwise conv
+            x = self.act(self.norm1(self.conv1(x)))
+            # Pointwise conv
+            x = self.act(self.norm2(self.conv2(x)))
+            return x
+
+
+# --------------------- Yolov8 modules ---------------------
+class YoloBottleneck(nn.Module):
+    def __init__(self,
+                 in_dim      :int,
+                 out_dim     :int,
+                 kernel_size :List  = [1, 3],
+                 expansion   :float = 0.5,
+                 shortcut    :bool  = False,
+                 act_type    :str   = 'silu',
+                 norm_type   :str   = 'BN',
+                 depthwise   :bool  = False,
+                 ) -> None:
+        super(YoloBottleneck, self).__init__()
+        inter_dim = int(out_dim * expansion)
+        # ----------------- Network setting -----------------
+        self.conv_layer1 = BasicConv(in_dim, inter_dim,
+                                     kernel_size=kernel_size[0], padding=kernel_size[0]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.conv_layer2 = BasicConv(inter_dim, out_dim,
+                                     kernel_size=kernel_size[1], padding=kernel_size[1]//2, stride=1,
+                                     act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.conv_layer2(self.conv_layer1(x))
+
+        return x + h if self.shortcut else h
+
+class CSPLayer(nn.Module):
+    # CSP Bottleneck with 3 convolutions
+    def __init__(self,
+                 in_dim      :int,
+                 out_dim     :int,
+                 num_blocks  :int   = 1,
+                 kernel_size :List = [3, 3],
+                 expansion   :float = 0.5,
+                 shortcut    :bool  = True,
+                 act_type    :str   = 'silu',
+                 norm_type   :str   = 'BN',
+                 depthwise   :bool  = False,
+                 ) -> None:
+        super().__init__()
+        inter_dim = round(out_dim * expansion)
+        self.input_proj_1 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.input_proj_2 = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.output_proj  = BasicConv(2 * inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.module       = nn.Sequential(*[YoloBottleneck(inter_dim,
+                                                           inter_dim,
+                                                           kernel_size,
+                                                           expansion   = 1.0,
+                                                           shortcut    = shortcut,
+                                                           act_type    = act_type,
+                                                           norm_type   = norm_type,
+                                                           depthwise   = depthwise,
+                                                           ) for _ in range(num_blocks)])
+
+    def forward(self, x):
+        x1 = self.input_proj_1(x)
+        x2 = self.input_proj_2(x)
+        x2 = self.module(x2)
+        out = self.output_proj(torch.cat([x1, x2], dim=1))
+
+        return out
+
+class ELANLayer(nn.Module):
+    def __init__(self,
+                 in_dim,
+                 out_dim,
+                 expansion  :float = 0.5,
+                 num_blocks :int   = 1,
+                 shortcut   :bool  = False,
+                 act_type   :str   = 'silu',
+                 norm_type  :str   = 'BN',
+                 depthwise  :bool  = False,
+                 ) -> None:
+        super(ELANLayer, self).__init__()
+        inter_dim = round(out_dim * expansion)
+        self.input_proj  = BasicConv(in_dim, inter_dim * 2, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.output_proj = BasicConv((2 + num_blocks) * inter_dim, out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+        self.module      = nn.ModuleList([YoloBottleneck(inter_dim,
+                                                         inter_dim,
+                                                         kernel_size = [3, 3],
+                                                         expansion   = 1.0,
+                                                         shortcut    = shortcut,
+                                                         act_type    = act_type,
+                                                         norm_type   = norm_type,
+                                                         depthwise   = depthwise)
+                                                         for _ in range(num_blocks)])
+
+    def forward(self, x):
+        # Input proj
+        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
+        out = list([x1, x2])
+
+        # Bottlenecl
+        out.extend(m(out[-1]) for m in self.module)
+
+        # Output proj
+        out = self.output_proj(torch.cat(out, dim=1))
+
+        return out

+ 150 - 0
yolo/models/yolof/yolof_decoder.py

@@ -0,0 +1,150 @@
+import math
+import torch
+import torch.nn as nn
+
+try:
+    from .yolof_basic import BasicConv
+except:
+    from  yolof_basic import BasicConv
+    
+
+class YolofDecoder(nn.Module):
+    def __init__(self, cfg, in_dim):
+        super().__init__()
+        # ------------------ Basic parameters -------------------
+        self.cfg = cfg
+        self.in_dim = in_dim
+        self.stride       = cfg.out_stride
+        self.num_classes  = cfg.num_classes
+        self.num_cls_head = cfg.num_cls_head
+        self.num_reg_head = cfg.num_reg_head
+        # Anchor config
+        self.anchor_size = torch.as_tensor(cfg.anchor_size)
+        self.num_anchors = len(cfg.anchor_size)
+
+        # ------------------ Network parameters -------------------
+        ## cls head
+        cls_heads = []
+        self.cls_head_dim = cfg.head_dim
+        for i in range(self.num_cls_head):
+            if i == 0:
+                cls_heads.append(
+                    BasicConv(in_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=cfg.head_act, norm_type=cfg.head_norm, depthwise=cfg.head_depthwise)
+                              )
+            else:
+                cls_heads.append(
+                    BasicConv(self.cls_head_dim, self.cls_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=cfg.head_act, norm_type=cfg.head_norm, depthwise=cfg.head_depthwise)
+                              )
+        ## reg head
+        reg_heads = []
+        self.reg_head_dim = cfg.head_dim
+        for i in range(self.num_reg_head):
+            if i == 0:
+                reg_heads.append(
+                    BasicConv(in_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=cfg.head_act, norm_type=cfg.head_norm, depthwise=cfg.head_depthwise)
+                              )
+            else:
+                reg_heads.append(
+                    BasicConv(self.reg_head_dim, self.reg_head_dim,
+                              kernel_size=3, padding=1, stride=1, 
+                              act_type=cfg.head_act, norm_type=cfg.head_norm, depthwise=cfg.head_depthwise)
+                              )
+        self.cls_heads = nn.Sequential(*cls_heads)
+        self.reg_heads = nn.Sequential(*reg_heads)
+
+        # pred layer
+        self.cls_pred = nn.Conv2d(self.cls_head_dim, self.num_classes * self.num_anchors, kernel_size=1)
+        self.reg_pred = nn.Conv2d(self.reg_head_dim, 4 * self.num_anchors, kernel_size=1)
+
+        self.init_weights()
+        
+    def init_weights(self):
+        # Init bias
+        init_prob = 0.01
+        bias_value = -torch.log(torch.tensor((1. - init_prob) / init_prob))
+        # cls pred
+        b = self.cls_pred.bias.view(1, -1)
+        b.data.fill_(bias_value.item())
+        self.cls_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        # reg pred
+        b = self.reg_pred.bias.view(-1, )
+        b.data.fill_(1.0)
+        self.reg_pred.bias = torch.nn.Parameter(b.view(-1), requires_grad=True)
+        w = self.reg_pred.weight
+        w.data.fill_(0.)
+        self.reg_pred.weight = torch.nn.Parameter(w, requires_grad=True)
+
+    def generate_anchors(self, fmp_size):
+        """
+            fmp_size: (List) [H, W]
+        """
+        # 特征图的宽和高
+        fmp_h, fmp_w = fmp_size
+
+        # 生成网格的x坐标和y坐标
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+
+        # 将xy两部分的坐标拼起来:[H, W, 2] -> [HW, 2]
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        # [HW, 2] -> [HW, A, 2] -> [M, 2], M=HWA
+        anchor_xy = anchor_xy.unsqueeze(1).repeat(1, self.num_anchors, 1)
+        anchor_xy = anchor_xy.view(-1, 2) + 0.5
+        anchor_xy *= self.stride
+
+        # [A, 2] -> [1, A, 2] -> [HW, A, 2] -> [M, 2], M=HWA
+        anchor_wh = self.anchor_size.unsqueeze(0).repeat(fmp_h*fmp_w, 1, 1)
+        anchor_wh = anchor_wh.view(-1, 2)
+
+        anchors = torch.cat([anchor_xy, anchor_wh], dim=-1)
+
+        return anchors
+        
+    def decode_boxes(self, anchors, reg_pred):
+        """
+            anchors:  (List[tensor]) [1, M, 4]
+            reg_pred: (List[tensor]) [B, M, 4]
+        """
+        cxcy_pred = anchors[..., :2] + reg_pred[..., :2] * self.stride
+        bwbh_pred = anchors[..., 2:] * torch.exp(reg_pred[..., 2:])
+        pred_x1y1 = cxcy_pred - bwbh_pred * 0.5
+        pred_x2y2 = cxcy_pred + bwbh_pred * 0.5
+        box_pred = torch.cat([pred_x1y1, pred_x2y2], dim=-1)
+
+        return box_pred
+
+    def forward(self, x):
+        # ------------------- Decoupled head -------------------
+        cls_feats = self.cls_heads(x)
+        reg_feats = self.reg_heads(x)
+
+        # ------------------- Prediction -------------------
+        cls_pred = self.cls_pred(cls_feats)
+        reg_pred = self.reg_pred(reg_feats)
+
+        # ------------------- Generate anchor box -------------------
+        B, _, H, W = cls_pred.size()
+        anchors = self.generate_anchors([H, W])   # [M, 4]
+        anchors = anchors.to(cls_feats.device)
+
+        # ------------------- Precoess preds -------------------
+        # [B, C*A, H, W] -> [B, H, W, C*A] -> [B, H*W*A, C]
+        cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
+        reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
+
+        ## Decode bbox
+        box_pred = self.decode_boxes(anchors[None], reg_pred)  # [B, M, 4]
+
+        outputs = {"pred_cls": cls_pred,   # (torch.Tensor) [B, M, C]
+                   "pred_reg": reg_pred,   # (torch.Tensor) [B, M, 4]
+                   "pred_box": box_pred,   # (torch.Tensor) [B, M, 4]
+                   "stride":   self.stride,
+                   "anchors":  anchors,    # (torch.Tensor) [M, C]
+                   }
+
+        return outputs 

+ 88 - 0
yolo/models/yolof/yolof_encoder.py

@@ -0,0 +1,88 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolof_basic import BasicConv
+except:
+    from  yolof_basic import BasicConv
+
+
+# BottleNeck
+class Bottleneck(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 dilation     :int,
+                 expand_ratio :float = 0.5,
+                 shortcut     :bool  = False,
+                 act_type     :str   = 'relu',
+                 norm_type    :str   = 'BN',
+                 depthwise    :bool  = False,):
+        super(Bottleneck, self).__init__()
+        # ------------------ Basic parameters -------------------
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+        self.dilation = dilation
+        self.expand_ratio = expand_ratio
+        self.shortcut = shortcut and in_dim == out_dim
+        inter_dim = round(in_dim * expand_ratio)
+        # ------------------ Network parameters -------------------
+        self.branch = nn.Sequential(
+            BasicConv(in_dim, inter_dim,
+                      kernel_size=1, padding=0, stride=1,
+                      act_type=act_type, norm_type=norm_type),
+            BasicConv(inter_dim, inter_dim,
+                      kernel_size=3, padding=dilation, dilation=dilation, stride=1,
+                      act_type=act_type, norm_type=norm_type, depthwise=depthwise),
+            BasicConv(inter_dim, in_dim,
+                      kernel_size=1, padding=0, stride=1,
+                      act_type=act_type, norm_type=norm_type)
+        )
+
+    def forward(self, x):
+        h = self.branch(x)
+
+        return x + self.branch(x) if self.shortcut else h
+
+# Dilated Encoder
+class YolofEncoder(nn.Module):
+    def __init__(self, cfg, in_dim, out_dim):
+        super(YolofEncoder, self).__init__()
+        # ------------------ Basic parameters -------------------
+        self.in_dim = in_dim
+        self.out_dim = out_dim
+        self.expand_ratio = cfg.neck_expand_ratio
+        self.dilations    = cfg.neck_dilations
+        # ------------------ Network parameters -------------------
+        ## proj layer
+        self.projector = nn.Sequential(
+            BasicConv(in_dim, out_dim, kernel_size=1, act_type=None, norm_type=cfg.neck_norm),
+            BasicConv(out_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=cfg.neck_norm)
+        )
+        ## encoder layers
+        self.encoders = nn.Sequential(*[Bottleneck(in_dim      = out_dim,
+                                                   out_dim     = out_dim,
+                                                   dilation    = d,
+                                                   expand_ratio = self.expand_ratio,
+                                                   shortcut     = True,
+                                                   act_type     = cfg.neck_act,
+                                                   norm_type    = cfg.neck_norm,
+                                                   depthwise    = cfg.neck_depthwise,
+                                                   ) for d in self.dilations])
+
+        # Initialize all layers
+        self.init_weights()
+
+    def init_weights(self):
+        """Initialize the parameters."""
+        for m in self.modules():
+            if isinstance(m, torch.nn.Conv2d):
+                # In order to be consistent with the source code,
+                # reset the Conv2d initialization parameters
+                m.reset_parameters()
+
+    def forward(self, x):
+        x = self.projector(x)
+        x = self.encoders(x)
+
+        return x

+ 29 - 0
yolo/models/yolof/yolof_upsampler.py

@@ -0,0 +1,29 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .yolof_basic import BasicConv
+except:
+    from  yolof_basic import BasicConv
+
+
+class YolofUpsampler(nn.Module):
+    def __init__(self, cfg, in_dim, out_dim):
+        super(YolofUpsampler, self).__init__()
+        # ----------- Basic parameters -----------
+        self.upscale_factor = cfg.upscale_factor
+        inter_dim = self.upscale_factor ** 2 * in_dim
+        # ----------- Model parameters -----------
+        self.input_proj = BasicConv(in_dim, inter_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+        self.output_proj = BasicConv(in_dim, out_dim, kernel_size=1, act_type=cfg.neck_act, norm_type=cfg.neck_norm)
+
+    def forward(self, x):
+        # [B, C, H, W] -> [B, 4*C, H, W]
+        x = self.input_proj(x)
+
+        # [B, 4*C, H, W] -> [B, C, 2*H, 2*W]
+        x = torch.pixel_shuffle(x, upscale_factor=self.upscale_factor)
+        
+        x = self.output_proj(x)
+        
+        return x