yjh0410 1 年之前
父節點
當前提交
ca92f8b686

+ 2 - 0
odlab/config/detr_config.py

@@ -23,6 +23,7 @@ class DetrBaseConfig(object):
         self.hidden_dim = 256
         self.num_heads = 8
         self.feedforward_dim = 2048
+        self.num_queries = 100
         self.num_enc_layers = 6
         self.num_dec_layers = 6
         self.dropout = 0.1
@@ -67,6 +68,7 @@ class DetrBaseConfig(object):
         self.eval_epoch = 2
 
         # --------- Data process ---------
+        self.use_coco_labels_91 = True
         ## input size
         self.train_min_size = [800]   # short edge of image
         self.train_min_size2 = [400, 500, 600]

+ 1 - 0
odlab/config/fcos_config.py

@@ -93,6 +93,7 @@ class FcosBaseConfig(object):
         self.eval_epoch = 2
 
         # --------- Data process ---------
+        self.use_coco_labels_91 = False
         ## input size
         self.train_min_size = [800]   # short edge of image
         self.train_max_size = 1333

+ 1 - 0
odlab/config/yolof_config.py

@@ -94,6 +94,7 @@ class YolofBaseConfig(object):
         self.eval_epoch = 2
 
         # --------- Data process ---------
+        self.use_coco_labels_91 = False
         ## input size
         self.train_min_size = [800]   # short edge of image
         self.train_max_size = 1333

+ 5 - 5
odlab/datasets/coco.py

@@ -16,16 +16,16 @@ except:
     from transforms import build_transform
 
 
-# coco_labels = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
-coco_labels = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',  'traffic light',  'fire hydrant',  'stop sign',  'parking meter',  'bench',  'bird',  'cat',  'dog',  'horse',  'sheep',  'cow',  'elephant',  'bear',  'zebra',  'giraffe',  'backpack',  'umbrella',  'handbag',  'tie',  'suitcase',  'frisbee',  'skis',  'snowboard',  'sports ball',  'kite',  'baseball bat',  'baseball glove',  'skateboard',  'surfboard',  'tennis racket',  'bottle',  'wine glass',  'cup',  'fork',  'knife',  'spoon',  'bowl',  'banana',  'apple',  'sandwich',  'orange',  'broccoli',  'carrot',  'hot dog',  'pizza',  'donut',  'cake',  'chair',  'couch',  'potted plant',  'bed',  'dining table',  'toilet',  'tv',  'laptop',  'mouse',  'remote',  'keyboard',  'cell phone',  'microwave',  'oven',  'toaster',  'sink',  'refrigerator',  'book',  'clock',  'vase',  'scissors',  'teddy bear',  'hair drier',  'toothbrush')
+coco_labels_91 = ('background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack', 'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk', 'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
+coco_labels_80 = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat',  'traffic light',  'fire hydrant',  'stop sign',  'parking meter',  'bench',  'bird',  'cat',  'dog',  'horse',  'sheep',  'cow',  'elephant',  'bear',  'zebra',  'giraffe',  'backpack',  'umbrella',  'handbag',  'tie',  'suitcase',  'frisbee',  'skis',  'snowboard',  'sports ball',  'kite',  'baseball bat',  'baseball glove',  'skateboard',  'surfboard',  'tennis racket',  'bottle',  'wine glass',  'cup',  'fork',  'knife',  'spoon',  'bowl',  'banana',  'apple',  'sandwich',  'orange',  'broccoli',  'carrot',  'hot dog',  'pizza',  'donut',  'cake',  'chair',  'couch',  'potted plant',  'bed',  'dining table',  'toilet',  'tv',  'laptop',  'mouse',  'remote',  'keyboard',  'cell phone',  'microwave',  'oven',  'toaster',  'sink',  'refrigerator',  'book',  'clock',  'vase',  'scissors',  'teddy bear',  'hair drier',  'toothbrush')
 coco_indexs = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 67, 70, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 84, 85, 86, 87, 88, 89, 90]
 
 
 class CocoDetection(torchvision.datasets.CocoDetection):
     def __init__(self, img_folder, ann_file, transforms):
         super(CocoDetection, self).__init__(img_folder, ann_file)
-        self.coco_labels = coco_labels  # 80 coco labels for detection task
-        self.coco_indexs = coco_indexs  # all original coco label index
+        self.coco_labels = coco_labels_80  # 80 coco labels for detection task
+        self.coco_indexs = coco_indexs     # all original coco label index
         self._transforms = transforms
 
     def prepare(self, image, target):
@@ -166,7 +166,7 @@ if __name__ == "__main__":
             # get box target
             x1, y1, x2, y2 = box.long()
             # get class label
-            cls_name = coco_labels[label.item()]
+            cls_name = coco_labels_80[label.item()]
             color = class_colors[label.item()]
             # draw bbox
             image = cv2.rectangle(image, (int(x1), int(y1)), (int(x2), int(y2)), color, 2)

+ 2 - 0
odlab/models/backbone/resnet.py

@@ -149,5 +149,7 @@ if __name__ == '__main__':
 
     x = torch.randn(2, 3, 320, 320)
     output = model(x)
+    for k in model.state_dict():
+        print(k)
     for y in output:
         print(y.size())

+ 1 - 1
odlab/models/detectors/detr/build.py

@@ -9,7 +9,7 @@ from .detr import DETR
 def build_detr(cfg, is_val=False):
     # -------------- Build DETR --------------
     model = DETR(cfg         = cfg,
-                 num_classes = cfg.num_classes,
+                 num_classes = 91,
                  conf_thresh = cfg.train_conf_thresh if is_val else cfg.test_conf_thresh,
                  topk        = cfg.train_topk        if is_val else cfg.test_topk,
                  )

+ 12 - 8
odlab/models/detectors/detr/detr.py

@@ -12,7 +12,7 @@ from ...basic.mlp   import MLP
 class DETR(nn.Module):
     def __init__(self, 
                  cfg,
-                 num_classes :int   = 80, 
+                 num_classes :int   = 90, 
                  conf_thresh :float = 0.05,
                  topk        :int   = 1000,
                  ):
@@ -25,20 +25,20 @@ class DETR(nn.Module):
 
         # ---------------------- Network Parameters ----------------------
         ## Backbone
-        self.backbone, feat_dims = build_backbone(cfg)
-
+        backbone, feat_dims = build_backbone(cfg)
+        self.backbone = nn.Sequential(backbone)
         ## Input proj
         self.input_proj = nn.Conv2d(feat_dims[-1], cfg.hidden_dim, kernel_size=1)
 
-        ## Object Queries
-        self.query_embed = nn.Embedding(cfg.num_queries, cfg.hidden_dim)
-        
         ## Transformer
         self.transformer = build_transformer(cfg, return_intermediate_dec=True)
 
+        ## Object queries
+        self.query_embed = nn.Embedding(cfg.num_queries, cfg.hidden_dim)
+        
         ## Output
         self.class_embed = nn.Linear(cfg.hidden_dim, num_classes + 1)
-        self.bbox_embed  = MLP(cfg.hidden_dim, cfg.feedward_dim, 4, 3)
+        self.bbox_embed  = MLP(cfg.hidden_dim, cfg.hidden_dim, 4, 3)
 
     @torch.jit.unused
     def set_aux_loss(self, outputs_class, outputs_coord):
@@ -102,10 +102,14 @@ class DETR(nn.Module):
             outputs = {'pred_logits': outputs_class[-1], 'pred_boxes': outputs_coord[-1]}
             outputs['aux_outputs'] = self.set_aux_loss(outputs_class, outputs_coord)
         else:
-            # [B, N, C] -> [N, C]
             cls_pred = outputs_class[-1].softmax(-1)[..., :-1]
             box_pred = outputs_coord[-1]
 
+            # [B, N, C] -> [N, C]
+            cls_pred = cls_pred[0]
+            box_pred = box_pred[0]
+
+            # xywh -> xyxy
             cxcy_pred = box_pred[..., :2]
             bwbh_pred = box_pred[..., 2:]
             x1y1_pred = cxcy_pred - 0.5 * bwbh_pred

+ 4 - 4
odlab/models/transformer/transformer.py

@@ -39,6 +39,7 @@ class DETRTransformer(nn.Module):
             hidden_dim, num_heads, ffn_dim, dropout, act_type, pre_norm)
         encoder_norm = nn.LayerNorm(hidden_dim) if pre_norm else None
         self.encoder = TransformerEncoder(encoder_layer, num_enc_layers, encoder_norm)
+
         ## Decoder module
         decoder_layer = TransformerDecoderLayer(
             hidden_dim, num_heads, ffn_dim, dropout, act_type, pre_norm)
@@ -67,7 +68,7 @@ class DETRTransformer(nn.Module):
             y_embed = y_embed / (y_embed[:, -1:, :] + 1e-6) * scale
             x_embed = x_embed / (x_embed[:, :, -1:] + 1e-6) * scale
     
-        dim_t = torch.arange(num_pos_feats, dtype=torch.float32)
+        dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=src_mask.device)
         dim_t_ = torch.div(dim_t, 2, rounding_mode='floor') / num_pos_feats
         dim_t = temperature ** (2 * dim_t_)
 
@@ -82,9 +83,8 @@ class DETRTransformer(nn.Module):
         return pos_embed
 
     def forward(self, src, src_mask, query_embed):
-        bs, c, h, w = src.shape
-
         # Get position embedding
+        bs, c, h, w = src.shape
         pos_embed = self.get_posembed(c, src_mask, normalize=True)
 
         # reshape: [B, C, H, W] -> [N, B, C], H=HW
@@ -96,7 +96,7 @@ class DETRTransformer(nn.Module):
         query_embed = query_embed.unsqueeze(1).repeat(1, bs, 1)
 
         # Encoder
-        memory = self.encoder(src, src_key_padding_mask=src_mask, pos_embed=pos_embed)
+        memory = self.encoder(src, src_mask, pos_embed=pos_embed)
 
         # Decoder
         tgt = torch.zeros_like(query_embed)

+ 0 - 1
odlab/models/transformer/transformer_decoder.py

@@ -90,7 +90,6 @@ class TransformerDecoderLayer(nn.Module):
         self.dropout3   = nn.Dropout(dropout)
         self.norm3      = nn.LayerNorm(hidden_dim)
 
-
     def with_pos_embed(self, tensor, pos_embed):
         return tensor if pos_embed is None else tensor + pos_embed
 

+ 2 - 3
odlab/models/transformer/transformer_encoder.py

@@ -62,14 +62,13 @@ class TransformerEncoderLayer(nn.Module):
         self.dropout2   = nn.Dropout(dropout)
         self.norm2      = nn.LayerNorm(hidden_dim)
 
-
     def with_pos_embed(self, tensor, pos_embed):
         return tensor if pos_embed is None else tensor + pos_embed
 
     def forward_post(self, src, src_mask, pos_embed):
         # MSHA
         q = k = self.with_pos_embed(src, pos_embed)
-        src2 = self.self_attn(q, k, src, src_key_padding_mask=src_mask)[0]
+        src2 = self.self_attn(q, k, src, key_padding_mask=src_mask)[0]
         src = src + self.dropout1(src2)
         src = self.norm1(src)
 
@@ -84,7 +83,7 @@ class TransformerEncoderLayer(nn.Module):
         # MSHA
         src2 = self.norm1(src)
         q = k = self.with_pos_embed(src2, pos_embed)
-        src2 = self.self_attn(q, k, src2, src_key_padding_mask=src_mask)[0]
+        src2 = self.self_attn(q, k, src2, key_padding_mask=src_mask)[0]
         src = src + self.dropout1(src2)
 
         # FFN

+ 11 - 7
odlab/test.py

@@ -8,6 +8,7 @@ import torch
 
 # load transform
 from datasets import build_dataset, build_transform
+from datasets.coco import coco_labels_91
 
 # load some utils
 from utils.misc import load_weight, compute_flops
@@ -98,6 +99,7 @@ def test_det(args, model, device, dataset, transform, class_colors, class_names)
 
 
 if __name__ == '__main__':
+    np.random.seed(0)
     args = parse_args()
     # cuda
     if args.cuda:
@@ -115,11 +117,6 @@ if __name__ == '__main__':
     # Dataset
     dataset = build_dataset(args, cfg, is_train=False)
 
-    np.random.seed(0)
-    class_colors = [(np.random.randint(255),
-                     np.random.randint(255),
-                     np.random.randint(255)) for _ in range(cfg.num_classes)]
-
     # Model
     model = build_model(args, cfg, is_val=False)
     model = load_weight(model, args.weight, args.fuse_conv_bn)
@@ -137,12 +134,19 @@ if __name__ == '__main__':
     del model_copy
         
     print("================= DETECT =================")
-    # run
+    if cfg.use_coco_labels_91:
+        class_names = coco_labels_91
+    else:
+        class_names = cfg.class_labels
+    class_colors = [(np.random.randint(255),
+                     np.random.randint(255),
+                     np.random.randint(255)) for _ in range(len(class_names))]
+    # Run
     test_det(args         = args,
              model        = model, 
              device       = device, 
              dataset      = dataset,
              transform    = transform,
              class_colors = class_colors,
-             class_names  = cfg.class_labels,
+             class_names  = class_names,
              )

+ 6 - 5
odlab/utils/misc.py

@@ -310,11 +310,12 @@ def load_weight(model, path_to_ckpt, fuse_cbn=False):
         print('no weight file ...')
     else:
         checkpoint = torch.load(path_to_ckpt, map_location='cpu')
-        print('--------------------------------------')
-        print('Best model infor:')
-        print('Epoch: {}'.format(checkpoint.pop("epoch")))
-        print('mAP: {}'.format(checkpoint.pop("mAP")))
-        print('--------------------------------------')
+        if "epoch" in checkpoint and "mAP" in checkpoint:
+            print('--------------------------------------')
+            print('Best model infor:')
+            print('Epoch: {}'.format(checkpoint.pop("epoch")))
+            print('mAP: {}'.format(checkpoint.pop("mAP")))
+            print('--------------------------------------')
         checkpoint_state_dict = checkpoint.pop("model")
         model.load_state_dict(checkpoint_state_dict)