| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.utils.data import DataLoader, DistributedSampler
- import numpy as np
- import os
- import math
- from copy import deepcopy
- from evaluator.coco_evaluator import COCOAPIEvaluator
- from evaluator.voc_evaluator import VOCAPIEvaluator
- from dataset.voc import VOCDetection, VOC_CLASSES
- from dataset.coco import COCODataset, coco_class_index, coco_class_labels
- from dataset.data_augment import build_transform
- from utils import fuse_conv_bn
- from models.yolov7.yolov7_basic import RepConv
- # ---------------------------- For Dataset ----------------------------
- ## build dataset
- def build_dataset(args, trans_config, device, is_train=False):
- # transform
- print('==============================')
- print('Transform Config: {}'.format(trans_config))
- train_transform = build_transform(args.img_size, trans_config, True)
- val_transform = build_transform(args.img_size, trans_config, False)
- # dataset
- if args.dataset == 'voc':
- data_dir = os.path.join(args.root, 'VOCdevkit')
- num_classes = 20
- class_names = VOC_CLASSES
- class_indexs = None
- # dataset
- dataset = VOCDetection(
- img_size=args.img_size,
- data_dir=data_dir,
- image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
- transform=train_transform,
- trans_config=trans_config,
- is_train=is_train
- )
-
- # evaluator
- evaluator = VOCAPIEvaluator(
- data_dir=data_dir,
- device=device,
- transform=val_transform
- )
- elif args.dataset == 'coco':
- data_dir = os.path.join(args.root, 'COCO')
- num_classes = 80
- class_names = coco_class_labels
- class_indexs = coco_class_index
- # dataset
- dataset = COCODataset(
- img_size=args.img_size,
- data_dir=data_dir,
- image_set='train2017' if is_train else 'val2017',
- transform=train_transform,
- trans_config=trans_config,
- is_train=is_train
- )
- # evaluator
- evaluator = COCOAPIEvaluator(
- data_dir=data_dir,
- device=device,
- transform=val_transform
- )
- else:
- print('unknow dataset !! Only support voc, coco !!')
- exit(0)
- print('==============================')
- print('Training model on:', args.dataset)
- print('The dataset size:', len(dataset))
- return dataset, (num_classes, class_names, class_indexs), evaluator
- ## build dataloader
- def build_dataloader(args, dataset, batch_size, collate_fn=None):
- # distributed
- if args.distributed:
- sampler = DistributedSampler(dataset)
- else:
- sampler = torch.utils.data.RandomSampler(dataset)
- batch_sampler_train = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler_train,
- collate_fn=collate_fn, num_workers=args.num_workers, pin_memory=True)
-
- return dataloader
-
- ## collate_fn for dataloader
- class CollateFunc(object):
- def __call__(self, batch):
- targets = []
- images = []
- for sample in batch:
- image = sample[0]
- target = sample[1]
- images.append(image)
- targets.append(target)
- images = torch.stack(images, 0) # [B, C, H, W]
- return images, targets
- # ---------------------------- For Model ----------------------------
- ## load trained weight
- def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_repconv=False):
- # check ckpt file
- if path_to_ckpt is None:
- print('no weight file ...')
- else:
- checkpoint = torch.load(path_to_ckpt, map_location='cpu')
- checkpoint_state_dict = checkpoint.pop("model")
- model.load_state_dict(checkpoint_state_dict)
- print('Finished loading model!')
- # fuse repconv
- if fuse_repconv:
- print('Fusing RepConv block ...')
- for m in model.modules():
- if isinstance(m, RepConv):
- m.fuse_repvgg_block()
- # fuse conv & bn
- if fuse_cbn:
- print('Fusing Conv & BN ...')
- model = fuse_conv_bn.fuse_conv_bn(model)
- return model
- ## Model EMA
- class ModelEMA(object):
- """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
- Keeps a moving average of everything in the model state_dict (parameters and buffers)
- For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
- """
- def __init__(self, model, decay=0.9999, tau=2000, updates=0):
- # Create EMA
- self.ema = deepcopy(self.de_parallel(model)).eval() # FP32 EMA
- self.updates = updates # number of EMA updates
- self.decay = lambda x: decay * (1 - math.exp(-x / tau)) # decay exponential ramp (to help early epochs)
- for p in self.ema.parameters():
- p.requires_grad_(False)
- def is_parallel(self, model):
- # Returns True if model is of type DP or DDP
- return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
- def de_parallel(self, model):
- # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
- return model.module if self.is_parallel(model) else model
- def copy_attr(self, a, b, include=(), exclude=()):
- # Copy attributes from b to a, options to only include [...] and to exclude [...]
- for k, v in b.__dict__.items():
- if (len(include) and k not in include) or k.startswith('_') or k in exclude:
- continue
- else:
- setattr(a, k, v)
- def update(self, model):
- # Update EMA parameters
- self.updates += 1
- d = self.decay(self.updates)
- msd = self.de_parallel(model).state_dict() # model state_dict
- for k, v in self.ema.state_dict().items():
- if v.dtype.is_floating_point: # true for FP16 and FP32
- v *= d
- v += (1 - d) * msd[k].detach()
- # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
- def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
- # Update EMA attributes
- self.copy_attr(self.ema, model, include, exclude)
|