yjh0410 1 年之前
父节点
当前提交
0ebad53257
共有 3 个文件被更改,包括 8 次插入2 次删除
  1. 1 1
      odlab/train.py
  2. 6 0
      odlab/utils/optimizer.py
  3. 1 1
      yolo/utils/solver/optimizer.py

+ 1 - 1
odlab/train.py

@@ -164,7 +164,7 @@ def main():
 
 
     # ----------------------- Training -----------------------
     # ----------------------- Training -----------------------
     print("Start training")
     print("Start training")
-    best_map = -1.
+    best_map = cfg.best_map
     for epoch in range(start_epoch, cfg.max_epoch):
     for epoch in range(start_epoch, cfg.max_epoch):
         if args.distributed:
         if args.distributed:
             train_loader.batch_sampler.sampler.set_epoch(epoch)
             train_loader.batch_sampler.sampler.set_epoch(epoch)

+ 6 - 0
odlab/utils/optimizer.py

@@ -34,6 +34,7 @@ def build_optimizer(cfg, model, resume=None):
             )
             )
                                 
                                 
     start_epoch = 0
     start_epoch = 0
+    cfg.best_map = -1.
     if resume is not None and resume.lower() != "none":
     if resume is not None and resume.lower() != "none":
         print('Load optimzier from the checkpoint: ', resume)
         print('Load optimzier from the checkpoint: ', resume)
         checkpoint = torch.load(resume)
         checkpoint = torch.load(resume)
@@ -41,5 +42,10 @@ def build_optimizer(cfg, model, resume=None):
         checkpoint_state_dict = checkpoint.pop("optimizer")
         checkpoint_state_dict = checkpoint.pop("optimizer")
         optimizer.load_state_dict(checkpoint_state_dict)
         optimizer.load_state_dict(checkpoint_state_dict)
         start_epoch = checkpoint.pop("epoch") + 1
         start_epoch = checkpoint.pop("epoch") + 1
+        if "mAP" in checkpoint:
+            print('--Load best metric from the checkpoint: ', resume)
+            best_map = checkpoint["mAP"]
+            cfg.best_map = best_map
+        del checkpoint, checkpoint_state_dict
                                                         
                                                         
     return optimizer, start_epoch
     return optimizer, start_epoch

+ 1 - 1
yolo/utils/solver/optimizer.py

@@ -43,13 +43,13 @@ def build_yolo_optimizer(cfg, model, resume=None):
             print('--Load optimizer from the checkpoint: ', resume)
             print('--Load optimizer from the checkpoint: ', resume)
             optimizer.load_state_dict(checkpoint_state_dict)
             optimizer.load_state_dict(checkpoint_state_dict)
             start_epoch = checkpoint.pop("epoch") + 1
             start_epoch = checkpoint.pop("epoch") + 1
-            del checkpoint, checkpoint_state_dict
             if "mAP" in checkpoint:
             if "mAP" in checkpoint:
                 print('--Load best metric from the checkpoint: ', resume)
                 print('--Load best metric from the checkpoint: ', resume)
                 best_map = checkpoint["mAP"]
                 best_map = checkpoint["mAP"]
                 cfg.best_map = best_map
                 cfg.best_map = best_map
             else:
             else:
                 cfg.best_map = -1.
                 cfg.best_map = -1.
+            del checkpoint, checkpoint_state_dict
         except:
         except:
             print("No optimzier in the given checkpoint.")
             print("No optimzier in the given checkpoint.")