|
|
@@ -142,13 +142,13 @@ def train():
|
|
|
trans_cfg = build_trans_config(model_cfg['trans_type'])
|
|
|
|
|
|
# Transform
|
|
|
- train_transform, trans_config = build_transform(
|
|
|
+ train_transform, trans_cfg = build_transform(
|
|
|
args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=True)
|
|
|
val_transform, _ = build_transform(
|
|
|
args=args, trans_config=trans_cfg, max_stride=model_cfg['max_stride'], is_train=False)
|
|
|
|
|
|
# Dataset
|
|
|
- dataset, dataset_info = build_dataset(args, data_cfg, trans_config, train_transform, is_train=True)
|
|
|
+ dataset, dataset_info = build_dataset(args, data_cfg, trans_cfg, train_transform, is_train=True)
|
|
|
|
|
|
# Dataloader
|
|
|
dataloader = build_dataloader(args, dataset, per_gpu_batch, CollateFunc())
|