浏览代码

modify train code

yjh0410 1 年之前
父节点
当前提交
9472637cf3
共有 3 个文件被更改,包括 3 次插入3 次删除
  1. 1 1
      config/model_config/rtdetr_config.py
  2. 1 1
      dataset/data_augment/rtdetr_augment.py
  3. 1 1
      models/detectors/rtdetr/build.py

+ 1 - 1
config/model_config/rtdetr_config.py

@@ -10,7 +10,7 @@ rtdetr_cfg = {
         'depth': 1.0,
         ## Image Encoder - Backbone
         'backbone': 'resnet18',
-        'backbone_norm': 'FrozeBN',
+        'backbone_norm': 'BN',
         'res5_dilation': False,
         'pretrained': True,
         'pretrained_weight': 'imagenet1k_v1',

+ 1 - 1
dataset/data_augment/rtdetr_augment.py

@@ -75,7 +75,7 @@ class RandomPhotometricDistort(object):
         Returns:
             ndarray: the distorted image(s).
         """
-        if random.random() < 0.5:
+        if random.random() < 0.8:
             dhue = np.random.uniform(low=-self.hue, high=self.hue)
             dsat = self._rand_scale(self.saturation)
             dexp = self._rand_scale(self.exposure)

+ 1 - 1
models/detectors/rtdetr/build.py

@@ -20,7 +20,7 @@ def build_rtdetr(args, cfg, num_classes=80, trainable=False, deploy=False):
     model = RT_DETR(cfg             = cfg,
                     num_classes     = num_classes,
                     conf_thresh     = args.conf_thresh,
-                    topk            = 100,
+                    topk            = 300,
                     deploy          = deploy,
                     no_multi_labels = args.no_multi_labels,
                     )