Browse Source

debug train

yjh0410 2 years ago
parent
commit
85e7a70a45
1 changed files with 7 additions and 2 deletions
  1. 7 2
      engine.py

+ 7 - 2
engine.py

@@ -9,10 +9,15 @@ from utils import distributed_utils
 from utils.vis_tools import vis_data
 
 
-def rescale_image_targets(images, targets, max_stride, min_box_size):
+def rescale_image_targets(images, targets, stride, min_box_size):
     """
         Deployed for Multi scale trick.
     """
+    if isinstance(stride, int):
+        max_stride = stride
+    elif isinstance(stride, list):
+        max_stride = max(stride)
+
     # During training phase, the shape of input image is square.
     old_img_size = images.shape[-1]
     new_img_size = random.randrange(old_img_size * 0.5, old_img_size * 1.0 + max_stride) // max_stride * max_stride  # size
@@ -77,7 +82,7 @@ def train_one_epoch(epoch,
         # multi scale
         if args.multi_scale and ni % 10 == 0:
             images, targets, img_size = rescale_image_targets(
-                images, targets, max(model.stride), args.min_box_size)
+                images, targets, model.stride, args.min_box_size)
             
         # inference
         outputs = model(images)