Procházet zdrojové kódy

add RepConv for YOLOv7

yjh0410 před 2 roky
rodič
revize
f54f096f67
9 změnil soubory, kde provedl 273 přidání a 85 odebrání
  1. 2 3
      README.md
  2. 4 5
      README_CN.md
  3. 4 4
      config/yolov7_config.py
  4. 5 1
      eval.py
  5. 181 4
      models/yolov7/yolov7_basic.py
  6. 4 7
      models/yolov7/yolov7_fpn.py
  7. 6 9
      test.py
  8. 2 2
      train.sh
  9. 65 50
      utils/misc.py

+ 2 - 3
README.md

@@ -106,14 +106,13 @@ python train.py --cuda -d coco --root path/to/COCO -v yolov1 -bs 16 --max_epoch
 | YOLOv4        | CSPDarkNet-L       |  640  |  √   |  250  |       |        46.6            |       65.8        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov4_coco.pth) |
 | YOLOv5        | CSPDarkNet-53      |  640  |  √   |  250  |       |                        |                   |  |
 | YOLOX         | CSPDarkNet-L       |  640  |  √   |  300  |       |        46.6            |       66.1        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_coco.pth) |
-| YOLOv7-Tiny   | ELANNet-Tiny       |  640  |  √   |  300  |       |        37.7            |       56.6        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov7_tiny_coco.pth) |
+| YOLOv7-Tiny   | ELANNet-Tiny       |  640  |  √   |  300  |       |                    |               |  |
 | YOLOv7-Large  | ELANNet-Large      |  640  |  √   |  300  |       |                        |                   |  |
 | YOLOv8-Nano   | CSP-ELANNet-Nano   |  640  |  √   |  500  |       |                        |                   |  |
-| YOLOv8-Large  | CSP-ELANNet-Large  |  640  |  √   |  500  |       |                        |                   |  |
 
 *All models are trained with ImageNet pretrained weight (IP). All FLOPs are measured with a 640x640 image size on COCO val2017. The FPS is measured with batch size 1 on 3090 GPU from the model inference to the NMS operation.*
 
-*Due to my limited computing resources, I had to abandon training on other YOLO detectors, including YOLOv7-Huge, YOLOv8-Small, YOLOv8-Medium and so on. If you are interested in these models and have trained them using the code from this project, I would greatly appreciate it if you could share the trained weight files with me.*
+*Due to my limited computing resources, I had to abandon training on other YOLO detectors, including YOLOv7-Huge, YOLOv8-Small, YOLOv8-Medium, YOLOv8-Large and YOLOv8-Huge. If you are interested in these models and have trained them using the code from this project, I would greatly appreciate it if you could share the trained weight files with me.*
 
 ## Train
 ### Single GPU

+ 4 - 5
README_CN.md

@@ -109,14 +109,13 @@ python train.py --cuda -d coco --root path/to/COCO -v yolov1 -bs 16 --max_epoch
 | YOLOv4        | CSPDarkNet-L       |  640  |  √   |  250  |       |        46.6            |       65.8        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov4_coco.pth) |
 | YOLOv5        | CSPDarkNet-53      |  640  |  √   |  250  |       |                        |                   |  |
 | YOLOX         | CSPDarkNet-L       |  640  |  √   |  300  |       |        46.6            |       66.1        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolox_coco.pth) |
-| YOLOv7-Tiny   | ELANNet-Tiny       |  640  |  √   |  300  |       |        37.7            |       56.6        | [ckpt](https://github.com/yjh0410/PyTorch_YOLO_Tutorial/releases/download/yolo_tutorial_ckpt/yolov7_tiny_coco.pth) |
-| YOLOv7-Large  | ELANNet-Large      |  640  |  √   |  300  |       |                        |                   |  |
-| YOLOv8-Nano   | CSP-ELANNet-Nano   |  640  |  √   |  500  |       |                        |                   |  |
-| YOLOv8-Large  | CSP-ELANNet-Large  |  640  |  √   |  500  |       |                        |                   |  |
+| YOLOv7-Tiny   | ELANNet-Tiny       |  640  |  √   |  300  |       |                    |               |  |
+| YOLOv7-Large  | ELANNet-Large      |  640  |  √   |  300  |       |                    |               |  |
+| YOLOv8-Nano   | CSP-ELANNet-Nano   |  640  |  √   |  500  |       |                    |               |  |
 
 *所有的模型都使用了ImageNet预训练权重(IP),所有的FLOPs都是在COCO-val数据集上以640x640或1280x1280的输入尺寸来测试的。FPS指标是在一张3090型号的GPU上以batch size=1的输入来测试的,请注意,测速的内容包括模型前向推理、后处理以及NMS操作。*
 
-*受限于我贫瘠的计算资源,更多的YOLO检测器被放弃训练了,例如YOLOv7-Huge、YOLOv8-Small和YOLOv8-Medium等。如果您对他们感兴趣,并使用本项目的代码训练了他们,我很真诚地希望您能分享训练好的权重文件,那将会令我感激不尽。*
+*受限于我贫瘠的计算资源,更多的YOLO检测器被放弃训练了,包括YOLOv7-Huge、YOLOv8-Small、YOLOv8-Medium、YOLOv8-Large以及YOLOv8-Huge。如果您对他们感兴趣,并使用本项目的代码训练了他们,我很真诚地希望您能分享训练好的权重文件,那将会令我感激不尽。*
 
 ## 训练
 ### 使用单个GPU来训练

+ 4 - 4
config/yolov7_config.py

@@ -4,7 +4,7 @@ yolov7_cfg = {
     'yolov7_nano':{
         # input
         'trans_type': 'yolov5_nano',
-        'multi_scale': [0.5, 1.25], # 320 -> 800
+        'multi_scale': [0.5, 1.0], # 320 -> 640
         # model
         'backbone': 'elannet_nano',
         'pretrained': True,
@@ -62,7 +62,7 @@ yolov7_cfg = {
     'yolov7_tiny':{
         # input
         'trans_type': 'yolov5_weak',
-        'multi_scale': [0.5, 1.25], # 320 -> 800
+        'multi_scale': [0.5, 1.0], # 320 -> 640
         # model
         'backbone': 'elannet_tiny',
         'pretrained': True,
@@ -120,7 +120,7 @@ yolov7_cfg = {
     'yolov7_large':{
         # input
         'trans_type': 'yolov5_strong',
-        'multi_scale': [0.5, 1.25], # 320 -> 800
+        'multi_scale': [0.5, 1.0], # 320 -> 640
         # model
         'backbone': 'elannet_large',
         'pretrained': True,
@@ -178,7 +178,7 @@ yolov7_cfg = {
     'yolov7_huge':{
         # input
         'trans_type': 'yolov5_strong',
-        'multi_scale': [0.5, 1.25], # 320 -> 800
+        'multi_scale': [0.5, 1.0], # 320 -> 640
         # model
         'backbone': 'elannet_huge',
         'pretrained': True,

+ 5 - 1
eval.py

@@ -39,6 +39,10 @@ def parse_args():
                         help='topk candidates for testing')
     parser.add_argument("--no_decode", action="store_true", default=False,
                         help="not decode in inference or yes")
+    parser.add_argument('--fuse_repconv', action='store_true', default=False,
+                        help='fuse RepConv')
+    parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
+                        help='fuse Conv & BN')
 
     # dataset
     parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
@@ -119,7 +123,7 @@ if __name__ == '__main__':
     model = build_model(args, model_cfg, device, num_classes, False)
 
     # load trained weight
-    model = load_weight(model=model, path_to_ckpt=args.weight)
+    model = load_weight(model, args.weight, args.fuse_conv_bn, args.fuse_repconv)
     model.to(device).eval()
 
     # compute FLOPs and Params

+ 181 - 4
models/yolov7/yolov7_basic.py

@@ -1,7 +1,9 @@
+import numpy as np
 import torch
 import torch.nn as nn
 
 
+# ---------------------------- 2D CNN ----------------------------
 class SiLU(nn.Module):
     """export-friendly version of nn.SiLU()"""
 
@@ -34,7 +36,7 @@ def get_norm(norm_type, dim):
         return nn.GroupNorm(num_groups=32, num_channels=dim)
 
 
-# Basic conv layer
+## Basic conv layer
 class Conv(nn.Module):
     def __init__(self, 
                  c1,                   # in channels
@@ -77,7 +79,8 @@ class Conv(nn.Module):
         return self.convs(x)
 
 
-# ELAN Block
+# ---------------------------- YOLOv7 Modules ----------------------------
+## ELAN-Block proposed by YOLOv7
 class ELANBlock(nn.Module):
     def __init__(self, in_dim, out_dim, expand_ratio=0.5, depth=2.0, act_type='silu', norm_type='BN', depthwise=False):
         super(ELANBlock, self).__init__()
@@ -107,7 +110,7 @@ class ELANBlock(nn.Module):
         return out
 
 
-# ELAN Block for PaFPN
+## PaFPN's ELAN-Block proposed by YOLOv7
 class ELANBlockFPN(nn.Module):
     def __init__(self, in_dim, out_dim, expand_ratio=0.5, nbranch=4, depth=1, act_type='silu', norm_type='BN', depthwise=False):
         super(ELANBlockFPN, self).__init__()
@@ -147,7 +150,7 @@ class ELANBlockFPN(nn.Module):
         return out
 
 
-# DownSample Block
+## DownSample Block proposed by YOLOv7
 class DownSample(nn.Module):
     def __init__(self, in_dim, out_dim, act_type='silu', norm_type='BN', depthwise=False):
         super().__init__()
@@ -165,3 +168,177 @@ class DownSample(nn.Module):
         out = torch.cat([x1, x2], dim=1)
 
         return out
+
+
+# ---------------------------- RepConv Modules ----------------------------
+class RepConv(nn.Module):
+    """
+        The code referenced to https://github.com/WongKinYiu/yolov7/models/common.py
+    """
+    # Represented convolution
+    # https://arxiv.org/abs/2101.03697
+
+    def __init__(self, c1, c2, k=3, s=1, p=1, g=1, act_type='silu', deploy=False):
+        super(RepConv, self).__init__()
+        # -------------- Basic parameters --------------
+        self.deploy = deploy
+        self.groups = g
+        self.in_channels = c1
+        self.out_channels = c2
+
+        # -------------- Network parameters --------------
+        if deploy:
+            self.rbr_reparam = nn.Conv2d(c1, c2, k, s, p, groups=g, bias=True)
+
+        else:
+            self.rbr_identity = (nn.BatchNorm2d(num_features=c1) if c2 == c1 and s == 1 else None)
+
+            self.rbr_dense = nn.Sequential(
+                nn.Conv2d(c1, c2, k, s, p, groups=g, bias=False),
+                nn.BatchNorm2d(num_features=c2),
+            )
+
+            self.rbr_1x1 = nn.Sequential(
+                nn.Conv2d(c1, c2, kernel_size=1, stride=s, bias=False),
+                nn.BatchNorm2d(num_features=c2),
+            )
+        self.act = get_activation(act_type)
+
+
+    def forward(self, inputs):
+        if hasattr(self, "rbr_reparam"):
+            return self.act(self.rbr_reparam(inputs))
+
+        if self.rbr_identity is None:
+            id_out = 0
+        else:
+            id_out = self.rbr_identity(inputs)
+
+        return self.act(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out)
+    
+    def get_equivalent_kernel_bias(self):
+        kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
+        kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
+        kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
+        return (
+            kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid,
+            bias3x3 + bias1x1 + biasid,
+        )
+
+    def _pad_1x1_to_3x3_tensor(self, kernel1x1):
+        if kernel1x1 is None:
+            return 0
+        else:
+            return nn.functional.pad(kernel1x1, [1, 1, 1, 1])
+
+    def _fuse_bn_tensor(self, branch):
+        if branch is None:
+            return 0, 0
+        if isinstance(branch, nn.Sequential):
+            kernel = branch[0].weight
+            running_mean = branch[1].running_mean
+            running_var = branch[1].running_var
+            gamma = branch[1].weight
+            beta = branch[1].bias
+            eps = branch[1].eps
+        else:
+            assert isinstance(branch, nn.BatchNorm2d)
+            if not hasattr(self, "id_tensor"):
+                input_dim = self.in_channels // self.groups
+                kernel_value = np.zeros(
+                    (self.in_channels, input_dim, 3, 3), dtype=np.float32
+                )
+                for i in range(self.in_channels):
+                    kernel_value[i, i % input_dim, 1, 1] = 1
+                self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
+            kernel = self.id_tensor
+            running_mean = branch.running_mean
+            running_var = branch.running_var
+            gamma = branch.weight
+            beta = branch.bias
+            eps = branch.eps
+        std = (running_var + eps).sqrt()
+        t = (gamma / std).reshape(-1, 1, 1, 1)
+        return kernel * t, beta - running_mean * gamma / std
+
+    def repvgg_convert(self):
+        kernel, bias = self.get_equivalent_kernel_bias()
+        return (
+            kernel.detach().cpu().numpy(),
+            bias.detach().cpu().numpy(),
+        )
+
+    def fuse_conv_bn(self, conv, bn):
+
+        std = (bn.running_var + bn.eps).sqrt()
+        bias = bn.bias - bn.running_mean * bn.weight / std
+
+        t = (bn.weight / std).reshape(-1, 1, 1, 1)
+        weights = conv.weight * t
+
+        bn = nn.Identity()
+        conv = nn.Conv2d(in_channels = conv.in_channels,
+                              out_channels = conv.out_channels,
+                              kernel_size = conv.kernel_size,
+                              stride=conv.stride,
+                              padding = conv.padding,
+                              dilation = conv.dilation,
+                              groups = conv.groups,
+                              bias = True,
+                              padding_mode = conv.padding_mode)
+
+        conv.weight = torch.nn.Parameter(weights)
+        conv.bias = torch.nn.Parameter(bias)
+        return conv
+
+    def fuse_repvgg_block(self):    
+        if self.deploy:
+            return
+                
+        self.rbr_dense = self.fuse_conv_bn(self.rbr_dense[0], self.rbr_dense[1])
+        
+        self.rbr_1x1 = self.fuse_conv_bn(self.rbr_1x1[0], self.rbr_1x1[1])
+        rbr_1x1_bias = self.rbr_1x1.bias
+        weight_1x1_expanded = torch.nn.functional.pad(self.rbr_1x1.weight, [1, 1, 1, 1])
+        
+        # Fuse self.rbr_identity
+        if (isinstance(self.rbr_identity, nn.BatchNorm2d) or isinstance(self.rbr_identity, nn.modules.batchnorm.SyncBatchNorm)):
+            identity_conv_1x1 = nn.Conv2d(
+                    in_channels=self.in_channels,
+                    out_channels=self.out_channels,
+                    kernel_size=1,
+                    stride=1,
+                    padding=0,
+                    groups=self.groups, 
+                    bias=False)
+            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.to(self.rbr_1x1.weight.data.device)
+            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.squeeze().squeeze()
+
+            identity_conv_1x1.weight.data.fill_(0.0)
+            identity_conv_1x1.weight.data.fill_diagonal_(1.0)
+            identity_conv_1x1.weight.data = identity_conv_1x1.weight.data.unsqueeze(2).unsqueeze(3)
+
+            identity_conv_1x1 = self.fuse_conv_bn(identity_conv_1x1, self.rbr_identity)
+            bias_identity_expanded = identity_conv_1x1.bias
+            weight_identity_expanded = torch.nn.functional.pad(identity_conv_1x1.weight, [1, 1, 1, 1])            
+        else:
+            bias_identity_expanded = torch.nn.Parameter( torch.zeros_like(rbr_1x1_bias) )
+            weight_identity_expanded = torch.nn.Parameter( torch.zeros_like(weight_1x1_expanded) )            
+        
+        self.rbr_dense.weight = torch.nn.Parameter(self.rbr_dense.weight + weight_1x1_expanded + weight_identity_expanded)
+        self.rbr_dense.bias = torch.nn.Parameter(self.rbr_dense.bias + rbr_1x1_bias + bias_identity_expanded)
+                
+        self.rbr_reparam = self.rbr_dense
+        self.deploy = True
+
+        if self.rbr_identity is not None:
+            del self.rbr_identity
+            self.rbr_identity = None
+
+        if self.rbr_1x1 is not None:
+            del self.rbr_1x1
+            self.rbr_1x1 = None
+
+        if self.rbr_dense is not None:
+            del self.rbr_dense
+            self.rbr_dense = None

+ 4 - 7
models/yolov7/yolov7_fpn.py

@@ -1,7 +1,7 @@
 import torch
 import torch.nn as nn
 import torch.nn.functional as F
-from .yolov7_basic import Conv, ELANBlockFPN, DownSample
+from .yolov7_basic import Conv, ELANBlockFPN, DownSample, RepConv
 
 
 # PaFPN-ELAN (YOLOv7's)
@@ -72,12 +72,9 @@ class Yolov7PaFPN(nn.Module):
                                               )
         
         # head conv
-        self.head_conv_1 = Conv(round(128*width), round(256*width), k=3, p=1,
-                                act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.head_conv_2 = Conv(round(256*width), round(512*width), k=3, p=1,
-                                act_type=act_type, norm_type=norm_type, depthwise=depthwise)
-        self.head_conv_3 = Conv(round(512*width), round(1024*width), k=3, p=1,
-                                act_type=act_type, norm_type=norm_type, depthwise=depthwise)
+        self.head_conv_1 = RepConv(round(128*width), round(256*width), k=3, s=1, p=1, act_type=act_type)
+        self.head_conv_2 = RepConv(round(256*width), round(512*width), k=3, s=1, p=1, act_type=act_type)
+        self.head_conv_3 = RepConv(round(512*width), round(1024*width), k=3, s=1, p=1, act_type=act_type)
 
         # output proj layers
         if out_dim is not None:

+ 6 - 9
test.py

@@ -33,7 +33,7 @@ def parse_args():
                         help='use cuda.')
     parser.add_argument('--save_folder', default='det_results/', type=str,
                         help='Dir to save results')
-    parser.add_argument('-vs', '--visual_threshold', default=0.4, type=float,
+    parser.add_argument('-vt', '--visual_threshold', default=0.4, type=float,
                         help='Final confidence threshold')
     parser.add_argument('-ws', '--window_scale', default=1.0, type=float,
                         help='resize window of cv2 for visualization.')
@@ -49,10 +49,12 @@ def parse_args():
                         help='NMS threshold')
     parser.add_argument('--topk', default=100, type=int,
                         help='topk candidates for testing')
-    parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
-                        help='fuse conv and bn')
     parser.add_argument("--no_decode", action="store_true", default=False,
                         help="not decode in inference or yes")
+    parser.add_argument('--fuse_repconv', action='store_true', default=False,
+                        help='fuse RepConv')
+    parser.add_argument('--fuse_conv_bn', action='store_true', default=False,
+                        help='fuse Conv & BN')
 
     # dataset
     parser.add_argument('--root', default='/mnt/share/ssd2/dataset',
@@ -193,7 +195,7 @@ if __name__ == '__main__':
     model = build_model(args, model_cfg, device, num_classes, False)
 
     # load trained weight
-    model = load_weight(model=model, path_to_ckpt=args.weight)
+    model = load_weight(model, args.weight, args.fuse_conv_bn, args.fuse_repconv)
     model.to(device).eval()
 
     # compute FLOPs and Params
@@ -206,11 +208,6 @@ if __name__ == '__main__':
         device=device)
     del model_copy
 
-    # fuse conv bn
-    if args.fuse_conv_bn:
-        print('fuse conv and bn ...')
-        model = fuse_conv_bn.fuse_conv_bn(model)
-
     # transform
     transform = build_transform(args.img_size, trans_cfg, is_train=False)
 

+ 2 - 2
train.sh

@@ -3,11 +3,11 @@ python train.py \
         --cuda \
         -d coco \
         --root /mnt/share/ssd2/dataset/ \
-        -m yolov8_nano \
+        -m yolov7_large \
         -bs 16 \
         -size 640 \
         --wp_epoch 1 \
-        --max_epoch 500 \
+        --max_epoch 150 \
         --eval_epoch 10 \
         --ema \
         --fp16 \

+ 65 - 50
utils/misc.py

@@ -14,7 +14,12 @@ from dataset.voc import VOCDetection, VOC_CLASSES
 from dataset.coco import COCODataset, coco_class_index, coco_class_labels
 from dataset.data_augment import build_transform
 
+from utils import fuse_conv_bn
+from models.yolov7.yolov7_basic import RepConv
 
+
+# ---------------------------- For Dataset ----------------------------
+## build dataset
 def build_dataset(args, trans_config, device, is_train=False):
     # transform
     print('==============================')
@@ -78,7 +83,7 @@ def build_dataset(args, trans_config, device, is_train=False):
 
     return dataset, (num_classes, class_names, class_indexs), evaluator
 
-
+## build dataloader
 def build_dataloader(args, dataset, batch_size, collate_fn=None):
     # distributed
     if args.distributed:
@@ -93,46 +98,52 @@ def build_dataloader(args, dataset, batch_size, collate_fn=None):
     
     return dataloader
     
+## collate_fn for dataloader
+class CollateFunc(object):
+    def __call__(self, batch):
+        targets = []
+        images = []
 
-def load_weight(model, path_to_ckpt):
-    # check ckpt file
-    if path_to_ckpt is None:
-        print('no weight file ...')
-
-        return model
-
-    checkpoint = torch.load(path_to_ckpt, map_location='cpu')
-    try:
-        checkpoint_state_dict = checkpoint.pop("model")
-    except:
-        checkpoint_state_dict = checkpoint
-    model.load_state_dict(checkpoint_state_dict)
+        for sample in batch:
+            image = sample[0]
+            target = sample[1]
 
-    print('Finished loading model!')
+            images.append(image)
+            targets.append(target)
 
-    return model
+        images = torch.stack(images, 0) # [B, C, H, W]
 
+        return images, targets
 
-def is_parallel(model):
-    # Returns True if model is of type DP or DDP
-    return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
 
+# ---------------------------- For Model ----------------------------
+## load trained weight
+def load_weight(model, path_to_ckpt, fuse_cbn=False, fuse_repconv=False):
+    # check ckpt file
+    if path_to_ckpt is None:
+        print('no weight file ...')
+    else:
+        checkpoint = torch.load(path_to_ckpt, map_location='cpu')
+        checkpoint_state_dict = checkpoint.pop("model")
+        model.load_state_dict(checkpoint_state_dict)
 
-def de_parallel(model):
-    # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
-    return model.module if is_parallel(model) else model
+        print('Finished loading model!')
 
+    # fuse repconv
+    if fuse_repconv:
+        print('Fusing RepConv block ...')
+        for m in model.modules():
+            if isinstance(m, RepConv):
+                m.fuse_repvgg_block()
 
-def copy_attr(a, b, include=(), exclude=()):
-    # Copy attributes from b to a, options to only include [...] and to exclude [...]
-    for k, v in b.__dict__.items():
-        if (len(include) and k not in include) or k.startswith('_') or k in exclude:
-            continue
-        else:
-            setattr(a, k, v)
+    # fuse conv & bn
+    if fuse_cbn:
+        print('Fusing Conv & BN ...')
+        model = fuse_conv_bn.fuse_conv_bn(model)
 
+    return model
 
-# Model EMA
+## Model EMA
 class ModelEMA(object):
     """ Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
     Keeps a moving average of everything in the model state_dict (parameters and buffers)
@@ -141,41 +152,45 @@ class ModelEMA(object):
 
     def __init__(self, model, decay=0.9999, tau=2000, updates=0):
         # Create EMA
-        self.ema = deepcopy(de_parallel(model)).eval()  # FP32 EMA
+        self.ema = deepcopy(self.de_parallel(model)).eval()  # FP32 EMA
         self.updates = updates  # number of EMA updates
         self.decay = lambda x: decay * (1 - math.exp(-x / tau))  # decay exponential ramp (to help early epochs)
         for p in self.ema.parameters():
             p.requires_grad_(False)
 
+
+    def is_parallel(self, model):
+        # Returns True if model is of type DP or DDP
+        return type(model) in (nn.parallel.DataParallel, nn.parallel.DistributedDataParallel)
+
+
+    def de_parallel(self, model):
+        # De-parallelize a model: returns single-GPU model if model is of type DP or DDP
+        return model.module if self.is_parallel(model) else model
+
+
+    def copy_attr(self, a, b, include=(), exclude=()):
+        # Copy attributes from b to a, options to only include [...] and to exclude [...]
+        for k, v in b.__dict__.items():
+            if (len(include) and k not in include) or k.startswith('_') or k in exclude:
+                continue
+            else:
+                setattr(a, k, v)
+
+
     def update(self, model):
         # Update EMA parameters
         self.updates += 1
         d = self.decay(self.updates)
 
-        msd = de_parallel(model).state_dict()  # model state_dict
+        msd = self.de_parallel(model).state_dict()  # model state_dict
         for k, v in self.ema.state_dict().items():
             if v.dtype.is_floating_point:  # true for FP16 and FP32
                 v *= d
                 v += (1 - d) * msd[k].detach()
         # assert v.dtype == msd[k].dtype == torch.float32, f'{k}: EMA {v.dtype} and model {msd[k].dtype} must be FP32'
 
+
     def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
         # Update EMA attributes
-        copy_attr(self.ema, model, include, exclude)
-
-
-class CollateFunc(object):
-    def __call__(self, batch):
-        targets = []
-        images = []
-
-        for sample in batch:
-            image = sample[0]
-            target = sample[1]
-
-            images.append(image)
-            targets.append(target)
-
-        images = torch.stack(images, 0) # [B, C, H, W]
-
-        return images, targets
+        self.copy_attr(self.ema, model, include, exclude)