yjh0410 2 年之前
父节点
当前提交
e06b02bb2a
共有 6 个文件被更改,包括 16 次插入119 次删除
  1. 11 10
      dataset/build.py
  2. 1 3
      dataset/coco.py
  3. 1 4
      dataset/ourdataset.py
  4. 1 3
      dataset/voc.py
  5. 2 2
      train.py
  6. 0 97
      utils/misc.py

+ 11 - 10
dataset/build.py

@@ -64,16 +64,17 @@ def build_dataset(args, data_cfg, trans_config, transform, is_train=False):
 # ------------------------------ Transform ------------------------------
 # ------------------------------ Transform ------------------------------
 def build_transform(args, trans_config, max_stride=32, is_train=False):
 def build_transform(args, trans_config, max_stride=32, is_train=False):
     # Modify trans_config
     # Modify trans_config
-    ## mosaic prob.
-    if args.mosaic is not None:
-        trans_config['mosaic_prob']=args.mosaic if is_train else 0.0
-    else:
-        trans_config['mosaic_prob']=trans_config['mosaic_prob'] if is_train else 0.0
-    ## mixup prob.
-    if args.mixup is not None:
-        trans_config['mixup_prob']=args.mixup if is_train else 0.0
-    else:
-        trans_config['mixup_prob']=trans_config['mixup_prob']  if is_train else 0.0
+    if is_train:
+        ## mosaic prob.
+        if args.mosaic is not None:
+            trans_config['mosaic_prob']=args.mosaic if is_train else 0.0
+        else:
+            trans_config['mosaic_prob']=trans_config['mosaic_prob'] if is_train else 0.0
+        ## mixup prob.
+        if args.mixup is not None:
+            trans_config['mixup_prob']=args.mixup if is_train else 0.0
+        else:
+            trans_config['mixup_prob']=trans_config['mixup_prob']  if is_train else 0.0
 
 
     # Transform
     # Transform
     if trans_config['aug_type'] == 'ssd':
     if trans_config['aug_type'] == 'ssd':

+ 1 - 3
dataset/coco.py

@@ -13,10 +13,8 @@ except:
     print("It seems that the COCOAPI is not installed.")
     print("It seems that the COCOAPI is not installed.")
 
 
 try:
 try:
-    from .data_augment import build_transform
     from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
     from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 except:
 except:
-    from data_augment import build_transform
     from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
     from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 
 
 
 
@@ -223,7 +221,7 @@ class COCODataset(Dataset):
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     import argparse
     import argparse
-    from data_augment import build_transform
+    from build import build_transform
     
     
     parser = argparse.ArgumentParser(description='COCO-Dataset')
     parser = argparse.ArgumentParser(description='COCO-Dataset')
 
 

+ 1 - 4
dataset/ourdataset.py

@@ -12,10 +12,8 @@ except:
     print("It seems that the COCOAPI is not installed.")
     print("It seems that the COCOAPI is not installed.")
 
 
 try:
 try:
-    from .data_augment import build_transform
     from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
     from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 except:
 except:
-    from data_augment import build_transform
     from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
     from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 
 
 # please define our class labels
 # please define our class labels
@@ -191,8 +189,7 @@ class OurDataset(Dataset):
 if __name__ == "__main__":
 if __name__ == "__main__":
     import argparse
     import argparse
     import sys
     import sys
-    from data_augment import build_transform
-    sys.path.append('.')
+    from build import build_transform
     
     
     parser = argparse.ArgumentParser(description='Our-Dataset')
     parser = argparse.ArgumentParser(description='Our-Dataset')
 
 

+ 1 - 3
dataset/voc.py

@@ -13,10 +13,8 @@ import numpy as np
 import xml.etree.ElementTree as ET
 import xml.etree.ElementTree as ET
 
 
 try:
 try:
-    from .data_augment import build_transform
     from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
     from .data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 except:
 except:
-    from data_augment import build_transform
     from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
     from data_augment.yolov5_augment import yolov5_mosaic_augment, yolov5_mixup_augment, yolox_mixup_augment
 
 
 
 
@@ -246,7 +244,7 @@ class VOCDetection(data.Dataset):
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
     import argparse
     import argparse
-    from data_augment import build_transform
+    from build import build_transform
     
     
     parser = argparse.ArgumentParser(description='VOC-Dataset')
     parser = argparse.ArgumentParser(description='VOC-Dataset')
 
 

+ 2 - 2
train.py

@@ -12,7 +12,7 @@ from torch.nn.parallel import DistributedDataParallel as DDP
 # ----------------- Extra Components -----------------
 # ----------------- Extra Components -----------------
 from utils import distributed_utils
 from utils import distributed_utils
 from utils.misc import compute_flops
 from utils.misc import compute_flops
-from utils.misc import ModelEMA, CollateFunc, build_dataset, build_dataloader
+from utils.misc import ModelEMA, CollateFunc, build_dataloader
 
 
 # ----------------- Evaluator Components -----------------
 # ----------------- Evaluator Components -----------------
 from evaluator.build import build_evluator
 from evaluator.build import build_evluator
@@ -145,7 +145,7 @@ def train():
     train_transform, trans_config = build_transform(
     train_transform, trans_config = build_transform(
         args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
         args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
     val_transform, _ = build_transform(
     val_transform, _ = build_transform(
-        args=args, max_stride=model_cfg['max_stride'], is_train=False)
+        args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
 
 
     # Dataset
     # Dataset
     dataset, dataset_info = build_dataset(args, data_cfg, trans_config, train_transform, is_train=True)
     dataset, dataset_info = build_dataset(args, data_cfg, trans_config, train_transform, is_train=True)

+ 0 - 97
utils/misc.py

@@ -3,111 +3,14 @@ import torch.nn as nn
 import torch.nn.functional as F
 import torch.nn.functional as F
 from torch.utils.data import DataLoader, DistributedSampler
 from torch.utils.data import DataLoader, DistributedSampler
 
 
-import os
 import cv2
 import cv2
 import math
 import math
 import numpy as np
 import numpy as np
 from copy import deepcopy
 from copy import deepcopy
 from thop import profile
 from thop import profile
 
 
-from evaluator.coco_evaluator import COCOAPIEvaluator
-from evaluator.voc_evaluator import VOCAPIEvaluator
-from evaluator.ourdataset_evaluator import OurDatasetEvaluator
-
-from dataset.voc import VOCDetection, VOC_CLASSES
-from dataset.coco import COCODataset, coco_class_index, coco_class_labels
-from dataset.ourdataset import OurDataset, our_class_labels
-from dataset.data_augment import build_transform
-
 
 
 # ---------------------------- For Dataset ----------------------------
 # ---------------------------- For Dataset ----------------------------
-## build dataset
-def build_dataset(args, trans_config, device, is_train=False):
-    # transform
-    print('==============================')
-    print('Transform Config: {}'.format(trans_config))
-    train_transform = build_transform(args.img_size, trans_config, True)
-    val_transform = build_transform(args.img_size, trans_config, False)
-
-    # dataset
-    if args.dataset == 'voc':
-        data_dir = os.path.join(args.root, 'VOCdevkit')
-        num_classes = 20
-        class_names = VOC_CLASSES
-        class_indexs = None
-
-        # dataset
-        dataset = VOCDetection(
-            img_size=args.img_size,
-            data_dir=data_dir,
-            image_sets=[('2007', 'trainval'), ('2012', 'trainval')] if is_train else [('2007', 'test')],
-            transform=train_transform,
-            trans_config=trans_config,
-            is_train=is_train
-            )
-        
-        # evaluator
-        evaluator = VOCAPIEvaluator(
-            data_dir=data_dir,
-            device=device,  
-            transform=val_transform
-            )
-
-    elif args.dataset == 'coco':
-        data_dir = os.path.join(args.root, 'COCO')
-        num_classes = 80
-        class_names = coco_class_labels
-        class_indexs = coco_class_index
-
-        # dataset
-        dataset = COCODataset(
-            img_size=args.img_size,
-            data_dir=data_dir,
-            image_set='train2017' if is_train else 'val2017',
-            transform=train_transform,
-            trans_config=trans_config,
-            is_train=is_train
-            )
-        # evaluator
-        evaluator = COCOAPIEvaluator(
-            data_dir=data_dir,
-            device=device,
-            transform=val_transform
-            )
-
-    elif args.dataset == 'ourdataset':
-        data_dir = os.path.join(args.root, 'OurDataset')
-        class_names = our_class_labels
-        num_classes = len(our_class_labels)
-        class_indexs = None
-
-        # dataset
-        dataset = OurDataset(
-            data_dir=data_dir,
-            img_size=args.img_size,
-            image_set='train' if is_train else 'val',
-            transform=train_transform,
-            trans_config=trans_config,
-            is_train=is_train
-            )
-        # evaluator
-        evaluator = OurDatasetEvaluator(
-            data_dir=data_dir,
-            device=device,
-            image_set='val',
-            transform=val_transform
-        )
-
-    else:
-        print('unknow dataset !! Only support voc, coco !!')
-        exit(0)
-
-    print('==============================')
-    print('Training model on:', args.dataset)
-    print('The dataset size:', len(dataset))
-
-    return dataset, (num_classes, class_names, class_indexs), evaluator
-
 ## build dataloader
 ## build dataloader
 def build_dataloader(args, dataset, batch_size, collate_fn=None):
 def build_dataloader(args, dataset, batch_size, collate_fn=None):
     # distributed
     # distributed