import math import sys from typing import Iterable import torch from utils import distributed_utils from utils.misc import MetricLogger, SmoothedValue from utils.vis_tools import vis_data def train_one_epoch(cfg, model : torch.nn.Module, criterion : torch.nn.Module, data_loader : Iterable, optimizer : torch.optim.Optimizer, device : torch.device, epoch : int, vis_target : bool, warmup_lr_scheduler, debug :bool = False ): model.train() criterion.train() metric_logger = MetricLogger(delimiter=" ") metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}')) header = 'Epoch: [{} / {}]'.format(epoch, cfg.max_epoch) epoch_size = len(data_loader) print_freq = 10 optimizer.zero_grad() for iter_i, (samples, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)): ni = iter_i + epoch * epoch_size # WarmUp if ni % cfg.grad_accumulate == 0: ni = ni // cfg.grad_accumulate if ni < cfg.warmup_iters: warmup_lr_scheduler(ni, optimizer) elif ni == cfg.warmup_iters: print('Warmup stage is over.') warmup_lr_scheduler.set_lr(optimizer, cfg.base_lr) # To device images, masks = samples images = images.to(device) masks = masks.to(device) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] # Visualize train targets if vis_target: vis_data(images, targets, masks, cfg.class_labels, cfg.normalize_coords, cfg.box_format) # Inference outputs = model(images, masks) # Compute loss loss_dict = criterion(outputs, targets) losses = loss_dict["losses"]# sum(loss_dict[k] * loss_weight_dict[k] for k in loss_dict.keys() if k in loss_weight_dict) loss_value = losses.item() losses /= cfg.grad_accumulate # Reduce losses over all GPUs for logging purposes loss_dict_reduced = distributed_utils.reduce_dict(loss_dict) # Check loss if not math.isfinite(loss_value): print("Loss is {}, stopping training".format(loss_value)) print(loss_dict_reduced) sys.exit(1) # Backward losses.backward() # Optimize if (iter_i + 1) % cfg.grad_accumulate == 0: if cfg.clip_max_norm > 0: torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_max_norm) optimizer.step() optimizer.zero_grad() metric_logger.update(loss=loss_value, **loss_dict_reduced) metric_logger.update(lr=optimizer.param_groups[0]["lr"]) if debug: print("For debug mode, we only train the model with 1 iteration.") break # gather the stats from all processes metric_logger.synchronize_between_processes() print("Averaged stats:", metric_logger) return {k: meter.global_avg for k, meter in metric_logger.meters.items()}