yjh0410 преди 1 година
родител
ревизия
0ebad53257
променени са 3 файла, в които са добавени 8 реда и са изтрити 2 реда
  1. 1 1
      odlab/train.py
  2. 6 0
      odlab/utils/optimizer.py
  3. 1 1
      yolo/utils/solver/optimizer.py

+ 1 - 1
odlab/train.py

@@ -164,7 +164,7 @@ def main():
 
     # ----------------------- Training -----------------------
     print("Start training")
-    best_map = -1.
+    best_map = cfg.best_map
     for epoch in range(start_epoch, cfg.max_epoch):
         if args.distributed:
             train_loader.batch_sampler.sampler.set_epoch(epoch)

+ 6 - 0
odlab/utils/optimizer.py

@@ -34,6 +34,7 @@ def build_optimizer(cfg, model, resume=None):
             )
                                 
     start_epoch = 0
+    cfg.best_map = -1.
     if resume is not None and resume.lower() != "none":
         print('Load optimzier from the checkpoint: ', resume)
         checkpoint = torch.load(resume)
@@ -41,5 +42,10 @@ def build_optimizer(cfg, model, resume=None):
         checkpoint_state_dict = checkpoint.pop("optimizer")
         optimizer.load_state_dict(checkpoint_state_dict)
         start_epoch = checkpoint.pop("epoch") + 1
+        if "mAP" in checkpoint:
+            print('--Load best metric from the checkpoint: ', resume)
+            best_map = checkpoint["mAP"]
+            cfg.best_map = best_map
+        del checkpoint, checkpoint_state_dict
                                                         
     return optimizer, start_epoch

+ 1 - 1
yolo/utils/solver/optimizer.py

@@ -43,13 +43,13 @@ def build_yolo_optimizer(cfg, model, resume=None):
             print('--Load optimizer from the checkpoint: ', resume)
             optimizer.load_state_dict(checkpoint_state_dict)
             start_epoch = checkpoint.pop("epoch") + 1
-            del checkpoint, checkpoint_state_dict
             if "mAP" in checkpoint:
                 print('--Load best metric from the checkpoint: ', resume)
                 best_map = checkpoint["mAP"]
                 cfg.best_map = best_map
             else:
                 cfg.best_map = -1.
+            del checkpoint, checkpoint_state_dict
         except:
             print("No optimzier in the given checkpoint.")