Explorar o código

add yolof & fcos

yjh0410 hai 1 ano
pai
achega
03540ca7ae

+ 7 - 0
yolo/config/__init__.py

@@ -9,6 +9,9 @@ from .yolov6_config  import build_yolov6_config
 from .yolov7_config  import build_yolov7_config
 from .yolov8_config  import build_yolov8_config
 from .yolov9_config  import build_yolov9_config
+
+from .yolof_config   import build_yolof_config
+from .fcos_config    import build_fcos_config
 from .rtdetr_config  import build_rtdetr_config
 
 
@@ -38,6 +41,10 @@ def build_config(args):
         cfg = build_yolov9_config(args)
         
     # ----------- RT-DETR -----------
+    elif 'yolof' in args.model:
+        cfg = build_yolof_config(args)
+    elif 'fcos' in args.model:
+        cfg = build_fcos_config(args)
     elif 'rtdetr' in args.model:
         cfg = build_rtdetr_config(args)
 

+ 107 - 0
yolo/config/fcos_config.py

@@ -0,0 +1,107 @@
+# Fcos Config
+
+
+def build_fcos_config(args):
+    if args.model == 'fcos_r18':
+        return FcosR18Config()
+    else:
+        raise NotImplementedError("No config for model: {}".format(args.model))
+    
+# Fcos-Base config
+class FcosBaseConfig(object):
+    def __init__(self) -> None:
+        # ---------------- Model config ----------------
+        self.out_stride = [8, 16, 32, 64]
+        self.max_stride = 64
+        ## Backbone
+        self.backbone = 'resnet50'
+        self.use_pretrained = True
+        ## Head
+        self.head_dim  = 256
+        self.num_cls_head = 4
+        self.num_reg_head = 4
+
+        # ---------------- Post-process config ----------------
+        ## Post process
+        self.val_topk = 1000
+        self.val_conf_thresh = 0.05
+        self.val_nms_thresh  = 0.6
+        self.test_topk = 100
+        self.test_conf_thresh = 0.3
+        self.test_nms_thresh  = 0.45
+
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        self.center_sampling_radius = 1.5
+        self.object_sizes_of_interest = [[-1, 64], [64, 128], [128, 256], [256, 512], [512, float('inf')]]
+
+        ## Loss weight
+        self.focal_loss_alpha = 0.25
+        self.focal_loss_gamma = 2.0
+        self.loss_cls = 1.0
+        self.loss_reg = 1.0
+        self.loss_ctn = 1.0
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema   = True
+        self.ema_decay = 0.9998
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.optimizer    = 'adamw'
+        self.base_lr      = 0.001     # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
+        self.batch_size_base = 64
+        self.momentum     = 0.9
+        self.weight_decay = 0.05
+        self.clip_max_norm   = 35.0
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8
+
+        # ---------------- Lr Scheduler config ----------------
+        self.warmup_epoch = 3
+        self.lr_scheduler = "cosine"
+        self.max_epoch    = 150
+        self.eval_epoch   = 10
+        self.no_aug_epoch = -1
+
+        # ---------------- Data process config ----------------
+        self.aug_type = 'yolo'
+        self.mosaic_prob = 0.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0          # approximated by the YOLOX's mixup
+        self.multi_scale = [0.5, 1.5]   # multi scale: [img_size * 0.5, img_size * 1.5]
+        ## Pixel mean & std
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [255., 255., 255.]
+        ## Transforms
+        self.train_img_size = 640
+        self.test_img_size  = 640
+        self.affine_params = {
+            'degrees': 0.0,
+            'translate': 0.1,
+            'scale': [0.5, 1.5],
+            'shear': 0.0,
+            'perspective': 0.0,
+            'hsv_h': 0.015,
+            'hsv_s': 0.7,
+            'hsv_v': 0.4,
+        }
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))
+
+# YOLOv2-R18
+class FcosR18Config(FcosBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        self.backbone = 'resnet18'
+
+# YOLOv2-R50
+class FcosR50Config(FcosBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # TODO: Try your best.

+ 115 - 0
yolo/config/yolof_config.py

@@ -0,0 +1,115 @@
+# Yolof Config
+
+
+def build_yolof_config(args):
+    if   args.model == 'yolof_r18':
+        return YolofR18Config()
+    elif args.model == 'yolof_r50':
+        return YolofR50Config()
+    else:
+        raise NotImplementedError("No config for model: {}".format(args.model))
+    
+# Fcos-Base config
+class YolofBaseConfig(object):
+    def __init__(self) -> None:
+        # ---------------- Model config ----------------
+        self.out_stride = 32
+        self.max_stride = 32
+        ## Backbone
+        self.backbone = 'resnet50'
+        self.use_pretrained = True
+        ## Encoder
+        self.neck_expand_ratio = 0.25
+        self.neck_dilations = [2, 4, 6, 8]
+        ## Head
+        self.head_dim  = 512
+        self.num_cls_head = 2
+        self.num_reg_head = 4
+
+        # ---------------- Post-process config ----------------
+        ## Post process
+        self.val_topk = 1000
+        self.val_conf_thresh = 0.05
+        self.val_nms_thresh  = 0.6
+        self.test_topk = 300
+        self.test_conf_thresh = 0.3
+        self.test_nms_thresh  = 0.45
+
+        # ---------------- Assignment config ----------------
+        ## Matcher
+        self.center_clamp = 32
+        self.match_topk_candidates = 4
+        self.match_iou_thresh = 0.15
+        self.ignore_thresh = 0.7
+        self.anchor_size  = [[32, 32], [64, 64], [128, 128], [256, 256], [512, 512]]
+
+        ## Loss weight
+        self.focal_loss_alpha = 0.25
+        self.focal_loss_gamma = 2.0
+        self.loss_cls = 1.0
+        self.loss_reg = 1.0
+        self.loss_ctn = 1.0
+
+        # ---------------- ModelEMA config ----------------
+        self.use_ema   = True
+        self.ema_decay = 0.9998
+        self.ema_tau   = 2000
+
+        # ---------------- Optimizer config ----------------
+        self.trainer      = 'yolo'
+        self.optimizer    = 'adamw'
+        self.base_lr      = 0.001     # base_lr = per_image_lr * batch_size
+        self.min_lr_ratio = 0.01      # min_lr  = base_lr * min_lr_ratio
+        self.batch_size_base = 64
+        self.momentum     = 0.9
+        self.weight_decay = 0.05
+        self.clip_max_norm   = 35.0
+        self.warmup_bias_lr  = 0.1
+        self.warmup_momentum = 0.8
+
+        # ---------------- Lr Scheduler config ----------------
+        self.warmup_epoch = 3
+        self.lr_scheduler = "cosine"
+        self.max_epoch    = 150
+        self.eval_epoch   = 10
+        self.no_aug_epoch = -1
+
+        # ---------------- Data process config ----------------
+        self.aug_type = 'yolo'
+        self.mosaic_prob = 0.0
+        self.mixup_prob  = 0.0
+        self.copy_paste  = 0.0          # approximated by the YOLOX's mixup
+        self.multi_scale = [0.5, 1.5]   # multi scale: [img_size * 0.5, img_size * 1.5]
+        ## Pixel mean & std
+        self.pixel_mean = [0., 0., 0.]
+        self.pixel_std  = [255., 255., 255.]
+        ## Transforms
+        self.train_img_size = 640
+        self.test_img_size  = 640
+        self.affine_params = {
+            'degrees': 0.0,
+            'translate': 0.1,
+            'scale': [0.5, 1.5],
+            'shear': 0.0,
+            'perspective': 0.0,
+            'hsv_h': 0.015,
+            'hsv_s': 0.7,
+            'hsv_v': 0.4,
+        }
+
+    def print_config(self):
+        config_dict = {key: value for key, value in self.__dict__.items() if not key.startswith('__')}
+        for k, v in config_dict.items():
+            print("{} : {}".format(k, v))
+
+# YOLOv2-R18
+class YolofR18Config(YolofBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        self.backbone = 'resnet18'
+
+# YOLOv2-R50
+class YolofR50Config(YolofBaseConfig):
+    def __init__(self) -> None:
+        super().__init__()
+        # TODO: Try your best.

+ 10 - 0
yolo/models/__init__.py

@@ -12,6 +12,9 @@ from .yolov6.build import build_yolov6
 from .yolov7.build import build_yolov7
 from .yolov8.build import build_yolov8
 from .yolov9.build import build_gelan
+
+from .yolof.build  import build_yolof
+from .fcos.build   import build_fcos
 from .rtdetr.build import build_rtdetr
 
 
@@ -48,6 +51,13 @@ def build_model(args, cfg, is_val=False):
     ## GElan
     elif 'yolov9' in args.model:
         model, criterion = build_gelan(cfg, is_val)
+
+    ## Yolof
+    elif 'yolof' in args.model:
+        model, criterion = build_yolof(cfg, is_val)
+    ## Fcos
+    elif 'fcos' in args.model:
+        model, criterion = build_fcos(cfg, is_val)
     ## RT-DETR
     elif 'rtdetr' in args.model:
         model, criterion = build_rtdetr(cfg, is_val)

+ 16 - 0
yolo/models/fcos/build.py

@@ -0,0 +1,16 @@
+from .loss import SetCriterion
+from .fcos import Fcos
+
+
+# build object detector
+def build_fcos(cfg, is_val=False):
+    # -------------- Build YOLO --------------
+    model = Fcos(cfg, is_val)
+  
+    # -------------- Build criterion --------------
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+        
+    return model, criterion

+ 115 - 0
yolo/models/fcos/fcos.py

@@ -0,0 +1,115 @@
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from .fcos_backbone import FcosBackbone
+from .fcos_fpn import FcosFPN
+from .fcos_head import FcosHead
+
+# --------------- External components ---------------
+from utils.misc import multiclass_nms
+
+
+# ------------------------ Fully Convolutional One-Stage Detector ------------------------
+class Fcos(nn.Module):
+    def __init__(self, 
+                 cfg,
+                 is_val = False,
+                 ) -> None:
+        super(Fcos, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
+        self.no_multi_labels  = False if is_val else True
+
+        # ---------------------- Network Parameters ----------------------
+        self.backbone = FcosBackbone(cfg)
+        self.fpn      = FcosFPN(cfg, self.backbone.feat_dims[-3:])
+        self.head     = FcosHead(cfg, self.fpn.out_dim)
+
+    def post_process(self, cls_preds, ctn_preds, box_preds):
+        """
+        Input:
+            cls_preds: List(Tensor) [[B, H x W, C], ...]
+            ctn_preds: List(Tensor) [[B, H x W, 1], ...]
+            box_preds: List(Tensor) [[B, H x W, 4], ...]
+        """
+        all_scores = []
+        all_labels = []
+        all_bboxes = []
+        
+        for cls_pred_i, ctn_pred_i, box_pred_i in zip(cls_preds, ctn_preds, box_preds):
+            cls_pred_i = cls_pred_i[0]
+            ctn_pred_i = ctn_pred_i[0]
+            box_pred_i = box_pred_i[0]
+            
+            # (H x W x C,)
+            scores_i = torch.sqrt(cls_pred_i.sigmoid() * ctn_pred_i.sigmoid()).flatten()
+
+            # Keep top k top scoring indices only.
+            num_topk = min(self.topk_candidates, box_pred_i.size(0))
+
+            # torch.sort is actually faster than .topk (at least on GPUs)
+            predicted_prob, topk_idxs = scores_i.sort(descending=True)
+            topk_scores = predicted_prob[:num_topk]
+            topk_idxs = topk_idxs[:num_topk]
+
+            # filter out the proposals with low confidence score
+            keep_idxs = topk_scores > self.conf_thresh
+            topk_idxs = topk_idxs[keep_idxs]
+
+            # final scores
+            scores = topk_scores[keep_idxs]
+            # final labels
+            labels = topk_idxs % self.num_classes
+            # final bboxes
+            anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+            bboxes = box_pred_i[anchor_idxs]
+
+            all_scores.append(scores)
+            all_labels.append(labels)
+            all_bboxes.append(bboxes)
+
+        scores = torch.cat(all_scores)
+        labels = torch.cat(all_labels)
+        bboxes = torch.cat(all_bboxes)
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes)
+
+        return bboxes, scores, labels
+
+    def forward(self, x):
+        # ---------------- Backbone ----------------
+        pyramid_feats = self.backbone(x)
+
+        # ---------------- Neck ----------------
+        pyramid_feats = self.fpn(pyramid_feats)
+
+        # ---------------- Heads ----------------
+        outputs = self.head(pyramid_feats)
+
+        if not self.training:
+            # ---------------- PostProcess ----------------
+            cls_pred = outputs["pred_cls"]
+            ctn_pred = outputs["pred_ctn"]
+            box_pred = outputs["pred_box"]
+            bboxes, scores, labels = self.post_process(cls_pred, ctn_pred, box_pred)
+
+            outputs = {
+                'scores': scores,
+                'labels': labels,
+                'bboxes': bboxes
+            }
+
+        return outputs 

+ 52 - 0
yolo/models/fcos/fcos_backbone.py

@@ -0,0 +1,52 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .resnet import build_resnet
+except:
+    from  resnet import build_resnet
+
+
+# --------------------- Yolov1's Backbone -----------------------
+class FcosBackbone(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.backbone, self.feat_dims = build_resnet(cfg.backbone, cfg.use_pretrained)
+
+    def forward(self, x):
+        pyramid_feats = self.backbone(x)
+
+        return pyramid_feats # [C3, C4, C5]
+
+
+if __name__=='__main__':
+    from thop import profile
+
+    # YOLOv1 configuration
+    class FcosBaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.out_stride = [8, 16, 32]
+            self.max_stride = 32
+            ## Backbone
+            self.backbone = 'resnet18'
+            self.use_pretrained = True
+    cfg = FcosBaseConfig()
+
+    # Build backbone
+    model = FcosBackbone(cfg)
+
+    # Randomly generate a input data
+    x = torch.randn(2, 3, 640, 640)
+
+    # Inference
+    outputs = model(x)
+    print(' - the shape of input :  ', x.shape)
+    for i, out in enumerate(outputs):
+        print(f' - the shape of level-{i} output : ', out.shape)
+
+    x = torch.randn(1, 3, 640, 640)
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('============== FLOPs & Params ================')
+    print(' - FLOPs  : {:.2f} G'.format(flops / 1e9 * 2))
+    print(' - Params : {:.2f} M'.format(params / 1e6))

+ 69 - 49
yolo/models/fcos/fcos_fpn.py

@@ -1,68 +1,88 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
+from typing import List
+
 
 # ------------------ Basic Feature Pyramid Network ------------------
-class BasicFPN(nn.Module):
-    def __init__(self, cfg, 
-                 in_dims=[512, 1024, 2048],
-                 out_dim=256,
-                 ):
+class FcosFPN(nn.Module):
+    def __init__(self, cfg, in_dims: List = [512, 1024, 2048]):
         super().__init__()
         # ------------------ Basic parameters -------------------
-        self.p6_feat = cfg.fpn_p6_feat
-        self.p7_feat = cfg.fpn_p7_feat
-        self.from_c5 = cfg.fpn_p6_from_c5
+        self.out_dim = cfg.head_dim
 
         # ------------------ Network parameters -------------------
-        ## latter layers
-        self.input_projs = nn.ModuleList()
-        self.smooth_layers = nn.ModuleList()
-        for in_dim in in_dims[::-1]:
-            self.input_projs.append(nn.Conv2d(in_dim, out_dim, kernel_size=1))
-            self.smooth_layers.append(nn.Conv2d(out_dim, out_dim, kernel_size=3, padding=1))
+        self.input_proj_1 = nn.Conv2d(in_dims[0], self.out_dim, kernel_size=1)
+        self.input_proj_2 = nn.Conv2d(in_dims[1], self.out_dim, kernel_size=1)
+        self.input_proj_3 = nn.Conv2d(in_dims[2], self.out_dim, kernel_size=1)
+
+        self.smooth_layer_1 = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=1)
+        self.smooth_layer_2 = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=1)
+        self.smooth_layer_3 = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=1)
 
-        ## P6/P7 layers
-        if self.p6_feat:
-            if self.from_c5:
-                self.p6_conv = nn.Conv2d(in_dims[-1], out_dim, kernel_size=3, stride=2, padding=1)
-            else: # from p5
-                self.p6_conv = nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1)
-        if self.p7_feat:
-            self.p7_conv = nn.Sequential(
-                nn.ReLU(inplace=True),
-                nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=2, padding=1)
-            )
+        self.p6_conv = nn.Conv2d(self.out_dim, self.out_dim, kernel_size=3, stride=2, padding=1)
 
     def forward(self, feats):
         """
-            feats: (List of Tensor) [C3, C4, C5], C_i ∈ R^(B x C_i x H_i x W_i)
+            feats: (List of Tensor) [C3, C4, C5]
         """
-        outputs = []
-        # [C3, C4, C5] -> [C5, C4, C3]
-        feats = feats[::-1]
-        top_level_feat = feats[0]
-        prev_feat = self.input_projs[0](top_level_feat)
-        outputs.append(self.smooth_layers[0](prev_feat))
+        c3, c4, c5 = feats
 
-        for feat, input_proj, smooth_layer in zip(feats[1:], self.input_projs[1:], self.smooth_layers[1:]):
-            feat = input_proj(feat)
-            top_down_feat = F.interpolate(prev_feat, size=feat.shape[2:], mode='nearest')
-            prev_feat = feat + top_down_feat
-            outputs.insert(0, smooth_layer(prev_feat))
+        # -------- Input projection --------
+        p3 = self.input_proj_1(c3)
+        p4 = self.input_proj_2(c4)
+        p5 = self.input_proj_3(c5)
+        
+        # -------- Feature fusion --------
+        outputs = [self.smooth_layer_3(p5)]
+        # P5 -> P4
+        p4 = p4 + F.interpolate(p5, size=p4.shape[2:], mode='nearest')
+        outputs.insert(0, self.smooth_layer_2(p4))
 
-        if self.p6_feat:
-            if self.from_c5:
-                p6_feat = self.p6_conv(feats[0])
-            else:
-                p6_feat = self.p6_conv(outputs[-1])
-            # [P3, P4, P5] -> [P3, P4, P5, P6]
-            outputs.append(p6_feat)
+        # P4 -> P3
+        p3 = p3 + F.interpolate(p4, size=p3.shape[2:], mode='nearest')
+        outputs.insert(0, self.smooth_layer_1(p3))
 
-            if self.p7_feat:
-                p7_feat = self.p7_conv(p6_feat)
-                # [P3, P4, P5, P6] -> [P3, P4, P5, P6, P7]
-                outputs.append(p7_feat)
+        # P5 -> P6
+        outputs.append(self.p6_conv(outputs[-1]))
 
-        # [P3, P4, P5] or [P3, P4, P5, P6, P7]
+        # [P3, P4, P5, P6]
         return outputs
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv2-Base config
+    class FcosBaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width    = 0.50
+            self.depth    = 0.34
+            self.out_stride = [8, 16, 32, 64]
+            ## Head
+            self.head_dim = 256
+
+    cfg = FcosBaseConfig()
+    # Build a head
+    in_dims  = [128, 256, 512]
+    fpn = FcosFPN(cfg, in_dims)
+
+    # Inference
+    x = [torch.randn(1, in_dims[0], 80, 80),
+         torch.randn(1, in_dims[1], 40, 40),
+         torch.randn(1, in_dims[2], 20, 20)]
+    t0 = time.time()
+    output = fpn(x)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print('====== FPN output ====== ')
+    for level, feat in enumerate(output):
+        print("- Level-{} : ".format(level), feat.shape)
+
+    flops, params = profile(fpn, inputs=(x, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 74 - 50
yolo/models/fcos/fcos_head.py

@@ -1,7 +1,11 @@
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
 
-from .modules import BasicConv
+try:
+    from .modules import ConvModule
+except:
+    from  modules import ConvModule
 
 
 class Scale(nn.Module):
@@ -25,53 +29,34 @@ class Scale(nn.Module):
         return x * self.scale
 
 class FcosHead(nn.Module):
-    def __init__(self, cfg, in_dim, out_dim,):
+    def __init__(self, cfg, in_dim: int = 256,):
         super().__init__()
-        self.fmp_size = None
         # ------------------ Basic parameters -------------------
         self.cfg = cfg
         self.in_dim = in_dim
-        self.stride       = cfg.out_stride
-        self.num_classes  = cfg.num_classes
-        self.num_cls_head = cfg.num_cls_head
-        self.num_reg_head = cfg.num_reg_head
-        self.act_type     = cfg.head_act
-        self.norm_type    = cfg.head_norm
+        self.out_dim = cfg.head_dim
+        self.out_stride  = cfg.out_stride
+        self.num_classes = cfg.num_classes
 
         # ------------------ Network parameters -------------------
-        ## cls head
+        ## classification head
         cls_heads = []
-        self.cls_head_dim = out_dim
-        for i in range(self.num_cls_head):
+        self.cls_head_dim = cfg.head_dim
+        for i in range(cfg.num_cls_head):
             if i == 0:
-                cls_heads.append(
-                    BasicConv(in_dim, self.cls_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=self.act_type, norm_type=self.norm_type)
-                              )
+                cls_heads.append(ConvModule(in_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
             else:
-                cls_heads.append(
-                    BasicConv(self.cls_head_dim, self.cls_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=self.act_type, norm_type=self.norm_type)
-                              )
+                cls_heads.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
         
-        ## reg head
+        ## bbox regression head
         reg_heads = []
-        self.reg_head_dim = out_dim
-        for i in range(self.num_reg_head):
+        self.reg_head_dim = cfg.head_dim
+        for i in range(cfg.num_reg_head):
             if i == 0:
-                reg_heads.append(
-                    BasicConv(in_dim, self.reg_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=self.act_type, norm_type=self.norm_type)
-                              )
+                reg_heads.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
             else:
-                reg_heads.append(
-                    BasicConv(self.reg_head_dim, self.reg_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=self.act_type, norm_type=self.norm_type)
-                              )
+                reg_heads.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
+        
         self.cls_heads = nn.Sequential(*cls_heads)
         self.reg_heads = nn.Sequential(*reg_heads)
 
@@ -82,7 +67,7 @@ class FcosHead(nn.Module):
         
         ## scale layers
         self.scales = nn.ModuleList(
-            Scale() for _ in range(len(self.stride))
+            Scale() for _ in range(len(self.out_stride))
         )
         
         # init bias
@@ -112,8 +97,9 @@ class FcosHead(nn.Module):
         fmp_h, fmp_w = fmp_size
         anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
         # [H, W, 2] -> [HW, 2]
-        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5
-        anchors *= self.stride[level]
+        anchors = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2)
+        anchors += 0.5
+        anchors *= self.out_stride[level]
 
         return anchors
         
@@ -130,8 +116,7 @@ class FcosHead(nn.Module):
 
         return pred_box
     
-    def forward(self, pyramid_feats, mask=None):
-        all_masks = []
+    def forward(self, pyramid_feats):
         all_anchors = []
         all_cls_preds = []
         all_reg_preds = []
@@ -158,16 +143,10 @@ class FcosHead(nn.Module):
             cls_pred = cls_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, self.num_classes)
             ctn_pred = ctn_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 1)
             reg_pred = reg_pred.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
-            reg_pred = nn.functional.relu(self.scales[level](reg_pred)) * self.stride[level]
+            reg_pred = F.relu(self.scales[level](reg_pred)) * self.out_stride[level]
+
             ## Decode bbox
             box_pred = self.decode_boxes(reg_pred, anchors)
-            ## Adjust mask
-            if mask is not None:
-                # [B, H, W]
-                mask_i = torch.nn.functional.interpolate(mask[None].float(), size=[H, W]).bool()[0]
-                # [B, H, W] -> [B, M]
-                mask_i = mask_i.flatten(1)     
-                all_masks.append(mask_i)
                 
             all_anchors.append(anchors)
             all_cls_preds.append(cls_pred)
@@ -180,7 +159,52 @@ class FcosHead(nn.Module):
                    "pred_box": all_box_preds,  # List [B, M, 4]
                    "pred_ctn": all_ctn_preds,  # List [B, M, 1]
                    "anchors": all_anchors,     # List [B, M, 2]
-                   "strides": self.stride,
-                   "mask": all_masks}          # List [B, M,]
+                   "strides": self.out_stride,
+                   }
 
         return outputs 
+
+
+if __name__=='__main__':
+    import time
+    from thop import profile
+    # Model config
+    
+    # YOLOv3-Base config
+    class FcosBaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.width = 0.50
+            self.depth = 0.34
+
+            self.out_stride = [8, 16, 32, 64]
+            self.num_classes = 20
+
+            ## Head
+            self.head_dim  = 256
+            self.num_cls_head = 4
+            self.num_reg_head = 4
+
+    cfg = FcosBaseConfig()
+    feat_dim = 256
+    pyramid_feats = [torch.randn(1, feat_dim, 80, 80),
+                     torch.randn(1, feat_dim, 40, 40),
+                     torch.randn(1, feat_dim, 20, 20),
+                     torch.randn(1, feat_dim, 10, 10)]
+
+    # Build a head
+    head = FcosHead(cfg, feat_dim)
+
+    # Inference
+    t0 = time.time()
+    outputs = head(pyramid_feats)
+    t1 = time.time()
+    print('Time: ', t1 - t0)
+    print("====== FCOS Head output ======")
+    for k in outputs:
+        print(f" - shape of {k}: ", outputs[k].shape )
+
+    flops, params = profile(head, inputs=(pyramid_feats, ), verbose=False)
+    print('==============================')
+    print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 32 - 180
yolo/models/fcos/loss.py

@@ -2,42 +2,33 @@ import torch
 import torch.nn as nn
 import torch.nn.functional as F
 
-from utils.box_ops import get_ious
 from utils.misc import sigmoid_focal_loss
 from utils.distributed_utils import get_world_size, is_dist_avail_and_initialized
 
-from .matcher import FcosMatcher, AlignedOTAMatcher
+from .matcher import FcosMatcher
 
 
-class SetCriterion(nn.Module):
+class SetCriterion(object):
     def __init__(self, cfg):
-        super().__init__()
         # ------------- Basic parameters -------------
         self.cfg = cfg
         self.num_classes = cfg.num_classes
+
         # ------------- Focal loss -------------
         self.alpha = cfg.focal_loss_alpha
         self.gamma = cfg.focal_loss_gamma
+
         # ------------- Loss weight -------------
-        # ------------- Matcher & Loss weight -------------
-        self.matcher_cfg = cfg.matcher_hpy
-        if cfg.matcher == 'fcos_matcher':
-            self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
-                                'loss_reg': cfg.loss_reg_weight,
-                                'loss_ctn': cfg.loss_ctn_weight}
-            self.matcher = FcosMatcher(cfg.num_classes,
-                                       self.matcher_cfg['center_sampling_radius'],
-                                       self.matcher_cfg['object_sizes_of_interest'],
-                                       [1., 1., 1., 1.]
-                                       )
-        elif cfg.matcher == 'simota':
-            self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
-                                'loss_reg': cfg.loss_reg_weight}
-            self.matcher = AlignedOTAMatcher(cfg.num_classes,
-                                             self.matcher_cfg['soft_center_radius'],
-                                             self.matcher_cfg['topk_candidates'])
-        else:
-            raise NotImplementedError("Unknown matcher: {}.".format(cfg.matcher))
+        self.weight_dict = {'loss_cls': cfg.loss_cls,
+                            'loss_reg': cfg.loss_reg,
+                            'loss_ctn': cfg.loss_ctn,}
+        
+        # ------------- Matcher -------------
+        self.matcher = FcosMatcher(cfg.num_classes,
+                                   center_sampling_radius=cfg.center_sampling_radius,
+                                   object_sizes_of_interest=cfg.object_sizes_of_interest,
+                                   box_weights=[1., 1., 1., 1.],
+                                   )
 
     def loss_labels(self, pred_cls, tgt_cls, num_boxes=1.0):
         """
@@ -49,34 +40,7 @@ class SetCriterion(nn.Module):
 
         return loss_cls.sum() / num_boxes
 
-    def loss_labels_qfl(self, pred_cls, target, beta=2.0, num_boxes=1.0):
-        # Quality FocalLoss
-        """
-            pred_cls: (torch.Tensor): [N, C]。
-            target:   (tuple([torch.Tensor], [torch.Tensor])): label -> (N,), score -> (N)
-        """
-        label, score = target
-        pred_sigmoid = pred_cls.sigmoid()
-        scale_factor = pred_sigmoid
-        zerolabel = scale_factor.new_zeros(pred_cls.shape)
-
-        ce_loss = F.binary_cross_entropy_with_logits(
-            pred_cls, zerolabel, reduction='none') * scale_factor.pow(beta)
-        
-        bg_class_ind = pred_cls.shape[-1]
-        pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
-        if pos.shape[0] > 0:
-            pos_label = label[pos].long()
-
-            scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
-
-            ce_loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
-                pred_cls[pos, pos_label], score[pos],
-                reduction='none') * scale_factor.abs().pow(beta)
-
-        return ce_loss.sum() / num_boxes
-    
-    def loss_bboxes_ltrb(self, pred_delta, tgt_delta, bbox_quality=None, num_boxes=1.0):
+    def loss_bboxes(self, pred_delta, tgt_delta, bbox_quality=None, num_boxes=1.0):
         """
             pred_box: (Tensor) [N, 4]
             tgt_box:  (Tensor) [N, 4]
@@ -114,16 +78,7 @@ class SetCriterion(nn.Module):
 
         return loss_box.sum() / num_boxes
 
-    def loss_bboxes_xyxy(self, pred_box, gt_box, num_boxes=1.0, box_weight=None):
-        ious = get_ious(pred_box, gt_box, box_mode="xyxy", iou_type='giou')
-        loss_box = 1.0 - ious
-
-        if box_weight is not None:
-            loss_box = loss_box.squeeze(-1) * box_weight
-
-        return loss_box.sum() / num_boxes
-    
-    def fcos_loss(self, outputs, targets):
+    def __call__(self, outputs, targets):
         """
             outputs['pred_cls']: (Tensor) [B, M, C]
             outputs['pred_reg']: (Tensor) [B, M, 4]
@@ -137,10 +92,10 @@ class SetCriterion(nn.Module):
         device = outputs['pred_cls'][0].device
         fpn_strides = outputs['strides']
         anchors = outputs['anchors']
-        pred_cls = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes)
+
+        pred_cls   = torch.cat(outputs['pred_cls'], dim=1).view(-1, self.num_classes)
         pred_delta = torch.cat(outputs['pred_reg'], dim=1).view(-1, 4)
-        pred_ctn = torch.cat(outputs['pred_ctn'], dim=1).view(-1, 1)
-        masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
+        pred_ctn   = torch.cat(outputs['pred_ctn'], dim=1).view(-1, 1)
 
         # -------------------- Label Assignment --------------------
         gt_classes, gt_deltas, gt_centerness = self.matcher(fpn_strides, anchors, targets)
@@ -148,33 +103,31 @@ class SetCriterion(nn.Module):
         gt_deltas = gt_deltas.view(-1, 4).to(device)
         gt_centerness = gt_centerness.view(-1, 1).to(device)
 
-        foreground_idxs = (gt_classes >= 0) & (gt_classes != self.num_classes)
-        num_foreground = foreground_idxs.sum()
+        fg_masks = (gt_classes >= 0) & (gt_classes != self.num_classes)
+        num_fgs = fg_masks.sum()
 
         if is_dist_avail_and_initialized():
-            torch.distributed.all_reduce(num_foreground)
-        num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item()
 
-        num_foreground_centerness = gt_centerness[foreground_idxs].sum()
+        num_fgs_ctn = gt_centerness[fg_masks].sum()
         if is_dist_avail_and_initialized():
-            torch.distributed.all_reduce(num_foreground_centerness)
-        num_targets = torch.clamp(num_foreground_centerness / get_world_size(), min=1).item()
+            torch.distributed.all_reduce(num_fgs_ctn)
+        num_targets = torch.clamp(num_fgs_ctn / get_world_size(), min=1).item()
 
         # -------------------- classification loss --------------------
         gt_classes_target = torch.zeros_like(pred_cls)
-        gt_classes_target[foreground_idxs, gt_classes[foreground_idxs]] = 1
-        valid_idxs = (gt_classes >= 0) & masks
-        loss_labels = self.loss_labels(
-            pred_cls[valid_idxs], gt_classes_target[valid_idxs], num_foreground)
+        gt_classes_target[fg_masks, gt_classes[fg_masks]] = 1
+        loss_labels = self.loss_labels(pred_cls, gt_classes_target, num_fgs)
 
         # -------------------- regression loss --------------------
-        loss_bboxes = self.loss_bboxes_ltrb(
-            pred_delta[foreground_idxs], gt_deltas[foreground_idxs], gt_centerness[foreground_idxs], num_targets)
+        loss_bboxes = self.loss_bboxes(
+            pred_delta[fg_masks], gt_deltas[fg_masks], gt_centerness[fg_masks], num_targets)
 
         # -------------------- centerness loss --------------------
         loss_centerness = F.binary_cross_entropy_with_logits(
-            pred_ctn[foreground_idxs],  gt_centerness[foreground_idxs], reduction='none')
-        loss_centerness = loss_centerness.sum() / num_foreground
+            pred_ctn[fg_masks],  gt_centerness[fg_masks], reduction='none')
+        loss_centerness = loss_centerness.sum() / num_fgs
 
         total_loss = loss_labels * self.weight_dict["loss_cls"] + \
                      loss_bboxes * self.weight_dict["loss_reg"] + \
@@ -187,104 +140,3 @@ class SetCriterion(nn.Module):
         )
 
         return loss_dict
-    
-    def ota_loss(self, outputs, targets):
-        """
-            outputs['pred_cls']: (Tensor) [B, M, C]
-            outputs['pred_reg']: (Tensor) [B, M, 4]
-            outputs['pred_box']: (Tensor) [B, M, 4]
-            outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
-            targets: (List) [dict{'boxes': [...], 
-                                 'labels': [...], 
-                                 'orig_size': ...}, ...]
-        """
-        # -------------------- Pre-process --------------------
-        bs          = outputs['pred_cls'][0].shape[0]
-        device      = outputs['pred_cls'][0].device
-        fpn_strides = outputs['strides']
-        anchors     = outputs['anchors']
-        # preds: [B, M, C]
-        # preds: [B, M, C]
-        cls_preds = torch.cat(outputs['pred_cls'], dim=1)
-        box_preds = torch.cat(outputs['pred_box'], dim=1)
-        masks = ~torch.cat(outputs['mask'], dim=1).view(-1)
-
-        # -------------------- Label Assignment --------------------
-        cls_targets = []
-        box_targets = []
-        assign_metrics = []
-        for batch_idx in range(bs):
-            tgt_labels = targets[batch_idx]["labels"].to(device)  # [N,]
-            tgt_bboxes = targets[batch_idx]["boxes"].to(device)   # [N, 4]
-            # refine target
-            tgt_boxes_wh = tgt_bboxes[..., 2:] - tgt_bboxes[..., :2]
-            min_tgt_size = torch.min(tgt_boxes_wh, dim=-1)[0]
-            keep = (min_tgt_size >= 8)
-            tgt_bboxes = tgt_bboxes[keep]
-            tgt_labels = tgt_labels[keep]
-            # label assignment
-            assigned_result = self.matcher(fpn_strides=fpn_strides,
-                                           anchors=anchors,
-                                           pred_cls=cls_preds[batch_idx].detach(),
-                                           pred_box=box_preds[batch_idx].detach(),
-                                           gt_labels=tgt_labels,
-                                           gt_bboxes=tgt_bboxes
-                                           )
-            cls_targets.append(assigned_result['assigned_labels'])
-            box_targets.append(assigned_result['assigned_bboxes'])
-            assign_metrics.append(assigned_result['assign_metrics'])
-
-        # List[B, M, C] -> Tensor[BM, C]
-        cls_targets = torch.cat(cls_targets, dim=0)
-        box_targets = torch.cat(box_targets, dim=0)
-        assign_metrics = torch.cat(assign_metrics, dim=0)
-
-        valid_idxs = (cls_targets >= 0) & masks
-        foreground_idxs = (cls_targets >= 0) & (cls_targets != self.num_classes)
-        num_fgs = assign_metrics.sum()
-
-        if is_dist_avail_and_initialized():
-            torch.distributed.all_reduce(num_fgs)
-        num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item()
-
-        # -------------------- classification loss --------------------
-        cls_preds = cls_preds.view(-1, self.num_classes)[valid_idxs]
-        qfl_targets = (cls_targets[valid_idxs], assign_metrics[valid_idxs])
-        loss_labels = self.loss_labels_qfl(cls_preds, qfl_targets, 2.0, num_fgs)
-
-        # -------------------- regression loss --------------------
-        box_preds_pos = box_preds.view(-1, 4)[foreground_idxs]
-        box_targets_pos = box_targets[foreground_idxs]
-        box_weight = assign_metrics[foreground_idxs]
-        loss_bboxes = self.loss_bboxes_xyxy(box_preds_pos, box_targets_pos, num_fgs, box_weight)
-
-        total_loss = loss_labels * self.weight_dict["loss_cls"] + \
-                     loss_bboxes * self.weight_dict["loss_reg"]
-        loss_dict = dict(
-                loss_cls = loss_labels,
-                loss_reg = loss_bboxes,
-                losses   = total_loss,
-        )
-
-        return loss_dict
-    
-    def forward(self, outputs, targets):
-        """
-            outputs['pred_cls']: (Tensor) [B, M, C]
-            outputs['pred_reg']: (Tensor) [B, M, 4]
-            outputs['pred_ctn']: (Tensor) [B, M, 1]
-            outputs['strides']: (List) [8, 16, 32, ...] stride of the model output
-            targets: (List) [dict{'boxes': [...], 
-                                 'labels': [...], 
-                                 'orig_size': ...}, ...]
-        """
-        if self.cfg.matcher == "fcos_matcher":
-            return self.fcos_loss(outputs, targets)
-        elif self.cfg.matcher == "simota":
-            return self.ota_loss(outputs, targets)
-        else:
-            raise NotImplementedError
-
-
-if __name__ == "__main__":
-    pass

+ 0 - 160
yolo/models/fcos/matcher.py

@@ -89,7 +89,6 @@ class FcosMatcher(object):
         self.object_sizes_of_interest = object_sizes_of_interest
         self.box_weightss = box_weights
 
-
     def get_deltas(self, anchors, boxes):
         """
         Get box regression transformation deltas (dl, dt, dr, db) that can be used
@@ -107,7 +106,6 @@ class FcosMatcher(object):
                            dim=-1) * anchors.new_tensor(self.box_weightss)
         return deltas
 
-
     @torch.no_grad()
     def __call__(self, fpn_strides, anchors, targets):
         """
@@ -216,163 +214,5 @@ class FcosMatcher(object):
                 gt_anchors_deltas.append(gt_anchors_reg_deltas_i.float())
                 gt_centerness.append(gt_centerness_i.float())
 
-
         # [B, M], [B, M, 4], [B, M]
         return torch.stack(gt_classes), torch.stack(gt_anchors_deltas), torch.stack(gt_centerness)
-
-
-class AlignedOTAMatcher(object):
-    """
-    This code referenced to https://github.com/open-mmlab/mmyolo/models/task_modules/assigners/batch_dsl_assigner.py
-    """
-    def __init__(self, num_classes, soft_center_radius=3.0, topk_candidates=13):
-        self.num_classes = num_classes
-        self.soft_center_radius = soft_center_radius
-        self.topk_candidates = topk_candidates
-
-    @torch.no_grad()
-    def __call__(self, 
-                 fpn_strides, 
-                 anchors, 
-                 pred_cls, 
-                 pred_box,
-                 gt_labels,
-                 gt_bboxes):
-        # [M,]
-        strides = torch.cat([torch.ones_like(anchor_i[:, 0]) * stride_i
-                                for stride_i, anchor_i in zip(fpn_strides, anchors)], dim=-1)
-        # List[F, M, 2] -> [M, 2]
-        num_gt = len(gt_labels)
-        anchors = torch.cat(anchors, dim=0)
-
-        # check gt
-        if num_gt == 0 or gt_bboxes.max().item() == 0.:
-            return {
-                'assigned_labels': gt_labels.new_full(pred_cls[..., 0].shape,
-                                                      self.num_classes,
-                                                      dtype=torch.long),
-                'assigned_bboxes': gt_bboxes.new_full(pred_box.shape, 0),
-                'assign_metrics': gt_bboxes.new_full(pred_cls[..., 0].shape, 0)
-            }
-        
-        # get inside points: [N, M]
-        is_in_gt = self.find_inside_points(gt_bboxes, anchors)
-        valid_mask = is_in_gt.sum(dim=0) > 0  # [M,]
-
-        # ----------------------------------- soft center prior -----------------------------------
-        gt_center = (gt_bboxes[..., :2] + gt_bboxes[..., 2:]) / 2.0
-        distance = (anchors.unsqueeze(0) - gt_center.unsqueeze(1)
-                    ).pow(2).sum(-1).sqrt() / strides.unsqueeze(0)  # [N, M]
-        distance = distance * valid_mask.unsqueeze(0)
-        soft_center_prior = torch.pow(10, distance - self.soft_center_radius)
-
-        # ----------------------------------- regression cost -----------------------------------
-        pair_wise_ious, _ = box_iou(gt_bboxes, pred_box)  # [N, M]
-        pair_wise_ious_loss = -torch.log(pair_wise_ious + 1e-8) * 3.0
-
-        # ----------------------------------- classification cost -----------------------------------
-        ## select the predicted scores corresponded to the gt_labels
-        pairwise_pred_scores = pred_cls.permute(1, 0)  # [M, C] -> [C, M]
-        pairwise_pred_scores = pairwise_pred_scores[gt_labels.long(), :].float()   # [N, M]
-        ## scale factor
-        scale_factor = (pair_wise_ious - pairwise_pred_scores.sigmoid()).abs().pow(2.0)
-        ## cls cost
-        pair_wise_cls_loss = F.binary_cross_entropy_with_logits(
-            pairwise_pred_scores, pair_wise_ious,
-            reduction="none") * scale_factor # [N, M]
-            
-        del pairwise_pred_scores
-
-        ## foreground cost matrix
-        cost_matrix = pair_wise_cls_loss + pair_wise_ious_loss + soft_center_prior
-        max_pad_value = torch.ones_like(cost_matrix) * 1e9
-        cost_matrix = torch.where(valid_mask[None].repeat(num_gt, 1),   # [N, M]
-                                  cost_matrix, max_pad_value)
-
-        # ----------------------------------- dynamic label assignment -----------------------------------
-        matched_pred_ious, matched_gt_inds, fg_mask_inboxes = self.dynamic_k_matching(
-            cost_matrix, pair_wise_ious, num_gt)
-        del pair_wise_cls_loss, cost_matrix, pair_wise_ious, pair_wise_ious_loss
-
-        # -----------------------------------process assigned labels -----------------------------------
-        assigned_labels = gt_labels.new_full(pred_cls[..., 0].shape,
-                                             self.num_classes)  # [M,]
-        assigned_labels[fg_mask_inboxes] = gt_labels[matched_gt_inds].squeeze(-1)
-        assigned_labels = assigned_labels.long()  # [M,]
-
-        assigned_bboxes = gt_bboxes.new_full(pred_box.shape, 0)        # [M, 4]
-        assigned_bboxes[fg_mask_inboxes] = gt_bboxes[matched_gt_inds]  # [M, 4]
-
-        assign_metrics = gt_bboxes.new_full(pred_cls[..., 0].shape, 0) # [M,]
-        assign_metrics[fg_mask_inboxes] = matched_pred_ious            # [M,]
-
-        assigned_dict = dict(
-            assigned_labels=assigned_labels,
-            assigned_bboxes=assigned_bboxes,
-            assign_metrics=assign_metrics
-            )
-        
-        return assigned_dict
-
-    def find_inside_points(self, gt_bboxes, anchors):
-        """
-            gt_bboxes: Tensor -> [N, 2]
-            anchors:   Tensor -> [M, 2]
-        """
-        num_anchors = anchors.shape[0]
-        num_gt = gt_bboxes.shape[0]
-
-        anchors_expand = anchors.unsqueeze(0).repeat(num_gt, 1, 1)           # [N, M, 2]
-        gt_bboxes_expand = gt_bboxes.unsqueeze(1).repeat(1, num_anchors, 1)  # [N, M, 4]
-
-        # offset
-        lt = anchors_expand - gt_bboxes_expand[..., :2]
-        rb = gt_bboxes_expand[..., 2:] - anchors_expand
-        bbox_deltas = torch.cat([lt, rb], dim=-1)
-
-        is_in_gts = bbox_deltas.min(dim=-1).values > 0
-
-        return is_in_gts
-    
-    def dynamic_k_matching(self, cost_matrix, pairwise_ious, num_gt):
-        """Use IoU and matching cost to calculate the dynamic top-k positive
-        targets.
-
-        Args:
-            cost_matrix (Tensor): Cost matrix.
-            pairwise_ious (Tensor): Pairwise iou matrix.
-            num_gt (int): Number of gt.
-            valid_mask (Tensor): Mask for valid bboxes.
-        Returns:
-            tuple: matched ious and gt indexes.
-        """
-        matching_matrix = torch.zeros_like(cost_matrix, dtype=torch.uint8)
-        # select candidate topk ious for dynamic-k calculation
-        candidate_topk = min(self.topk_candidates, pairwise_ious.size(1))
-        topk_ious, _ = torch.topk(pairwise_ious, candidate_topk, dim=1)
-        # calculate dynamic k for each gt
-        dynamic_ks = torch.clamp(topk_ious.sum(1).int(), min=1)
-
-        # sorting the batch cost matirx is faster than topk
-        _, sorted_indices = torch.sort(cost_matrix, dim=1)
-        for gt_idx in range(num_gt):
-            topk_ids = sorted_indices[gt_idx, :dynamic_ks[gt_idx]]
-            matching_matrix[gt_idx, :][topk_ids] = 1
-
-        del topk_ious, dynamic_ks, topk_ids
-
-        prior_match_gt_mask = matching_matrix.sum(0) > 1
-        if prior_match_gt_mask.sum() > 0:
-            cost_min, cost_argmin = torch.min(
-                cost_matrix[:, prior_match_gt_mask], dim=0)
-            matching_matrix[:, prior_match_gt_mask] *= 0
-            matching_matrix[cost_argmin, prior_match_gt_mask] = 1
-
-        # get foreground mask inside box and center prior
-        fg_mask_inboxes = matching_matrix.sum(0) > 0
-        matched_pred_ious = (matching_matrix *
-                             pairwise_ious).sum(0)[fg_mask_inboxes]
-        matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)
-
-        return matched_pred_ious, matched_gt_inds, fg_mask_inboxes
-        

+ 12 - 138
yolo/models/fcos/modules.py

@@ -4,145 +4,19 @@ from typing import List
 
 
 # --------------------- Basic modules ---------------------
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-    elif act_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-        
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-    elif norm_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-
-class BasicConv(nn.Module):
+class ConvModule(nn.Module):
     def __init__(self, 
-                 in_dim,                   # in channels
-                 out_dim,                  # out channels 
-                 kernel_size=1,            # kernel size 
-                 padding=0,                # padding
-                 stride=1,                 # padding
-                 dilation=1,               # dilation
-                 act_type  :str = 'lrelu', # activation
-                 norm_type :str = 'BN',    # normalization
-                 depthwise :bool = False
+                 in_dim,         # in channels
+                 out_dim,        # out channels 
+                 kernel_size=1,  # kernel size 
+                 padding=0,      # padding
+                 stride=1,       # padding
+                 dilation=1,     # dilation
                 ):
-        super(BasicConv, self).__init__()
-        self.depthwise = depthwise
-        use_bias = False if norm_type is not None else True
-        if not depthwise:
-            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1, bias=use_bias)
-            self.norm = get_norm(norm_type, out_dim)
-        else:
-            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=use_bias)
-            self.norm1 = get_norm(norm_type, in_dim)
-            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
-            self.norm2 = get_norm(norm_type, out_dim)
-        self.act  = get_activation(act_type)
+        super(ConvModule, self).__init__()
+        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
+        self.norm = nn.GroupNorm(num_groups=32, num_channels=out_dim)
+        self.act  = nn.ReLU(inplace=True)
 
     def forward(self, x):
-        if not self.depthwise:
-            return self.act(self.norm(self.conv(x)))
-        else:
-            # Depthwise conv
-            x = self.norm1(self.conv1(x))
-            # Pointwise conv
-            x = self.act(self.norm2(self.conv2(x)))
-            return x
-
-
-# --------------------- ResNet modules ---------------------
-def conv3x3(in_planes, out_planes, stride=1):
-    """3x3 convolution with padding"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
-                     padding=1, bias=False)
-
-def conv1x1(in_planes, out_planes, stride=1):
-    """1x1 convolution"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-class BasicBlock(nn.Module):
-    expansion = 1
-
-    def __init__(self, inplanes, planes, stride=1, downsample=None):
-        super(BasicBlock, self).__init__()
-        self.conv1 = conv3x3(inplanes, planes, stride)
-        self.bn1 = nn.BatchNorm2d(planes)
-        self.relu = nn.ReLU(inplace=True)
-        self.conv2 = conv3x3(planes, planes)
-        self.bn2 = nn.BatchNorm2d(planes)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x):
-        identity = x
-
-        out = self.conv1(x)
-        out = self.bn1(out)
-        out = self.relu(out)
-
-        out = self.conv2(out)
-        out = self.bn2(out)
-
-        if self.downsample is not None:
-            identity = self.downsample(x)
-
-        out += identity
-        out = self.relu(out)
-
-        return out
-
-class Bottleneck(nn.Module):
-    expansion = 4
-
-    def __init__(self, inplanes, planes, stride=1, downsample=None):
-        super(Bottleneck, self).__init__()
-        self.conv1 = conv1x1(inplanes, planes)
-        self.bn1 = nn.BatchNorm2d(planes)
-        self.conv2 = conv3x3(planes, planes, stride)
-        self.bn2 = nn.BatchNorm2d(planes)
-        self.conv3 = conv1x1(planes, planes * self.expansion)
-        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
-        self.relu = nn.ReLU(inplace=True)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x):
-        identity = x
-
-        out = self.conv1(x)
-        out = self.bn1(out)
-        out = self.relu(out)
-
-        out = self.conv2(out)
-        out = self.bn2(out)
-        out = self.relu(out)
-
-        out = self.conv3(out)
-        out = self.bn3(out)
-
-        if self.downsample is not None:
-            identity = self.downsample(x)
-
-        out += identity
-        out = self.relu(out)
-
-        return out
+        return self.act(self.norm(self.conv(x)))

+ 90 - 15
yolo/models/fcos/resnet.py

@@ -2,15 +2,10 @@ import torch
 import torch.nn as nn
 import torch.utils.model_zoo as model_zoo
 
-try:
-    from .modules import conv1x1, BasicBlock, Bottleneck
-except:
-    from  modules import conv1x1, BasicBlock, Bottleneck
 
 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
            'resnet152']
 
-
 model_urls = {
     'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
     'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
@@ -20,9 +15,87 @@ model_urls = {
 }
 
 
+# --------------------- ResNet modules ---------------------
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = conv1x1(inplanes, planes)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = conv1x1(planes, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
 # --------------------- ResNet -----------------------
 class ResNet(nn.Module):
-
     def __init__(self, block, layers, zero_init_residual=False):
         super(ResNet, self).__init__()
         self.inplanes = 64
@@ -86,27 +159,29 @@ class ResNet(nn.Module):
         c4 = self.layer3(c3)   # [B, C, H/16, W/16]
         c5 = self.layer4(c4)   # [B, C, H/32, W/32]
 
-        return c5
+        return [c3, c4, c5]
 
-
-# --------------------- Functions -----------------------
 def build_resnet(model_name="resnet18", pretrained=False):
     if model_name == 'resnet18':
         model = resnet18(pretrained)
-        feat_dim = 512
+        feat_dims = [128, 256, 512]
+
     elif model_name == 'resnet34':
         model = resnet34(pretrained)
-        feat_dim = 512
+        feat_dims = [128, 256, 512]
+
     elif model_name == 'resnet50':
         model = resnet50(pretrained)
-        feat_dim = 2048
+        feat_dims = [512, 1024, 2048]
+
     elif model_name == 'resnet101':
-        model = resnet34(pretrained)
-        feat_dim = 2048
+        model = resnet101(pretrained)
+        feat_dims = [512, 1024, 2048]
+
     else:
         raise NotImplementedError("Unknown resnet: {}".format(model_name))
     
-    return model, feat_dim
+    return model, feat_dims
 
 def resnet18(pretrained=False, **kwargs):
     """Constructs a ResNet-18 model.

+ 16 - 0
yolo/models/yolof/build.py

@@ -0,0 +1,16 @@
+from .loss import SetCriterion
+from .yolof import Yolof
+
+
+# build object detector
+def build_yolof(cfg, is_val=False):
+    # -------------- Build YOLO --------------
+    model = Yolof(cfg, is_val)
+  
+    # -------------- Build criterion --------------
+    criterion = None
+    if is_val:
+        # build criterion for training
+        criterion = SetCriterion(cfg)
+        
+    return model, criterion

+ 27 - 26
yolo/models/yolof/loss.py

@@ -8,24 +8,27 @@ from utils.distributed_utils import get_world_size, is_dist_avail_and_initialize
 from .matcher import UniformMatcher
 
 
-class SetCriterion(nn.Module):
+class SetCriterion(object):
     """
         This code referenced to https://github.com/megvii-model/YOLOF/blob/main/playground/detection/coco/yolof/yolof_base/yolof.py
     """
     def __init__(self, cfg):
-        super().__init__()
         # ------------- Basic parameters -------------
         self.cfg = cfg
         self.num_classes = cfg.num_classes
+
         # ------------- Focal loss -------------
         self.alpha = cfg.focal_loss_alpha
         self.gamma = cfg.focal_loss_gamma
+
         # ------------- Loss weight -------------
-        self.weight_dict = {'loss_cls': cfg.loss_cls_weight,
-                            'loss_reg': cfg.loss_reg_weight}
+        self.weight_dict = {'loss_cls': cfg.loss_cls,
+                            'loss_reg': cfg.loss_reg}
+        
         # ------------- Matcher -------------
-        self.matcher_cfg = cfg.matcher_hpy
-        self.matcher = UniformMatcher(self.matcher_cfg['topk_candidates'])
+        self.ignore_thresh = cfg.ignore_thresh
+        self.match_iou_weight = cfg.match_iou_thresh
+        self.matcher = UniformMatcher(cfg.match_topk_candidates)
 
     def loss_labels(self, pred_cls, tgt_cls, num_boxes):
         """
@@ -49,7 +52,7 @@ class SetCriterion(nn.Module):
 
         return loss_reg.sum() / num_boxes
 
-    def forward(self, outputs, targets):
+    def __call__(self, outputs, targets):
         """
             outputs['pred_cls']: (Tensor) [B, M, C]
             outputs['pred_box']: (Tensor) [B, M, 4]
@@ -61,29 +64,28 @@ class SetCriterion(nn.Module):
         pred_box = outputs['pred_box']
         pred_cls = outputs['pred_cls'].reshape(-1, self.num_classes)
         anchor_boxes = outputs['anchors']
-        masks = ~outputs['mask']
         device = pred_box.device
-        B = len(targets)
+        bs = len(targets)
 
         # -------------------- Label assignment --------------------
         indices = self.matcher(pred_box, anchor_boxes, targets)
 
         # [M, 4] -> [1, M, 4] -> [B, M, 4]
         anchor_boxes = box_cxcywh_to_xyxy(anchor_boxes)
-        anchor_boxes = anchor_boxes[None].repeat(B, 1, 1)
+        anchor_boxes = anchor_boxes[None].repeat(bs, 1, 1)
 
         ious = []
         pos_ious = []
-        for i in range(B):
+        for i in range(bs):
             src_idx, tgt_idx = indices[i]
             # iou between predbox and tgt box
-            iou, _ = box_iou(pred_box[i, ...], (targets[i]['boxes']).clone())
+            iou, _ = box_iou(pred_box[i, ...], (targets[i]['boxes']).clone().to(device))
             if iou.numel() == 0:
                 max_iou = iou.new_full((iou.size(0),), 0)
             else:
                 max_iou = iou.max(dim=1)[0]
             # iou between anchorbox and tgt box
-            a_iou, _ = box_iou(anchor_boxes[i], (targets[i]['boxes']).clone())
+            a_iou, _ = box_iou(anchor_boxes[i], (targets[i]['boxes']).clone().to(device))
             if a_iou.numel() == 0:
                 pos_iou = a_iou.new_full((0,), 0)
             else:
@@ -92,42 +94,41 @@ class SetCriterion(nn.Module):
             pos_ious.append(pos_iou)
 
         ious = torch.cat(ious)
-        ignore_idx = ious > self.matcher_cfg['ignore_thresh']
+        ignore_idx = ious > self.ignore_thresh
         pos_ious = torch.cat(pos_ious)
-        pos_ignore_idx = pos_ious < self.matcher_cfg['iou_thresh']
+        pos_ignore_idx = pos_ious < self.match_iou_weight
 
         src_idx = torch.cat(
             [src + idx * anchor_boxes[0].shape[0] for idx, (src, _) in
              enumerate(indices)])
         # [BM,]
         gt_cls = torch.full(pred_cls.shape[:1],
-                                self.num_classes,
-                                dtype=torch.int64,
-                                device=device)
+                            self.num_classes,
+                            dtype=torch.int64,
+                            device=device)
         gt_cls[ignore_idx] = -1
         tgt_cls_o = torch.cat([t['labels'][J] for t, (_, J) in zip(targets, indices)])
         tgt_cls_o[pos_ignore_idx] = -1
 
         gt_cls[src_idx] = tgt_cls_o.to(device)
 
-        foreground_idxs = (gt_cls >= 0) & (gt_cls != self.num_classes)
-        num_foreground = foreground_idxs.sum()
+        fg_mask = (gt_cls >= 0) & (gt_cls != self.num_classes)
+        num_fgs = fg_mask.sum()
 
         if is_dist_avail_and_initialized():
-            torch.distributed.all_reduce(num_foreground)
-        num_foreground = torch.clamp(num_foreground / get_world_size(), min=1).item()
+            torch.distributed.all_reduce(num_fgs)
+        num_fgs = torch.clamp(num_fgs / get_world_size(), min=1).item()
 
         # -------------------- Classification loss --------------------
         gt_cls_target = torch.zeros_like(pred_cls)
-        gt_cls_target[foreground_idxs, gt_cls[foreground_idxs]] = 1
-        valid_idxs = (gt_cls >= 0) & masks
-        loss_labels = self.loss_labels(pred_cls[valid_idxs], gt_cls_target[valid_idxs], num_foreground)
+        gt_cls_target[fg_mask, gt_cls[fg_mask]] = 1
+        loss_labels = self.loss_labels(pred_cls, gt_cls_target, num_fgs)
 
         # -------------------- Regression loss --------------------
         tgt_boxes = torch.cat([t['boxes'][i] for t, (_, i) in zip(targets, indices)], dim=0).to(device)
         tgt_boxes = tgt_boxes[~pos_ignore_idx]
         matched_pred_box = pred_box.reshape(-1, 4)[src_idx[~pos_ignore_idx.cpu()]]
-        loss_bboxes = self.loss_bboxes(matched_pred_box, tgt_boxes, num_foreground)
+        loss_bboxes = self.loss_bboxes(matched_pred_box, tgt_boxes, num_fgs)
 
         total_loss = loss_labels * self.weight_dict["loss_cls"] + \
                      loss_bboxes * self.weight_dict["loss_reg"]

+ 1 - 1
yolo/models/yolof/matcher.py

@@ -30,7 +30,7 @@ class UniformMatcher(nn.Module):
         anchor_boxes = anchor_boxes.flatten(0, 1)
 
         # Also concat the target boxes
-        tgt_bbox = torch.cat([v['boxes'] for v in targets])
+        tgt_bbox = torch.cat([v['boxes'] for v in targets]).to(out_bbox.device)
 
         # Compute the L1 cost between boxes
         # Note that we use anchors and predict boxes both

+ 13 - 138
yolo/models/yolof/modules.py

@@ -4,145 +4,20 @@ from typing import List
 
 
 # --------------------- Basic modules ---------------------
-def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
-    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
-
-    return conv
-
-def get_activation(act_type=None):
-    if act_type == 'relu':
-        return nn.ReLU(inplace=True)
-    elif act_type == 'lrelu':
-        return nn.LeakyReLU(0.1, inplace=True)
-    elif act_type == 'mish':
-        return nn.Mish(inplace=True)
-    elif act_type == 'silu':
-        return nn.SiLU(inplace=True)
-    elif act_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-        
-def get_norm(norm_type, dim):
-    if norm_type == 'BN':
-        return nn.BatchNorm2d(dim)
-    elif norm_type == 'GN':
-        return nn.GroupNorm(num_groups=32, num_channels=dim)
-    elif norm_type is None:
-        return nn.Identity()
-    else:
-        raise NotImplementedError
-
-class BasicConv(nn.Module):
+class ConvModule(nn.Module):
     def __init__(self, 
-                 in_dim,                   # in channels
-                 out_dim,                  # out channels 
-                 kernel_size=1,            # kernel size 
-                 padding=0,                # padding
-                 stride=1,                 # padding
-                 dilation=1,               # dilation
-                 act_type  :str = 'lrelu', # activation
-                 norm_type :str = 'BN',    # normalization
-                 depthwise :bool = False
+                 in_dim: int,           # in channels
+                 out_dim: int,          # out channels 
+                 kernel_size: int = 1,  # kernel size 
+                 padding: int = 0,      # padding
+                 stride: int = 1,       # padding
+                 dilation: int = 1,     # dilation
+                 use_act: bool = False,
                 ):
-        super(BasicConv, self).__init__()
-        self.depthwise = depthwise
-        use_bias = False if norm_type is not None else True
-        if not depthwise:
-            self.conv = get_conv2d(in_dim, out_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=1, bias=use_bias)
-            self.norm = get_norm(norm_type, out_dim)
-        else:
-            self.conv1 = get_conv2d(in_dim, in_dim, k=kernel_size, p=padding, s=stride, d=dilation, g=in_dim, bias=use_bias)
-            self.norm1 = get_norm(norm_type, in_dim)
-            self.conv2 = get_conv2d(in_dim, out_dim, k=1, p=0, s=1, d=1, g=1)
-            self.norm2 = get_norm(norm_type, out_dim)
-        self.act  = get_activation(act_type)
+        super(ConvModule, self).__init__()
+        self.conv = nn.Conv2d(in_dim, out_dim, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False)
+        self.norm = nn.BatchNorm2d(out_dim)
+        self.act  = nn.ReLU(inplace=True) if use_act else nn.Identity()
 
     def forward(self, x):
-        if not self.depthwise:
-            return self.act(self.norm(self.conv(x)))
-        else:
-            # Depthwise conv
-            x = self.norm1(self.conv1(x))
-            # Pointwise conv
-            x = self.act(self.norm2(self.conv2(x)))
-            return x
-
-
-# --------------------- ResNet modules ---------------------
-def conv3x3(in_planes, out_planes, stride=1):
-    """3x3 convolution with padding"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
-                     padding=1, bias=False)
-
-def conv1x1(in_planes, out_planes, stride=1):
-    """1x1 convolution"""
-    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
-
-class BasicBlock(nn.Module):
-    expansion = 1
-
-    def __init__(self, inplanes, planes, stride=1, downsample=None):
-        super(BasicBlock, self).__init__()
-        self.conv1 = conv3x3(inplanes, planes, stride)
-        self.bn1 = nn.BatchNorm2d(planes)
-        self.relu = nn.ReLU(inplace=True)
-        self.conv2 = conv3x3(planes, planes)
-        self.bn2 = nn.BatchNorm2d(planes)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x):
-        identity = x
-
-        out = self.conv1(x)
-        out = self.bn1(out)
-        out = self.relu(out)
-
-        out = self.conv2(out)
-        out = self.bn2(out)
-
-        if self.downsample is not None:
-            identity = self.downsample(x)
-
-        out += identity
-        out = self.relu(out)
-
-        return out
-
-class Bottleneck(nn.Module):
-    expansion = 4
-
-    def __init__(self, inplanes, planes, stride=1, downsample=None):
-        super(Bottleneck, self).__init__()
-        self.conv1 = conv1x1(inplanes, planes)
-        self.bn1 = nn.BatchNorm2d(planes)
-        self.conv2 = conv3x3(planes, planes, stride)
-        self.bn2 = nn.BatchNorm2d(planes)
-        self.conv3 = conv1x1(planes, planes * self.expansion)
-        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
-        self.relu = nn.ReLU(inplace=True)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x):
-        identity = x
-
-        out = self.conv1(x)
-        out = self.bn1(out)
-        out = self.relu(out)
-
-        out = self.conv2(out)
-        out = self.bn2(out)
-        out = self.relu(out)
-
-        out = self.conv3(out)
-        out = self.bn3(out)
-
-        if self.downsample is not None:
-            identity = self.downsample(x)
-
-        out += identity
-        out = self.relu(out)
-
-        return out
+        return self.act(self.norm(self.conv(x)))

+ 85 - 10
yolo/models/yolof/resnet.py

@@ -2,15 +2,10 @@ import torch
 import torch.nn as nn
 import torch.utils.model_zoo as model_zoo
 
-try:
-    from .modules import conv1x1, BasicBlock, Bottleneck
-except:
-    from  modules import conv1x1, BasicBlock, Bottleneck
 
 __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
            'resnet152']
 
-
 model_urls = {
     'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
     'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
@@ -20,9 +15,87 @@ model_urls = {
 }
 
 
+# --------------------- ResNet modules ---------------------
+def conv3x3(in_planes, out_planes, stride=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=1, bias=False)
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
+
+class BasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(BasicBlock, self).__init__()
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+class Bottleneck(nn.Module):
+    expansion = 4
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None):
+        super(Bottleneck, self).__init__()
+        self.conv1 = conv1x1(inplanes, planes)
+        self.bn1 = nn.BatchNorm2d(planes)
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bn2 = nn.BatchNorm2d(planes)
+        self.conv3 = conv1x1(planes, planes * self.expansion)
+        self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
 # --------------------- ResNet -----------------------
 class ResNet(nn.Module):
-
     def __init__(self, block, layers, zero_init_residual=False):
         super(ResNet, self).__init__()
         self.inplanes = 64
@@ -88,21 +161,23 @@ class ResNet(nn.Module):
 
         return c5
 
-
-# --------------------- Functions -----------------------
 def build_resnet(model_name="resnet18", pretrained=False):
     if model_name == 'resnet18':
         model = resnet18(pretrained)
         feat_dim = 512
+
     elif model_name == 'resnet34':
         model = resnet34(pretrained)
         feat_dim = 512
+
     elif model_name == 'resnet50':
         model = resnet50(pretrained)
         feat_dim = 2048
+
     elif model_name == 'resnet101':
-        model = resnet34(pretrained)
+        model = resnet101(pretrained)
         feat_dim = 2048
+
     else:
         raise NotImplementedError("Unknown resnet: {}".format(model_name))
     
@@ -184,4 +259,4 @@ if __name__=='__main__':
     flops, params = profile(model, inputs=(x, ), verbose=False)
     print('==============================')
     print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
-    print('Params : {:.2f} M'.format(params / 1e6))    
+    print('Params : {:.2f} M'.format(params / 1e6))

+ 91 - 0
yolo/models/yolof/yolof.py

@@ -0,0 +1,91 @@
+import torch
+import torch.nn as nn
+
+# --------------- Model components ---------------
+from .yolof_backbone import YolofBackbone
+from .yolof_encoder  import DilatedEncoder
+from .yolof_decoder  import YolofHead
+
+# --------------- External components ---------------
+from utils.misc import multiclass_nms
+
+
+# ------------------------ You Only Look One-level Feature ------------------------
+class Yolof(nn.Module):
+    def __init__(self, cfg, is_val: bool = False):
+        super(Yolof, self).__init__()
+        # ---------------------- Basic setting ----------------------
+        self.cfg = cfg
+        self.num_classes = cfg.num_classes
+        ## Post-process parameters
+        self.topk_candidates  = cfg.val_topk        if is_val else cfg.test_topk
+        self.conf_thresh      = cfg.val_conf_thresh if is_val else cfg.test_conf_thresh
+        self.nms_thresh       = cfg.val_nms_thresh  if is_val else cfg.test_nms_thresh
+        self.no_multi_labels  = False if is_val else True
+
+        # ---------------------- Network Parameters ----------------------
+        self.backbone = YolofBackbone(cfg)
+        self.encoder  = DilatedEncoder(cfg, self.backbone.feat_dim, cfg.head_dim)
+        self.decoder  = YolofHead(cfg, self.encoder.out_dim, cfg.head_dim)
+
+    def post_process(self, cls_pred, box_pred):
+        """
+        Input:
+            cls_pred: (Tensor) [[H x W x KA, C]
+            box_pred: (Tensor)  [H x W x KA, 4]
+        """
+        cls_pred = cls_pred[0]
+        box_pred = box_pred[0]
+        
+        # (H x W x KA x C,)
+        scores_i = cls_pred.sigmoid().flatten()
+
+        # Keep top k top scoring indices only.
+        num_topk = min(self.topk_candidates, box_pred.size(0))
+
+        # torch.sort is actually faster than .topk (at least on GPUs)
+        predicted_prob, topk_idxs = scores_i.sort(descending=True)
+        topk_scores = predicted_prob[:num_topk]
+        topk_idxs = topk_idxs[:num_topk]
+
+        # filter out the proposals with low confidence score
+        keep_idxs = topk_scores > self.conf_thresh
+        topk_idxs = topk_idxs[keep_idxs]
+
+        # final scores
+        scores = topk_scores[keep_idxs]
+        # final labels
+        labels = topk_idxs % self.num_classes
+        # final bboxes
+        anchor_idxs = torch.div(topk_idxs, self.num_classes, rounding_mode='floor')
+        bboxes = box_pred[anchor_idxs]
+
+        # to cpu & numpy
+        scores = scores.cpu().numpy()
+        labels = labels.cpu().numpy()
+        bboxes = bboxes.cpu().numpy()
+
+        # nms
+        scores, labels, bboxes = multiclass_nms(
+            scores, labels, bboxes, self.nms_thresh, self.num_classes)
+
+        return bboxes, scores, labels
+
+    def forward(self, x):
+        x = self.backbone(x)
+        x = self.encoder(x)
+        outputs = self.decoder(x)
+
+        if not self.training:
+            # ---------------- PostProcess ----------------
+            cls_pred = outputs["pred_cls"]
+            box_pred = outputs["pred_box"]
+            bboxes, scores, labels = self.post_process(cls_pred, box_pred)
+
+            outputs = {
+                'scores': scores,
+                'labels': labels,
+                'bboxes': bboxes
+            }
+
+        return outputs 

+ 50 - 0
yolo/models/yolof/yolof_backbone.py

@@ -0,0 +1,50 @@
+import torch
+import torch.nn as nn
+
+try:
+    from .resnet import build_resnet
+except:
+    from  resnet import build_resnet
+
+
+# --------------------- Yolov1's Backbone -----------------------
+class YolofBackbone(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.backbone, self.feat_dim = build_resnet(cfg.backbone, cfg.use_pretrained)
+
+    def forward(self, x):
+        pyramid_feats = self.backbone(x)
+
+        return pyramid_feats # [C3, C4, C5]
+
+
+if __name__=='__main__':
+    from thop import profile
+
+    # YOLOv1 configuration
+    class YolofBaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.out_stride = 32
+            ## Backbone
+            self.backbone = 'resnet18'
+            self.use_pretrained = True
+    cfg = YolofBaseConfig()
+
+    # Build backbone
+    model = YolofBackbone(cfg)
+
+    # Randomly generate a input data
+    x = torch.randn(2, 3, 640, 640)
+
+    # Inference
+    output = model(x)
+    print(' - the shape of input :  ', x.shape)
+    print(' - the shape of output : ', output.shape)
+
+    x = torch.randn(1, 3, 640, 640)
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('============== FLOPs & Params ================')
+    print(' - FLOPs  : {:.2f} G'.format(flops / 1e9 * 2))
+    print(' - Params : {:.2f} M'.format(params / 1e6))

+ 36 - 63
yolo/models/yolof/yolof_decoder.py

@@ -2,62 +2,48 @@ import math
 import torch
 import torch.nn as nn
 
-from .modules import BasicConv
+try:
+    from .modules import ConvModule
+except:
+    from  modules import ConvModule
 
 
 class YolofHead(nn.Module):
-    def __init__(self, cfg, in_dim, out_dim,):
+    def __init__(self, cfg, in_dim: int, out_dim: int,):
         super().__init__()
-        self.fmp_size = None
-        self.ctr_clamp = cfg.center_clamp
+        self.ctr_clamp = 32
         self.DEFAULT_EXP_CLAMP = math.log(1e8)
         self.DEFAULT_SCALE_CLAMP = math.log(1000.0 / 16)
         # ------------------ Basic parameters -------------------
         self.cfg = cfg
         self.in_dim = in_dim
-        self.stride       = cfg.out_stride
+        self.out_stride   = cfg.out_stride
         self.num_classes  = cfg.num_classes
         self.num_cls_head = cfg.num_cls_head
         self.num_reg_head = cfg.num_reg_head
-        self.act_type     = cfg.head_act
-        self.norm_type    = cfg.head_norm
         # Anchor config
         self.anchor_size = torch.as_tensor(cfg.anchor_size)
         self.num_anchors = len(cfg.anchor_size)
 
         # ------------------ Network parameters -------------------
-        ## cls head
+        ## classification head
         cls_heads = []
         self.cls_head_dim = out_dim
         for i in range(self.num_cls_head):
             if i == 0:
-                cls_heads.append(
-                    BasicConv(in_dim, self.cls_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=self.act_type, norm_type=self.norm_type)
-                              )
+                cls_heads.append(ConvModule(in_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
             else:
-                cls_heads.append(
-                    BasicConv(self.cls_head_dim, self.cls_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=self.act_type, norm_type=self.norm_type)
-                              )
-        ## reg head
+                cls_heads.append(ConvModule(self.cls_head_dim, self.cls_head_dim, kernel_size=3, padding=1, stride=1))
+
+        ## bbox regression head
         reg_heads = []
         self.reg_head_dim = out_dim
         for i in range(self.num_reg_head):
             if i == 0:
-                reg_heads.append(
-                    BasicConv(in_dim, self.reg_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=self.act_type, norm_type=self.norm_type)
-                              )
+                reg_heads.append(ConvModule(in_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
             else:
-                reg_heads.append(
-                    BasicConv(self.reg_head_dim, self.reg_head_dim,
-                              kernel_size=3, padding=1, stride=1, 
-                              act_type=self.act_type, norm_type=self.norm_type)
-                              )
+                reg_heads.append(ConvModule(self.reg_head_dim, self.reg_head_dim, kernel_size=3, padding=1, stride=1))
+
         self.cls_heads = nn.Sequential(*cls_heads)
         self.reg_heads = nn.Sequential(*reg_heads)
 
@@ -86,30 +72,25 @@ class YolofHead(nn.Module):
         """fmp_size: list -> [H, W] \n
            stride: int -> output stride
         """
-        # check anchor boxes
-        if self.fmp_size is not None and self.fmp_size == fmp_size:
-            return self.anchor_boxes
-        else:
-            # generate grid cells
-            fmp_h, fmp_w = fmp_size
-            anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
-            # [H, W, 2] -> [HW, 2]
-            anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5
-            # [HW, 2] -> [HW, 1, 2] -> [HW, KA, 2] 
-            anchor_xy = anchor_xy[:, None, :].repeat(1, self.num_anchors, 1)
-            anchor_xy *= self.stride
-
-            # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2]
-            anchor_wh = self.anchor_size[None, :, :].repeat(fmp_h*fmp_w, 1, 1)
-
-            # [HW, KA, 4] -> [M, 4]
-            anchor_boxes = torch.cat([anchor_xy, anchor_wh], dim=-1)
-            anchor_boxes = anchor_boxes.view(-1, 4)
-
-            self.anchor_boxes = anchor_boxes
-            self.fmp_size = fmp_size
-
-            return anchor_boxes
+        # generate grid cells
+        fmp_h, fmp_w = fmp_size
+        anchor_y, anchor_x = torch.meshgrid([torch.arange(fmp_h), torch.arange(fmp_w)])
+
+        # anchor points: [H, W, 2] -> [HW, 2]
+        anchor_xy = torch.stack([anchor_x, anchor_y], dim=-1).float().view(-1, 2) + 0.5
+
+        # [HW, 2] -> [HW, 1, 2] -> [HW, KA, 2] 
+        anchor_xy = anchor_xy[:, None, :].repeat(1, self.num_anchors, 1)
+        anchor_xy *= self.out_stride       # [KA, 2] -> [1, KA, 2] -> [HW, KA, 2]
+
+        # anchor boxes: [KA, 2] -> [HW, KA, 2]
+        anchor_wh = self.anchor_size[None, :, :].repeat(fmp_h*fmp_w, 1, 1)
+
+        # [HW, KA, 4] -> [M, 4], M = H*W*KA
+        anchor_boxes = torch.cat([anchor_xy, anchor_wh], dim=-1)
+        anchor_boxes = anchor_boxes.view(-1, 4)
+
+        return anchor_boxes
         
     def decode_boxes(self, anchor_boxes, pred_reg):
         """
@@ -135,7 +116,7 @@ class YolofHead(nn.Module):
 
         return pred_box
 
-    def forward(self, x, mask=None):
+    def forward(self, x):
         # ------------------- Decoupled head -------------------
         cls_feats = self.cls_heads(x)
         reg_feats = self.reg_heads(x)
@@ -167,19 +148,11 @@ class YolofHead(nn.Module):
         reg_pred = reg_pred.view(B, -1, 4)
         ## Decode bbox
         box_pred = self.decode_boxes(anchor_boxes[None], reg_pred)  # [B, M, 4]
-        ## adjust mask
-        if mask is not None:
-            # [B, H, W]
-            mask = torch.nn.functional.interpolate(mask[None].float(), size=fmp_size).bool()[0]
-            # [B, H, W] -> [B, HW]
-            mask = mask.flatten(1)
-            # [B, HW] -> [B, HW, KA] -> [BM,], M= HW x KA
-            mask = mask[..., None].repeat(1, 1, self.num_anchors).flatten()
 
         outputs = {"pred_cls": normalized_cls_pred,
                    "pred_reg": reg_pred,
                    "pred_box": box_pred,
                    "anchors": anchor_boxes,
-                   "mask": mask}
+                   }
 
         return outputs 

+ 52 - 34
yolo/models/yolof/yolof_encoder.py

@@ -1,23 +1,26 @@
+import torch
 import torch.nn as nn
-from utils import weight_init
 
-from .modules import BasicConv
+try:
+    from .modules import ConvModule
+except:
+    from  modules import ConvModule
 
 
 # BottleNeck
 class Bottleneck(nn.Module):
-    def __init__(self, in_dim, dilation, expand_ratio, act_type='relu', norm_type='BN'):
+    def __init__(self, in_dim: int, dilation: int = 1, expansion: float = 0.5):
         super(Bottleneck, self).__init__()
         # ------------------ Basic parameters -------------------
         self.in_dim = in_dim
         self.dilation = dilation
-        self.expand_ratio = expand_ratio
-        inter_dim = round(in_dim * expand_ratio)
+        self.expansion = expansion
+        inter_dim = round(in_dim * expansion)
         # ------------------ Network parameters -------------------
         self.branch = nn.Sequential(
-            BasicConv(in_dim, inter_dim, kernel_size=1, act_type=act_type, norm_type=norm_type),
-            BasicConv(inter_dim, inter_dim, kernel_size=3, padding=dilation, dilation=dilation, act_type=act_type, norm_type=norm_type),
-            BasicConv(inter_dim, in_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
+            ConvModule(in_dim, inter_dim, kernel_size=1),
+            ConvModule(inter_dim, inter_dim, kernel_size=3, padding=dilation, dilation=dilation),
+            ConvModule(inter_dim, in_dim, kernel_size=1)
         )
 
     def forward(self, x):
@@ -32,41 +35,56 @@ class DilatedEncoder(nn.Module):
         self.out_dim = out_dim
         self.expand_ratio = cfg.neck_expand_ratio
         self.dilations    = cfg.neck_dilations
-        self.act_type     = cfg.neck_act
-        self.norm_type    = cfg.neck_norm
         # ------------------ Network parameters -------------------
         ## proj layer
         self.projector = nn.Sequential(
-            BasicConv(in_dim, out_dim, kernel_size=1, act_type=None, norm_type=self.norm_type),
-            BasicConv(out_dim, out_dim, kernel_size=3, padding=1, act_type=None, norm_type=self.norm_type)
+            ConvModule(in_dim,  out_dim, kernel_size=1, use_act=False),
+            ConvModule(out_dim, out_dim, kernel_size=3, padding=1, use_act=False)
         )
         ## encoder layers
         self.encoders = nn.Sequential(
-            *[Bottleneck(out_dim, d, self.expand_ratio, self.act_type, self.norm_type) for d in self.dilations])
-
-        self._init_weight()
-
-    def _init_weight(self):
-        for m in self.projector:
-            if isinstance(m, nn.Conv2d):
-                weight_init.c2_xavier_fill(m)
-                weight_init.c2_xavier_fill(m)
-            if isinstance(m, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
-                nn.init.constant_(m.weight, 1)
-                nn.init.constant_(m.bias, 0)
-
-        for m in self.encoders.modules():
-            if isinstance(m, nn.Conv2d):
-                nn.init.normal_(m.weight, mean=0, std=0.01)
-                if hasattr(m, 'bias') and m.bias is not None:
-                    nn.init.constant_(m.bias, 0)
-
-            if isinstance(m, (nn.GroupNorm, nn.BatchNorm2d, nn.SyncBatchNorm)):
-                nn.init.constant_(m.weight, 1)
-                nn.init.constant_(m.bias, 0)
+            *[Bottleneck(in_dim = out_dim,
+                         dilation = d,
+                         expansion = self.expand_ratio,
+                         ) for d in self.dilations])
 
     def forward(self, x):
         x = self.projector(x)
         x = self.encoders(x)
 
         return x
+
+
+if __name__=='__main__':
+    from thop import profile
+
+    # YOLOv1 configuration
+    class YolofBaseConfig(object):
+        def __init__(self) -> None:
+            # ---------------- Model config ----------------
+            self.out_stride = 32
+            ## Backbone
+            self.backbone = 'resnet18'
+            self.use_pretrained = True
+
+            self.neck_expand_ratio = 0.25
+            self.neck_dilations = [2, 4, 6, 8]
+
+    cfg = YolofBaseConfig()
+
+    # Randomly generate a input data
+    x = torch.randn(2, 512, 20, 20)
+
+    # Build backbone
+    model = DilatedEncoder(cfg, in_dim=512, out_dim=512)
+
+    # Inference
+    output = model(x)
+    print(' - the shape of input :  ', x.shape)
+    print(' - the shape of output : ', output.shape)
+
+    x = torch.randn(1, 512, 20, 20)
+    flops, params = profile(model, inputs=(x, ), verbose=False)
+    print('============== FLOPs & Params ================')
+    print(' - FLOPs  : {:.2f} G'.format(flops / 1e9 * 2))
+    print(' - Params : {:.2f} M'.format(params / 1e6))

+ 2 - 2
yolo/utils/misc.py

@@ -204,7 +204,7 @@ class CollateFunc(object):
 
 # ---------------------------- For Loss ----------------------------
 ## FocalLoss
-def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2):
+def sigmoid_focal_loss(inputs, targets, alpha: float = 0.25, gamma: float = 2):
     """
     Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
     Args:
@@ -229,7 +229,7 @@ def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: f
         alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
         loss = alpha_t * loss
 
-    return loss.mean(1).sum() / num_boxes
+    return loss
 
 ## Variable FocalLoss
 def varifocal_loss_with_logits(pred_logits,