yjh0410 2 年之前
父节点
当前提交
e06b02bb2a
共有 6 个文件被更改,包括 16 次插入119 次删除
  1. 11 10
      dataset/build.py
  2. 1 3
      dataset/coco.py
  3. 1 4
      dataset/ourdataset.py
  4. 1 3
      dataset/voc.py
  5. 2 2
      train.py
  6. 0 97
      utils/misc.py

+ 11 - 10
dataset/build.py

@@ -64,16 +64,17 @@ def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
 # ------------------------------ Transform ------------------------------
 def build_transform(args, trans_config, max_stride=32, is_train=False):
     # Modify trans_config
-    ## mosaic prob.
-    if args.mosaic is not None:
-        trans_config['mosaic_prob']=args.mosaic if is_train else 0.0
-    else:
-        trans_config['mosaic_prob']=trans_config['mosaic_prob'] if is_train else 0.0
-    ## mixup prob.
-    if args.mixup is not None:
-        trans_config['mixup_prob']=args.mixup if is_train else 0.0
-    else:
-        trans_config['mixup_prob']=trans_config['mixup_prob']  if is_train else 0.0
+    if is_train:
+        ## mosaic prob.
+        if args.mosaic is not None:
+            trans_config['mosaic_prob']=args.mosaic if is_train else 0.0
+        else:
+            trans_config['mosaic_prob']=trans_config['mosaic_prob'] if is_train else 0.0
+        ## mixup prob.
+        if args.mixup is not None:
+            trans_config['mixup_prob']=args.mixup if is_train else 0.0
+        else:
+            trans_config['mixup_prob']=trans_config['mixup_prob']  if is_train else 0.0
 
     # Transform
     if trans_config['aug_type'] == 'ssd':

+ 1 - 3
dataset/coco.py

@@ -13,10 +13,8 @@ except:
     print("It seems that the COCOAPI is not installed.")
 
 try:
-    from .data_augment import build_transform
     from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 except:
-    from data_augment import build_transform
     from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 
 
@@ -223,7 +221,7 @@ class COCODataset(Dataset):
 
 if __name__ == "__main__":
     import argparse
-    from data_augment import build_transform
+    from build import build_transform
     
     parser = argparse.ArgumentParser(description='COCO-Dataset')
 

+ 1 - 4
dataset/ourdataset.py

@@ -12,10 +12,8 @@ except:
     print("It seems that the COCOAPI is not installed.")
 
 try:
-    from .data_augment import build_transform
     from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 except:
-    from data_augment import build_transform
     from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 
 # please define our class labels
@@ -191,8 +189,7 @@ class OurDataset(Dataset):
 if __name__ == "__main__":
     import argparse
     import sys
-    from data_augment import build_transform
-    sys.path.append('.')
+    from build import build_transform
     
     parser = argparse.ArgumentParser(description='Our-Dataset')
 

+ 1 - 3
dataset/voc.py

@@ -13,10 +13,8 @@ import numpy as np
 import xml.etree.ElementTree as ET
 
 try:
-    from .data_augment import build_transform
     from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 except:
-    from data_augment import build_transform
     from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 
 
@@ -246,7 +244,7 @@ class VOCDetection(data.Dataset):
 
 if __name__ == "__main__":
     import argparse
-    from data_augment import build_transform
+    from build import build_transform
     
     parser = argparse.ArgumentParser(description='VOC-Dataset')
 

+ 2 - 2
train.py

@@ -12,7 +12,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 # ----------------- Extra Components -----------------
 from utils import distributed_utils
 from utils.misc import compute_flops
-from utils.misc import ModelEMA, CollateFunc, build_dataset, build_dataloader
+from utils.misc import ModelEMA, CollateFunc, build_dataloader
 
 # ----------------- Evaluator Components -----------------
 from evaluator.build import build_evluator
@@ -145,7 +145,7 @@ def train():
     train_transform, trans_config = build_transform(
         args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
     val_transform, _ = build_transform(
-        args=args, max_stride=model_cfg['max_stride'], is_train=False)
+        args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
 
     # Dataset
     dataset, dataset_info = build_dataset(args, data_cfg, trans_config, train_transform, is_train=True)

+ 0 - 97
utils/misc.py

@@ -3,111 +3,14 @@ import torch.nn as nn
 import torch.nn.functional as F
 from torch.utils.data import DataLoader, DistributedSampler
 
-import os
 import cv2
 import math
 import numpy as np
 from copy import deepcopy
 from thop import profile
 
-from evaluator.coco_evaluator import COCOAPIEvaluator
-from evaluator.voc_evaluator import VOCAPIEvaluator
-from evaluator.ourdataset_evaluator import OurDatasetEvaluator
-
-from dataset.voc import VOCDetection, VOC_CLASSES
-from dataset.coco import COCODataset, coco_class_index, coco_class_labels
-from dataset.ourdataset import OurDataset, our_class_labels
-from dataset.data_augment import build_transform
-
 
 # ---------------------------- For Dataset ----------------------------
-## build dataset
-def build_dataset(args, trans_config, device, is_train=False):
-    # transform
-    print('==============================')
-    print('Transform Config: {}'.format(trans_config))
-    train_transform = build_transform(args.img_size, trans_config, True)
-    val_transform = build_transform(args.img_size, trans_config, False)
-
-    # dataset
-    if args.dataset == 'voc':
-        data_dir = os.path.join(args.root, 'VOCdevkit')
-        num_classes = 20
-        class_names = VOC_CLASSES
-        class_indexs = None
-
-        # dataset
-        dataset = VOCDetection(
-            img_size=args.img_size,
-            data_dir=data_dir,
-            image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
-            transform=train_transform,
-            trans_config=trans_config,
-            is_train=is_train
-            )
-        
-        # evaluator
-        evaluator = VOCAPIEvaluator(
-            data_dir=data_dir,
-            device=device,  
-            transform=val_transform
-            )
-
-    elif args.dataset == 'coco':
-        data_dir = os.path.join(args.root, 'COCO')
-        num_classes = 80
-        class_names = coco_class_labels
-        class_indexs = coco_class_index
-
-        # dataset
-        dataset = COCODataset(
-            img_size=args.img_size,
-            data_dir=data_dir,
-            image_set='train2017' if is_train else 'val2017',
-            transform=train_transform,
-            trans_config=trans_config,
-            is_train=is_train
-            )
-        # evaluator
-        evaluator = COCOAPIEvaluator(
-            data_dir=data_dir,
-            device=device,
-            transform=val_transform
-            )
-
-    elif args.dataset == 'ourdataset':
-        data_dir = os.path.join(args.root, 'OurDataset')
-        class_names = our_class_labels
-        num_classes = len(our_class_labels)
-        class_indexs = None
-
-        # dataset
-        dataset = OurDataset(
-            data_dir=data_dir,
-            img_size=args.img_size,
-            image_set='train' if is_train else 'val',
-            transform=train_transform,
-            trans_config=trans_config,
-            is_train=is_train
-            )
-        # evaluator
-        evaluator = OurDatasetEvaluator(
-            data_dir=data_dir,
-            device=device,
-            image_set='val',
-            transform=val_transform
-        )
-
-    else:
-        print('unknow dataset !! Only support voc, coco !!')
-        exit(0)
-
-    print('==============================')
-    print('Training model on:', args.dataset)
-    print('The dataset size:', len(dataset))
-
-    return dataset, (num_classes, class_names, class_indexs), evaluator
-
 ## build dataloader
 def build_dataloader(args, dataset, batch_size, collate_fn=None):
     # distributed