Browse Source

debug DETR-R50 training

yjh0410 1 year ago
parent
commit
d38ca985c5

+ 0 - 1
odlab/config/detr_config.py

@@ -68,7 +68,6 @@ class DetrBaseConfig(object):
         self.eval_epoch = 2
 
         # --------- Data process ---------
-        self.use_coco_labels_91 = True
         ## input size
         self.train_min_size = [800]   # short edge of image
         self.train_min_size2 = [400, 500, 600]

+ 0 - 1
odlab/config/fcos_config.py

@@ -93,7 +93,6 @@ class FcosBaseConfig(object):
         self.eval_epoch = 2
 
         # --------- Data process ---------
-        self.use_coco_labels_91 = False
         ## input size
         self.train_min_size = [800]   # short edge of image
         self.train_max_size = 1333

+ 0 - 1
odlab/config/yolof_config.py

@@ -94,7 +94,6 @@ class YolofBaseConfig(object):
         self.eval_epoch = 2
 
         # --------- Data process ---------
-        self.use_coco_labels_91 = False
         ## input size
         self.train_min_size = [800]   # short edge of image
         self.train_max_size = 1333

+ 1 - 1
odlab/models/detectors/detr/build.py

@@ -9,7 +9,7 @@ from .detr import DETR
 def build_detr(cfg, is_val=False):
     # -------------- Build DETR --------------
     model = DETR(cfg         = cfg,
-                 num_classes = 91,
+                 num_classes = cfg.num_classes,
                  conf_thresh = cfg.train_conf_thresh if is_val else cfg.test_conf_thresh,
                  topk        = cfg.train_topk        if is_val else cfg.test_topk,
                  )

+ 8 - 2
odlab/models/detectors/detr/criterion.py

@@ -20,14 +20,21 @@ class SetCriterion(nn.Module):
         super().__init__()
         self.num_classes = cfg.num_classes
         self.losses = ['labels', 'boxes']
+        self.eos_coef = 0.1
+
         # -------- Loss weights --------
         self.weight_dict = {'loss_cls':  cfg.loss_cls,
                             'loss_box':  cfg.loss_box,
                             'loss_giou': cfg.loss_giou}
         for i in range(cfg.num_dec_layers - 1):
             self.weight_dict.update({k + f'_aux_{i}': v for k, v in self.weight_dict.items()})
+        empty_weight = torch.ones(self.num_classes + 1)
+        empty_weight[-1] = self.eos_coef
+        self.register_buffer('empty_weight', empty_weight)
+        
         # -------- Matcher --------
-        self.matcher = HungarianMatcher(cfg.cost_class, cfg.cost_bbox, cfg.cost_giou)
+        matcher_hpy = cfg.matcher_hpy
+        self.matcher = HungarianMatcher(matcher_hpy['cost_class'], matcher_hpy['cost_bbox'], matcher_hpy['cost_giou'])
 
     def loss_labels(self, outputs, targets, indices, num_boxes):
         assert 'pred_logits' in outputs
@@ -37,7 +44,6 @@ class SetCriterion(nn.Module):
         target_classes_o = torch.cat([t["labels"][J] for t, (_, J) in zip(targets, indices)])
         target_classes = torch.full(src_logits.shape[:2], self.num_classes,
                                     dtype=torch.int64, device=src_logits.device)
-        target_classes[idx] = target_classes_o
 
         loss_cls = F.cross_entropy(src_logits.transpose(1, 2), target_classes, self.empty_weight)
 

+ 9 - 7
odlab/test.py

@@ -99,7 +99,6 @@ def test_det(args, model, device, dataset, transform, class_colors, class_names)
 
 
 if __name__ == '__main__':
-    np.random.seed(0)
     args = parse_args()
     # cuda
     if args.cuda:
@@ -116,6 +115,10 @@ if __name__ == '__main__':
 
     # Dataset
     dataset = build_dataset(args, cfg, is_train=False)
+    if args.model == "detr_r50":
+        # Test official DETR model
+        cfg.class_labels = coco_labels_91
+        cfg.num_classes = 91
 
     # Model
     model = build_model(args, cfg, is_val=False)
@@ -134,13 +137,12 @@ if __name__ == '__main__':
     del model_copy
         
     print("================= DETECT =================")
-    if cfg.use_coco_labels_91:
-        class_names = coco_labels_91
-    else:
-        class_names = cfg.class_labels
+    # Color for beautiful visualization
+    np.random.seed(0)
     class_colors = [(np.random.randint(255),
                      np.random.randint(255),
-                     np.random.randint(255)) for _ in range(len(class_names))]
+                     np.random.randint(255))
+                     for _ in range(cfg.num_classes)]
     # Run
     test_det(args         = args,
              model        = model, 
@@ -148,5 +150,5 @@ if __name__ == '__main__':
              dataset      = dataset,
              transform    = transform,
              class_colors = class_colors,
-             class_names  = class_names,
+             class_names  = cfg.class_labels,
              )

+ 1 - 0
odlab/train.py

@@ -126,6 +126,7 @@ def main():
     ## Build model
     model, criterion = build_model(args, cfg, is_val=True)
     model.to(device)
+    criterion.to(device)
     model_without_ddp = model
     ## Calcute Params & GFLOPs
     if distributed_utils.is_main_process():