Browse Source

modify rtdetr_r18 config

yjh0410 1 year ago
parent
commit
c981fb6612
2 changed files with 5 additions and 5 deletions
  1. 4 4
      engine.py
  2. 1 1
      utils/misc.py

+ 4 - 4
engine.py

@@ -1156,9 +1156,9 @@ class RTRTrainer(object):
 
         # ---------------------------- Build Transform ----------------------------
         self.train_transform, self.trans_cfg = build_transform(
-            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['out_stride'][-1], is_train=True)
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
         self.val_transform, _ = build_transform(
-            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['out_stride'][-1], is_train=False)
+            args=args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=False)
         if self.trans_cfg["mosaic_prob"] > 0.5:
             self.second_stage_epoch = 5
 
@@ -1303,7 +1303,7 @@ class RTRTrainer(object):
             # Multi scale
             if self.args.multi_scale:
                 images, targets, img_size = self.rescale_image_targets(
-                    images, targets, self.model_cfg['out_stride'][-1], self.args.min_box_size, self.model_cfg['multi_scale'])
+                    images, targets, self.model_cfg['max_stride'], self.args.min_box_size, self.model_cfg['multi_scale'])
             else:
                 targets = self.refine_targets(img_size, targets, self.args.min_box_size)
 
@@ -1478,7 +1478,7 @@ class RTRTrainer(object):
         # build a new transform for second stage
         print(' - Rebuild transforms ...')
         self.train_transform, self.trans_cfg = build_transform(
-            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['out_stride'][-1], is_train=True)
+            args=self.args, trans_config=self.trans_cfg, max_stride=self.model_cfg['max_stride'], is_train=True)
         self.train_loader.dataset.transform = self.train_transform
         
 

+ 1 - 1
utils/misc.py

@@ -22,7 +22,7 @@ class SmoothedValue(object):
     window or the global series average.
     """
 
-    def __init__(self, window_size=20, fmt=None):
+    def __init__(self, window_size=1, fmt=None):
         if fmt is None:
             fmt = "{median:.4f} ({global_avg:.4f})"
         self.deque = deque(maxlen=window_size)