|
|
@@ -3,24 +3,21 @@ import torch.distributed as dist
|
|
|
|
|
|
import time
|
|
|
import os
|
|
|
+import math
|
|
|
+import numpy as np
|
|
|
import random
|
|
|
|
|
|
from utils import distributed_utils
|
|
|
from utils.vis_tools import vis_data
|
|
|
|
|
|
|
|
|
-def rescale_image_targets(images, targets, stride, min_box_size):
|
|
|
+def rescale_image_targets(images, targets, max_stride, min_box_size):
|
|
|
"""
|
|
|
Deployed for Multi scale trick.
|
|
|
"""
|
|
|
- if isinstance(stride, int):
|
|
|
- max_stride = stride
|
|
|
- elif isinstance(stride, list):
|
|
|
- max_stride = max(stride)
|
|
|
-
|
|
|
# During training phase, the shape of input image is square.
|
|
|
old_img_size = images.shape[-1]
|
|
|
- new_img_size = random.randrange(old_img_size * 0.5, old_img_size * 1.0 + max_stride) // max_stride * max_stride # size
|
|
|
+ new_img_size = random.randrange(old_img_size * 0.5, old_img_size * 1.5 + max_stride) // max_stride * max_stride # size
|
|
|
if new_img_size / old_img_size != 1:
|
|
|
# interpolate
|
|
|
images = torch.nn.functional.interpolate(
|
|
|
@@ -54,10 +51,12 @@ def train_one_epoch(epoch,
|
|
|
ema,
|
|
|
model,
|
|
|
criterion,
|
|
|
+ cfg,
|
|
|
dataloader,
|
|
|
optimizer,
|
|
|
- lr_scheduler,
|
|
|
- warmup_scheduler,
|
|
|
+ scheduler,
|
|
|
+ lf,
|
|
|
+ scaler,
|
|
|
last_opt_step):
|
|
|
epoch_size = len(dataloader)
|
|
|
img_size = args.img_size
|
|
|
@@ -69,34 +68,42 @@ def train_one_epoch(epoch,
|
|
|
for iter_i, (images, targets) in enumerate(dataloader):
|
|
|
ni = iter_i + epoch * epoch_size
|
|
|
# Warmup
|
|
|
- if ni < nw:
|
|
|
- warmup_scheduler.warmup(ni, optimizer)
|
|
|
+ if ni <= nw:
|
|
|
+ xi = [0, nw] # x interp
|
|
|
+ accumulate = max(1, np.interp(ni, xi, [1, 64 / args.batch_size]).round())
|
|
|
+ for j, x in enumerate(optimizer.param_groups):
|
|
|
+ # bias lr falls from 0.1 to lr0, all other lrs rise from 0.0 to lr0
|
|
|
+ x['lr'] = np.interp(
|
|
|
+ ni, xi, [cfg['warmup_bias_lr'] if j == 0 else 0.0, x['initial_lr'] * lf(epoch)])
|
|
|
+ if 'momentum' in x:
|
|
|
+ x['momentum'] = np.interp(ni, xi, [cfg['warmup_momentum'], cfg['momentum']])
|
|
|
|
|
|
# visualize train targets
|
|
|
if args.vis_tgt:
|
|
|
vis_data(images, targets)
|
|
|
|
|
|
# to device
|
|
|
- images = images.to(device, non_blocking=True).float()
|
|
|
+ images = images.to(device, non_blocking=True).float() / 255.
|
|
|
|
|
|
# multi scale
|
|
|
- if args.multi_scale and ni % 10 == 0:
|
|
|
+ if args.multi_scale:
|
|
|
images, targets, img_size = rescale_image_targets(
|
|
|
- images, targets, model.stride, args.min_box_size)
|
|
|
+ images, targets, max(model.stride), args.min_box_size)
|
|
|
|
|
|
# inference
|
|
|
- outputs = model(images)
|
|
|
-
|
|
|
- # loss
|
|
|
- loss_dict = criterion(outputs=outputs, targets=targets)
|
|
|
- losses = loss_dict['losses']
|
|
|
+ with torch.cuda.amp.autocast(enabled=args.fp16):
|
|
|
+ outputs = model(images)
|
|
|
+ # loss
|
|
|
+ loss_dict = criterion(outputs=outputs, targets=targets)
|
|
|
+ losses = loss_dict['losses']
|
|
|
+ losses *= images.shape[0] # loss * bs
|
|
|
|
|
|
- # reduce
|
|
|
- loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
|
|
|
+ # reduce
|
|
|
+ loss_dict_reduced = distributed_utils.reduce_dict(loss_dict)
|
|
|
|
|
|
- if args.distributed:
|
|
|
- # gradient averaged between devices in DDP mode
|
|
|
- losses *= distributed_utils.get_world_size()
|
|
|
+ if args.distributed:
|
|
|
+ # gradient averaged between devices in DDP mode
|
|
|
+ losses *= distributed_utils.get_world_size()
|
|
|
|
|
|
# check loss
|
|
|
try:
|
|
|
@@ -107,16 +114,20 @@ def train_one_epoch(epoch,
|
|
|
print(loss_dict)
|
|
|
|
|
|
# backward
|
|
|
- losses /= accumulate
|
|
|
- losses.backward()
|
|
|
+ scaler.scale(losses).backward()
|
|
|
|
|
|
# Optimize
|
|
|
if ni - last_opt_step >= accumulate:
|
|
|
+ if cfg['clip_grad'] > 0:
|
|
|
+ # unscale gradients
|
|
|
+ scaler.unscale_(optimizer)
|
|
|
+ # clip gradients
|
|
|
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg['clip_grad'])
|
|
|
# optimizer.step
|
|
|
- optimizer.step()
|
|
|
+ scaler.step(optimizer)
|
|
|
+ scaler.update()
|
|
|
optimizer.zero_grad()
|
|
|
-
|
|
|
- # EMA
|
|
|
+ # ema
|
|
|
if ema:
|
|
|
ema.update(model)
|
|
|
last_opt_step = ni
|
|
|
@@ -128,7 +139,7 @@ def train_one_epoch(epoch,
|
|
|
# basic infor
|
|
|
log = '[Epoch: {}/{}]'.format(epoch+1, total_epochs)
|
|
|
log += '[Iter: {}/{}]'.format(iter_i, epoch_size)
|
|
|
- log += '[lr: {:.6f}]'.format(cur_lr[0])
|
|
|
+ log += '[lr: {:.6f}]'.format(cur_lr[2])
|
|
|
# loss infor
|
|
|
for k in loss_dict_reduced.keys():
|
|
|
if k == 'losses' and args.distributed:
|
|
|
@@ -146,7 +157,7 @@ def train_one_epoch(epoch,
|
|
|
|
|
|
t0 = time.time()
|
|
|
|
|
|
- lr_scheduler.step()
|
|
|
+ scheduler.step()
|
|
|
|
|
|
return last_opt_step
|
|
|
|
|
|
@@ -163,7 +174,7 @@ def val_one_epoch(args,
|
|
|
if evaluator is None:
|
|
|
print('No evaluator ... save model and go on training.')
|
|
|
print('Saving state, epoch: {}'.format(epoch + 1))
|
|
|
- weight_name = '{}_epoch_{}.pth'.format(args.model, epoch + 1)
|
|
|
+ weight_name = '{}_epoch_{}.pth'.format(args.version, epoch + 1)
|
|
|
checkpoint_path = os.path.join(path_to_save, weight_name)
|
|
|
torch.save({'model': model.state_dict(),
|
|
|
'mAP': -1.,
|
|
|
@@ -187,7 +198,7 @@ def val_one_epoch(args,
|
|
|
best_map = cur_map
|
|
|
# save model
|
|
|
print('Saving state, epoch:', epoch + 1)
|
|
|
- weight_name = '{}_epoch_{}_{:.2f}.pth'.format(args.model, epoch + 1, best_map*100)
|
|
|
+ weight_name = '{}_epoch_{}_{:.2f}.pth'.format(args.version, epoch + 1, best_map*100)
|
|
|
checkpoint_path = os.path.join(path_to_save, weight_name)
|
|
|
torch.save({'model': model.state_dict(),
|
|
|
'mAP': round(best_map*100, 1),
|