|
|
@@ -9,10 +9,15 @@ from utils import distributed_utils
|
|
|
from utils.vis_tools import vis_data
|
|
|
|
|
|
|
|
|
-def rescale_image_targets(images, targets, max_stride, min_box_size):
|
|
|
+def rescale_image_targets(images, targets, stride, min_box_size):
|
|
|
"""
|
|
|
Deployed for Multi scale trick.
|
|
|
"""
|
|
|
+ if isinstance(stride, int):
|
|
|
+ max_stride = stride
|
|
|
+ elif isinstance(stride, list):
|
|
|
+ max_stride = max(stride)
|
|
|
+
|
|
|
# During training phase, the shape of input image is square.
|
|
|
old_img_size = images.shape[-1]
|
|
|
new_img_size = random.randrange(old_img_size * 0.5, old_img_size * 1.0 + max_stride) // max_stride * max_stride # size
|
|
|
@@ -77,7 +82,7 @@ def train_one_epoch(epoch,
|
|
|
# multi scale
|
|
|
if args.multi_scale and ni % 10 == 0:
|
|
|
images, targets, img_size = rescale_image_targets(
|
|
|
- images, targets, max(model.stride), args.min_box_size)
|
|
|
+ images, targets, model.stride, args.min_box_size)
|
|
|
|
|
|
# inference
|
|
|
outputs = model(images)
|