浏览代码

batch size base = 64

yjh0410 1 年之前
父节点
当前提交
0be5cd4f86
共有 2 个文件被更改,包括 2 次插入2 次删除
  1. 1 1
      yolo/engine.py
  2. 1 1
      yolo/train.py

+ 1 - 1
yolo/engine.py

@@ -62,7 +62,7 @@ class YoloTrainer(object):
         self.scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
 
         # ---------------------------- Build Optimizer ----------------------------
-        self.grad_accumulate = max(256 // args.batch_size, 1)
+        self.grad_accumulate = max(64 // args.batch_size, 1)
         cfg.base_lr = cfg.per_image_lr * args.batch_size * self.grad_accumulate
         cfg.min_lr  = cfg.base_lr * cfg.min_lr_ratio
         self.optimizer, self.start_epoch = build_yolo_optimizer(cfg, model, args.resume)

+ 1 - 1
yolo/train.py

@@ -75,7 +75,7 @@ def parse_args():
                         help='data root')
     parser.add_argument('-d', '--dataset', default='coco',
                         help='coco, voc')
-    parser.add_argument('--num_workers', default=4, type=int, 
+    parser.add_argument('--num_workers', default=16, type=int, 
                         help='Number of workers used in dataloading')
     
     # DDP train