yjh0410 1 gadu atpakaļ
vecāks
revīzija
ddf611a65e

+ 14 - 125
engine.py

@@ -759,7 +759,7 @@ class YoloxTrainer(object):
 
         return images, targets, new_img_size
 
-## RTCDet Trainer
+## Real-time Convolutional Object Detector Trainer
 class RTCTrainer(object):
     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
         # ------------------- basic parameters -------------------
@@ -1121,7 +1121,7 @@ class RTCTrainer(object):
             args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
         self.train_loader.dataset.transform = self.train_transform
    
-## RTRDet Trainer
+## Real-time Transformer-based Object Detector Trainer
 class RTRTrainer(object):
     def __init__(self, args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size):
         # ------------------- Basic parameters -------------------
@@ -1132,21 +1132,14 @@ class RTRTrainer(object):
         self.criterion = criterion
         self.world_size = world_size
         self.grad_accumulate = args.grad_accumulate
-        self.clip_grad = 35
-        self.heavy_eval = False
-        # weak augmentatino stage
-        self.second_stage = False
-        self.third_stage = False
-        self.second_stage_epoch = args.no_aug_epoch
-        self.third_stage_epoch = args.no_aug_epoch // 2
+        self.clip_grad = 0.1
         # path to save model
         self.path_to_save = os.path.join(args.save_folder, args.dataset, args.model)
         os.makedirs(self.path_to_save, exist_ok=True)
 
         # ---------------------------- Hyperparameters refer to RTMDet ----------------------------
         self.optimizer_dict = {'optimizer': 'adamw', 'momentum': None, 'weight_decay': 1e-4, 'lr0': 0.0001, 'backbone_lr_ratio': 0.1}
-        self.ema_dict = {'ema_decay': 0.9998, 'ema_tau': 2000}
-        self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.05}
+        self.lr_schedule_dict = {'scheduler': 'cosine', 'lrf': 0.1}
         self.warmup_dict = {'warmup_momentum': 0.8, 'warmup_bias_lr': 0.1}        
 
         # ---------------------------- Build Dataset & Model & Trans. Config ----------------------------
@@ -1175,70 +1168,26 @@ class RTRTrainer(object):
         self.optimizer, self.start_epoch = build_detr_optimizer(self.optimizer_dict, model, self.args.resume)
 
         # ---------------------------- Build LR Scheduler ----------------------------
-        self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch - args.no_aug_epoch)
+        self.lr_scheduler, self.lf = build_lr_scheduler(self.lr_schedule_dict, self.optimizer, args.max_epoch)
         self.lr_scheduler.last_epoch = self.start_epoch - 1  # do not move
         if self.args.resume and self.args.resume != 'None':
             self.lr_scheduler.step()
 
-        # ---------------------------- Build Model-EMA ----------------------------
-        if self.args.ema and distributed_utils.get_rank() in [-1, 0]:
-            print('Build ModelEMA ...')
-            self.model_ema = ModelEMA(self.ema_dict, model, self.start_epoch * len(self.train_loader))
-        else:
-            self.model_ema = None
-
-
     def train(self, model):
         for epoch in range(self.start_epoch, self.args.max_epoch):
             if self.args.distributed:
                 self.train_loader.batch_sampler.sampler.set_epoch(epoch)
 
-            # check second stage
-            if epoch >= (self.args.max_epoch - self.second_stage_epoch - 1) and not self.second_stage:
-                self.check_second_stage()
-                # save model of the last mosaic epoch
-                weight_name = '{}_last_mosaic_epoch.pth'.format(self.args.model)
-                checkpoint_path = os.path.join(self.path_to_save, weight_name)
-                print('Saving state of the last Mosaic epoch-{}.'.format(self.epoch))
-                torch.save({'model': model.state_dict(),
-                            'mAP': round(self.evaluator.map*100, 1),
-                            'optimizer': self.optimizer.state_dict(),
-                            'epoch': self.epoch,
-                            'args': self.args}, 
-                            checkpoint_path)
-
-            # check third stage
-            if epoch >= (self.args.max_epoch - self.third_stage_epoch - 1) and not self.third_stage:
-                self.check_third_stage()
-                # save model of the last mosaic epoch
-                weight_name = '{}_last_weak_augment_epoch.pth'.format(self.args.model)
-                checkpoint_path = os.path.join(self.path_to_save, weight_name)
-                print('Saving state of the last weak augment epoch-{}.'.format(self.epoch))
-                torch.save({'model': model.state_dict(),
-                            'mAP': round(self.evaluator.map*100, 1),
-                            'optimizer': self.optimizer.state_dict(),
-                            'epoch': self.epoch,
-                            'args': self.args}, 
-                            checkpoint_path)
-
             # train one epoch
             self.epoch = epoch
             self.train_one_epoch(model)
 
             # eval one epoch
-            if self.heavy_eval:
+            if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
                 model_eval = model.module if self.args.distributed else model
                 self.eval(model_eval)
-            else:
-                model_eval = model.module if self.args.distributed else model
-                if (epoch % self.args.eval_epoch) == 0 or (epoch == self.args.max_epoch - 1):
-                    self.eval(model_eval)
-
 
     def eval(self, model):
-        # chech model
-        model_eval = model if self.model_ema is None else self.model_ema.ema
-
         if distributed_utils.is_main_process():
             # check evaluator
             if self.evaluator is None:
@@ -1246,7 +1195,7 @@ class RTRTrainer(object):
                 print('Saving state, epoch: {}'.format(self.epoch))
                 weight_name = '{}_no_eval.pth'.format(self.args.model)
                 checkpoint_path = os.path.join(self.path_to_save, weight_name)
-                torch.save({'model': model_eval.state_dict(),
+                torch.save({'model': model.state_dict(),
                             'mAP': -1.,
                             'optimizer': self.optimizer.state_dict(),
                             'epoch': self.epoch,
@@ -1255,12 +1204,12 @@ class RTRTrainer(object):
             else:
                 print('eval ...')
                 # set eval mode
-                model_eval.trainable = False
-                model_eval.eval()
+                model.trainable = False
+                model.eval()
 
                 # evaluate
                 with torch.no_grad():
-                    self.evaluator.evaluate(model_eval)
+                    self.evaluator.evaluate(model)
 
                 # save model
                 cur_map = self.evaluator.map
@@ -1271,7 +1220,7 @@ class RTRTrainer(object):
                     print('Saving state, epoch:', self.epoch)
                     weight_name = '{}_best.pth'.format(self.args.model)
                     checkpoint_path = os.path.join(self.path_to_save, weight_name)
-                    torch.save({'model': model_eval.state_dict(),
+                    torch.save({'model': model.state_dict(),
                                 'mAP': round(self.best_map*100, 1),
                                 'optimizer': self.optimizer.state_dict(),
                                 'epoch': self.epoch,
@@ -1279,14 +1228,13 @@ class RTRTrainer(object):
                                 checkpoint_path)                      
 
                 # set train mode.
-                model_eval.trainable = True
-                model_eval.train()
+                model.trainable = True
+                model.train()
 
         if self.args.distributed:
             # wait for all processes to synchronize
             dist.barrier()
 
-
     def train_one_epoch(self, model):
         # basic parameters
         epoch_size = len(self.train_loader)
@@ -1383,7 +1331,6 @@ class RTRTrainer(object):
         if not self.second_stage:
             self.lr_scheduler.step()
         
-
     def refine_targets(self, targets, min_box_size):
         # rescale targets
         for tgt in targets:
@@ -1399,7 +1346,6 @@ class RTRTrainer(object):
         
         return targets
 
-
     def normalize_bbox(self, targets, img_size):
         # normalize targets
         for tgt in targets:
@@ -1407,7 +1353,6 @@ class RTRTrainer(object):
         
         return targets
 
-
     def denormalize_bbox(self, targets, img_size):
         # normalize targets
         for tgt in targets:
@@ -1415,7 +1360,6 @@ class RTRTrainer(object):
         
         return targets
 
-
     def rescale_image_targets(self, images, targets, stride, min_box_size, multi_scale_range=[0.5, 1.5]):
         """
             Deployed for Multi scale trick.
@@ -1455,61 +1399,6 @@ class RTRTrainer(object):
         return images, targets, new_img_size
 
 
-    def check_second_stage(self):
-        # set second stage
-        print('============== Second stage of Training ==============')
-        self.second_stage = True
-
-        # close mosaic augmentation
-        if self.train_loader.dataset.mosaic_prob > 0.:
-            print(' - Close < Mosaic Augmentation > ...')
-            self.train_loader.dataset.mosaic_prob = 0.
-            self.heavy_eval = True
-
-        # close mixup augmentation
-        if self.train_loader.dataset.mixup_prob > 0.:
-            print(' - Close < Mixup Augmentation > ...')
-            self.train_loader.dataset.mixup_prob = 0.
-            self.heavy_eval = True
-
-        # close rotation augmentation
-        if 'degrees' in self.trans_cfg.keys() and self.trans_cfg['degrees'] > 0.0:
-            print(' - Close < degress of rotation > ...')
-            self.trans_cfg['degrees'] = 0.0
-        if 'shear' in self.trans_cfg.keys() and self.trans_cfg['shear'] > 0.0:
-            print(' - Close < shear of rotation >...')
-            self.trans_cfg['shear'] = 0.0
-        if 'perspective' in self.trans_cfg.keys() and self.trans_cfg['perspective'] > 0.0:
-            print(' - Close < perspective of rotation > ...')
-            self.trans_cfg['perspective'] = 0.0
-
-        # build a new transform for second stage
-        print(' - Rebuild transforms ...')
-        self.train_transform, self.trans_cfg = build_transform(
-            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
-        self.train_loader.dataset.transform = self.train_transform
-        
-
-    def check_third_stage(self):
-        # set third stage
-        print('============== Third stage of Training ==============')
-        self.third_stage = True
-
-        # close random affine
-        if 'translate' in self.trans_cfg.keys() and self.trans_cfg['translate'] > 0.0:
-            print(' - Close < translate of affine > ...')
-            self.trans_cfg['translate'] = 0.0
-        if 'scale' in self.trans_cfg.keys():
-            print(' - Close < scale of affine >...')
-            self.trans_cfg['scale'] = [1.0, 1.0]
-
-        # build a new transform for second stage
-        print(' - Rebuild transforms ...')
-        self.train_transform, self.trans_cfg = build_transform(
-            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
-        self.train_loader.dataset.transform = self.train_transform
-        
-
 # ----------------------- Det + Seg trainers -----------------------
 ## RTCDet Trainer for Det + Seg
 class RTCTrainerDS(object):
@@ -2206,7 +2095,7 @@ def build_trainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion
         return YoloxTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
     elif model_cfg['trainer_type'] == 'rtcdet':
         return RTCTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
-    elif model_cfg['trainer_type'] == 'rtrdet':
+    elif model_cfg['trainer_type'] == 'rtdetr':
         return RTRTrainer(args, data_cfg, model_cfg, trans_cfg, device, model, criterion, world_size)
     
     # ----------------------- Det + Seg trainers -----------------------

+ 56 - 0
models/detectors/rtdetr/README.md

@@ -0,0 +1,56 @@
+# Real-time Transformer-based Object Detector:
+This model is not yet complete.
+
+## Results on the COCO-val
+|     Model    | Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
+|--------------|-------|-------|------------------------|-------------------|-------------------|--------------------|--------|
+| RT-DETR-R18  | 2xb8  |  640  |                        |                   |                   |                    |  |
+| RT-DETR-R50  | 2xb8  |  640  |                        |                   |                   |                    |  |
+| RT-DETR-R101 | 2xb8  |  640  |                        |                   |                   |                    |  |
+
+- For the backbone of the image encoder, we use the IN-1K classification pretrained weight. It might be hard to train RT-DETR from scratch without IN-1K pretrained weight.
+- For training, we train RT-DETR series with 6x (~72 epochs) schedule on COCO.
+- For data augmentation, we use the `color jitter`, `random hflip`, `random crop`, and multi-scale training trick.
+- For optimizer, we use AdamW with weight decay 0.0001 and base per image lr 0.001 / 16.
+- For learning rate scheduler, we use `cosine` decay scheduler.
+
+## Train RT-DETR
+### Single GPU
+Taking training RT-DETR-R18 on COCO as the example,
+```Shell
+python train.py --cuda -d coco --root path/to/coco -m rtdetr_r18 -bs 16 -size 640 --max_epoch 72 --eval_epoch 5 --no_aug_epoch -1 --ema --fp16 --multi_scale 
+```
+
+### Multi GPU
+Taking training RT-DETR-R18 on COCO as the example,
+```Shell
+python -m torch.distributed.run --nproc_per_node=8 train.py --cuda -dist -d coco --root /data/datasets/ -m rtdetr_r18 -bs 16 -size 640 --max_epoch 72 --eval_epoch 5 --no_aug_epoch -1 --ema --fp16 --sybn --multi_scale --save_folder weights/ 
+```
+
+## Test RT-DETR
+Taking testing RT-DETR-R18 on COCO-val as the example,
+```Shell
+python test.py --cuda -d coco --root path/to/coco -m rtdetr_r18 --weight path/to/rtdetr_r18.pth -size 640 -vt 0.4 --show 
+```
+
+## Evaluate RT-DETR
+Taking evaluating RT-DETR-R18 on COCO-val as the example,
+```Shell
+python eval.py --cuda -d coco-val --root path/to/coco -m rtdetr_r18 --weight path/to/rtdetr_r18.pth 
+```
+
+## Demo
+### Detect with Image
+```Shell
+python demo.py --mode image --path_to_img path/to/image_dirs/ --cuda -m rtdetr_r18 --weight path/to/weight -size 640 -vt 0.4 --show
+```
+
+### Detect with Video
+```Shell
+python demo.py --mode video --path_to_vid path/to/video --cuda -m rtdetr_r18 --weight path/to/weight -size 640 -vt 0.4 --show --gif
+```
+
+### Detect with Camera
+```Shell
+python demo.py --mode camera --cuda -m rtdetr_r18 --weight path/to/weight -size 640 -vt 0.4 --show --gif
+```

+ 86 - 156
models/detectors/rtdetr/basic_modules/backbone.py

@@ -1,22 +1,24 @@
 import torch
-import torch.nn as nn
-from torch import Tensor
-from typing import Callable, List, Optional, Type, Union
-
+import torchvision
+from torch import nn
+from torchvision.models._utils import IntermediateLayerGetter
+from torchvision.models.resnet import (ResNet18_Weights,
+                                       ResNet34_Weights,
+                                       ResNet50_Weights,
+                                       ResNet101_Weights)
 try:
-    from .basic import conv1x1, BasicBlock, Bottleneck
+    from .basic import FrozenBatchNorm2d
 except:
-    from basic import conv1x1, BasicBlock, Bottleneck
+    from basic  import FrozenBatchNorm2d
    
 
 # IN1K pretrained weights
 pretrained_urls = {
     # ResNet series
-    'resnet18': None,
-    'resnet34': None,
-    'resnet50': None,
-    'resnet101': None,
-    'resnet152': None,
+    'resnet18':  ResNet18_Weights,
+    'resnet34':  ResNet34_Weights,
+    'resnet50':  ResNet50_Weights,
+    'resnet101': ResNet101_Weights,
     # ShuffleNet series
 }
 
@@ -24,164 +26,92 @@ pretrained_urls = {
 # ----------------- Model functions -----------------
 ## Build backbone network
 def build_backbone(cfg, pretrained):
+    print('==============================')
+    print('Backbone: {}'.format(cfg['backbone']))
+    # ResNet
     if 'resnet' in cfg['backbone']:
-        # Build ResNet
-        model, feats = build_resnet(cfg, pretrained)
+        pretrained_weight = cfg['pretrained_weight'] if pretrained else None
+        model, feats = build_resnet(cfg, pretrained_weight)
+    elif 'svnetv2' in cfg['backbone']:
+        pretrained_weight = cfg['pretrained_weight'] if pretrained else None
+        model, feats = build_scnetv2(cfg, pretrained_weight)
     else:
         raise NotImplementedError("Unknown backbone: <>.".format(cfg['backbone']))
     
     return model, feats
 
-## Load pretrained weight
-def load_pretrained(model_name):
-    return
-
 
 # ----------------- ResNet Backbone -----------------
 class ResNet(nn.Module):
-    def __init__(self,
-                 block: Type[Union[BasicBlock, Bottleneck]],
-                 layers: List[int],
-                 num_classes: int = 1000,
-                 zero_init_residual: bool = False,
-                 groups: int = 1,
-                 width_per_group: int = 64,
-                 replace_stride_with_dilation: Optional[List[bool]] = None,
-                 norm_layer: Optional[Callable[..., nn.Module]] = None,
-                 ) -> None:
+    """ResNet backbone with frozen BatchNorm."""
+    def __init__(self, name: str, res5_dilation: bool, norm_type: str, pretrained_weights: str = "imagenet1k_v1"):
         super().__init__()
-        # --------------- Basic parameters ----------------
-        self.groups = groups
-        self.base_width = width_per_group
-        self.inplanes = 64
-        self.dilation = 1
-        self.zero_init_residual = zero_init_residual
-        self.replace_stride_with_dilation = [False, False, False] if replace_stride_with_dilation is None else replace_stride_with_dilation
-        if len(self.replace_stride_with_dilation) != 3:
-            raise ValueError(
-                "replace_stride_with_dilation should be None "
-                f"or a 3-element tuple, got {self.replace_stride_with_dilation}"
-            )
-
-        # --------------- Network parameters ----------------
-        self._norm_layer = nn.BatchNorm2d if norm_layer is None else norm_layer
-        ## Stem layer
-        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, bias=False)
-        self.bn1 = self._norm_layer(self.inplanes)
-        self.relu = nn.ReLU(inplace=True)
-        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
-        ## Res Layer
-        self.layer1 = self._make_layer(block, 64, layers[0])
-        self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=self.replace_stride_with_dilation[0])
-        self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=self.replace_stride_with_dilation[1])
-        self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=self.replace_stride_with_dilation[2])
-        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
-        self.fc = nn.Linear(512 * block.expansion, num_classes)
-
-        self._init_layer()
-
-    def _init_layer(self):
-        for m in self.modules():
-            if isinstance(m, nn.Conv2d):
-                nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
-            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
-                nn.init.constant_(m.weight, 1)
-                nn.init.constant_(m.bias, 0)
-
-        if self.zero_init_residual:
-            for m in self.modules():
-                if isinstance(m, Bottleneck) and m.bn3.weight is not None:
-                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
-                elif isinstance(m, BasicBlock) and m.bn2.weight is not None:
-                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
-
-    def _make_layer(
-        self,
-        block: Type[Union[BasicBlock, Bottleneck]],
-        planes: int,
-        blocks: int,
-        stride: int = 1,
-        dilate: bool = False,
-    ) -> nn.Sequential:
-        norm_layer = self._norm_layer
-        downsample = None
-        previous_dilation = self.dilation
-        if dilate:
-            self.dilation *= stride
-            stride = 1
-        if stride != 1 or self.inplanes != planes * block.expansion:
-            downsample = nn.Sequential(
-                conv1x1(self.inplanes, planes * block.expansion, stride),
-                norm_layer(planes * block.expansion),
-            )
-
-        layers = []
-        layers.append(
-            block(
-                self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer
-            )
-        )
-        self.inplanes = planes * block.expansion
-        for _ in range(1, blocks):
-            layers.append(
-                block(
-                    self.inplanes,
-                    planes,
-                    groups=self.groups,
-                    base_width=self.base_width,
-                    dilation=self.dilation,
-                    norm_layer=norm_layer,
-                )
-            )
-
-        return nn.Sequential(*layers)
-
-    def forward(self, x: Tensor) -> Tensor:
-        # See note [TorchScript super()]
-        x = self.conv1(x)
-        x = self.bn1(x)
-        x = self.relu(x)
-        x = self.maxpool(x)
-
-        x = self.layer1(x)
-        x = self.layer2(x)
-        x = self.layer3(x)
-        x = self.layer4(x)
-
-        x = self.avgpool(x)
-        x = torch.flatten(x, 1)
-        x = self.fc(x)
-
-        return x
+        # Pretrained
+        assert pretrained_weights in [None, "imagenet1k_v1", "imagenet1k_v2"]
+        if pretrained_weights is not None:
+            if name in ('resnet18', 'resnet34'):
+                pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
+            else:
+                if pretrained_weights == "imagenet1k_v1":
+                    pretrained_weights = pretrained_urls[name].IMAGENET1K_V1
+                else:
+                    pretrained_weights = pretrained_urls[name].IMAGENET1K_V2
+        else:
+            pretrained_weights = None
+        print('ImageNet pretrained weight: ', pretrained_weights)
+        # Norm layer
+        if norm_type == 'BN':
+            norm_layer = nn.BatchNorm2d
+        elif norm_type == 'FrozeBN':
+            norm_layer = FrozenBatchNorm2d
+        # Backbone
+        backbone = getattr(torchvision.models, name)(
+            replace_stride_with_dilation=[False, False, res5_dilation],
+            norm_layer=norm_layer, weights=pretrained_weights)
+        return_layers = {"layer2": "0", "layer3": "1", "layer4": "2"}
+        self.body = IntermediateLayerGetter(backbone, return_layers=return_layers)
+        self.feat_dims = [128, 256, 512] if name in ('resnet18', 'resnet34') else [512, 1024, 2048]
+        # Freeze
+        for name, parameter in backbone.named_parameters():
+            if 'layer2' not in name and 'layer3' not in name and 'layer4' not in name:
+                parameter.requires_grad_(False)
+
+    def forward(self, x):
+        xs = self.body(x)
+        fmp_list = []
+        for name, fmp in xs.items():
+            fmp_list.append(fmp)
+
+        return fmp_list
+
+def build_resnet(cfg, pretrained_weight=None):
+    # ResNet series
+    backbone = ResNet(cfg['backbone'], cfg['res5_dilation'], cfg['backbone_norm'], pretrained_weight)
 
-def _resnet(block: Type[Union[BasicBlock, Bottleneck]], layers: List[int], **kwargs) -> ResNet:
-    return ResNet(block, layers, **kwargs)
+    return backbone, backbone.feat_dims
 
-def build_resnet(cfg, pretrained=False, **kwargs):
-    # ---------- Build ResNet ----------
-    if   cfg['backbone'] == 'resnet18':
-        model = _resnet(BasicBlock, [2, 2, 2, 2], **kwargs)
-        feats = [128, 256, 512]
-    elif cfg['backbone'] == 'resnet34':
-        model = _resnet(BasicBlock, [3, 4, 6, 3], **kwargs)
-        feats = [128, 256, 512]
-    elif cfg['backbone'] == 'resnet50':
-        model = _resnet(Bottleneck, [3, 4, 6, 3], **kwargs)
-        feats = [512, 1024, 2048]
-    elif cfg['backbone'] == 'resnet101':
-        model = _resnet(Bottleneck, [3, 4, 23, 3], **kwargs)
-        feats = [512, 1024, 2048]
-    elif cfg['backbone'] == 'resnet152':
-        model = _resnet(Bottleneck, [3, 8, 36, 3], **kwargs)
-        feats = [512, 1024, 2048]
 
-    # ---------- Load pretrained ----------
-    if pretrained:
-        # TODO: load IN1K pretrained
-        pass
+# ----------------- ShuffleNet Backbone -----------------
+## TODO: Add shufflenet-v2
+class ShuffleNetv2:
+    pass
 
-    return model, feats
+def build_scnetv2(cfg, pretrained_weight=None):
+    return
 
 
-# ----------------- ShuffleNet Backbone -----------------
-## TODO: Add shufflenet-v2
+if __name__ == '__main__':
+    cfg = {
+        'backbone':      'resnet18',
+        'backbone_norm': 'FrozeBN',
+        'res5_dilation': False,
+        'pretrained': True,
+        'pretrained_weight': 'imagenet1k_v1',
+    }
+    model, feat_dim = build_backbone(cfg, cfg['pretrained'])
+    print(feat_dim)
+
+    x = torch.randn(2, 3, 320, 320)
+    output = model(x)
+    for y in output:
+        print(y.size())

+ 30 - 101
models/detectors/rtdetr/basic_modules/basic.py

@@ -1,7 +1,5 @@
 import torch
 import torch.nn as nn
-from torch import Tensor
-from typing import List, Optional, Callable
 
 
 # ----------------- CNN modules -----------------
@@ -51,6 +49,36 @@ def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
     """1x1 convolution"""
     return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
 
+class FrozenBatchNorm2d(torch.nn.Module):
+    def __init__(self, n):
+        super(FrozenBatchNorm2d, self).__init__()
+        self.register_buffer("weight", torch.ones(n))
+        self.register_buffer("bias", torch.zeros(n))
+        self.register_buffer("running_mean", torch.zeros(n))
+        self.register_buffer("running_var", torch.ones(n))
+
+    def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+                              missing_keys, unexpected_keys, error_msgs):
+        num_batches_tracked_key = prefix + 'num_batches_tracked'
+        if num_batches_tracked_key in state_dict:
+            del state_dict[num_batches_tracked_key]
+
+        super(FrozenBatchNorm2d, self)._load_from_state_dict(
+            state_dict, prefix, local_metadata, strict,
+            missing_keys, unexpected_keys, error_msgs)
+
+    def forward(self, x):
+        # move reshapes to the beginning
+        # to make it fuser-friendly
+        w = self.weight.reshape(1, -1, 1, 1)
+        b = self.bias.reshape(1, -1, 1, 1)
+        rv = self.running_var.reshape(1, -1, 1, 1)
+        rm = self.running_mean.reshape(1, -1, 1, 1)
+        eps = 1e-5
+        scale = w * (rv + eps).rsqrt()
+        bias = b - rm * scale
+        return x * scale + bias
+    
 class Conv(nn.Module):
     def __init__(self, 
                  c1,                   # in channels
@@ -92,104 +120,5 @@ class Conv(nn.Module):
     def forward(self, x):
         return self.convs(x)
 
-class BasicBlock(nn.Module):
-    expansion: int = 1
-
-    def __init__(
-        self,
-        inplanes: int,
-        planes: int,
-        stride: int = 1,
-        downsample: Optional[nn.Module] = None,
-        groups: int = 1,
-        base_width: int = 64,
-        dilation: int = 1,
-        norm_layer: Optional[Callable[..., nn.Module]] = None,
-    ) -> None:
-        super().__init__()
-        if norm_layer is None:
-            norm_layer = nn.BatchNorm2d
-        if groups != 1 or base_width != 64:
-            raise ValueError("BasicBlock only supports groups=1 and base_width=64")
-        if dilation > 1:
-            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
-        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
-        self.conv1 = conv3x3(inplanes, planes, stride)
-        self.bn1 = norm_layer(planes)
-        self.relu = nn.ReLU(inplace=True)
-        self.conv2 = conv3x3(planes, planes)
-        self.bn2 = norm_layer(planes)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x: Tensor) -> Tensor:
-        identity = x
-
-        out = self.conv1(x)
-        out = self.bn1(out)
-        out = self.relu(out)
-
-        out = self.conv2(out)
-        out = self.bn2(out)
-
-        if self.downsample is not None:
-            identity = self.downsample(x)
-
-        out += identity
-        out = self.relu(out)
-
-        return out
-
-class Bottleneck(nn.Module):
-    expansion: int = 4
-
-    def __init__(
-        self,
-        inplanes: int,
-        planes: int,
-        stride: int = 1,
-        downsample: Optional[nn.Module] = None,
-        groups: int = 1,
-        base_width: int = 64,
-        dilation: int = 1,
-        norm_layer: Optional[Callable[..., nn.Module]] = None,
-    ) -> None:
-        super().__init__()
-        if norm_layer is None:
-            norm_layer = nn.BatchNorm2d
-        width = int(planes * (base_width / 64.0)) * groups
-        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
-        self.conv1 = conv1x1(inplanes, width)
-        self.bn1 = norm_layer(width)
-        self.conv2 = conv3x3(width, width, stride, groups, dilation)
-        self.bn2 = norm_layer(width)
-        self.conv3 = conv1x1(width, planes * self.expansion)
-        self.bn3 = norm_layer(planes * self.expansion)
-        self.relu = nn.ReLU(inplace=True)
-        self.downsample = downsample
-        self.stride = stride
-
-    def forward(self, x: Tensor) -> Tensor:
-        identity = x
-
-        out = self.conv1(x)
-        out = self.bn1(out)
-        out = self.relu(out)
-
-        out = self.conv2(out)
-        out = self.bn2(out)
-        out = self.relu(out)
-
-        out = self.conv3(out)
-        out = self.bn3(out)
-
-        if self.downsample is not None:
-            identity = self.downsample(x)
-
-        out += identity
-        out = self.relu(out)
-
-        return out
-
 
 # ----------------- Transformer modules -----------------

+ 6 - 0
train.sh

@@ -15,6 +15,12 @@ if [[ $MODEL == *"rtcdet"* ]]; then
     WP_EPOCH=3
     EVAL_EPOCH=10
     NO_AUG_EPOCH=20
+elif [[ $MODEL == *"rtdetr"* ]]; then
+    # Epoch setting
+    MAX_EPOCH=72
+    WP_EPOCH=-1
+    EVAL_EPOCH=4
+    NO_AUG_EPOCH=-1
 elif [[ $MODEL == *"yolov8"* ]]; then
     # Epoch setting
     MAX_EPOCH=500