yjh0410 1 year ago
parent
commit
492e833f1a

+ 8 - 0
README.md

@@ -62,6 +62,14 @@ cd dataset/scripts/
 sh COCO2017.sh
 ```
 
+- Clean COCO
+```Shell
+cd <RT-ODLab>
+cd tools/
+python clean_coco.py --root path/to/coco --image_set val
+python clean_coco.py --root path/to/coco --image_set train
+```
+
 - Check COCO
 ```Shell
 cd <RT-ODLab>

+ 4 - 2
dataset/coco.py

@@ -39,9 +39,9 @@ class COCODataset(Dataset):
         # ----------- Path parameters -----------
         self.data_dir = data_dir
         if image_set == 'train2017':
-            self.json_file='instances_train2017.json'
+            self.json_file='instances_train2017_clean.json'
         elif image_set == 'val2017':
-            self.json_file='instances_val2017.json'
+            self.json_file='instances_val2017_clean.json'
         elif image_set == 'test2017':
             self.json_file='image_info_test-dev2017.json'
         else:
@@ -248,6 +248,8 @@ if __name__ == "__main__":
     # opt
     parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
                         help='data root')
+    parser.add_argument('--image_set', type=str, default='train2017',
+                        help='mixup augmentation.')
     parser.add_argument('-size', '--img_size', default=640, type=int,
                         help='input image size.')
     parser.add_argument('--aug_type', type=str, default='ssd',

+ 19 - 0
dataset/data_augment/rtdetr_augment.py

@@ -0,0 +1,19 @@
+# Data preprocessor for Real-time DETR
+
+
+
+class RTDetrAugmentation(object):
+    def __init__(self):
+        return
+    
+    def __call__(self,):
+        pass
+
+
+class RTDetrBaseTransform(object):
+    def __init__(self):
+        return
+    
+    def __call__(self,):
+        pass
+

+ 129 - 0
models/detectors/rtrdet/rtrdet_basic.py

@@ -0,0 +1,129 @@
+import torch
+import torch.nn as nn
+from typing import List
+
+
+# ----------------- CNN modules -----------------
+def get_conv2d(c1, c2, k, p, s, d, g, bias=False):
+    conv = nn.Conv2d(c1, c2, k, stride=s, padding=p, dilation=d, groups=g, bias=bias)
+
+    return conv
+
+def get_activation(act_type=None):
+    if act_type == 'relu':
+        return nn.ReLU(inplace=True)
+    elif act_type == 'lrelu':
+        return nn.LeakyReLU(0.1, inplace=True)
+    elif act_type == 'mish':
+        return nn.Mish(inplace=True)
+    elif act_type == 'silu':
+        return nn.SiLU(inplace=True)
+    elif act_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+        
+def get_norm(norm_type, dim):
+    if norm_type == 'BN':
+        return nn.BatchNorm2d(dim)
+    elif norm_type == 'GN':
+        return nn.GroupNorm(num_groups=32, num_channels=dim)
+    elif norm_type is None:
+        return nn.Identity()
+    else:
+        raise NotImplementedError
+
+class Conv(nn.Module):
+    def __init__(self, 
+                 c1,                   # in channels
+                 c2,                   # out channels 
+                 k=1,                  # kernel size 
+                 p=0,                  # padding
+                 s=1,                  # padding
+                 d=1,                  # dilation
+                 act_type  :str  = 'lrelu',   # activation
+                 norm_type :str  ='BN',       # normalization
+                 depthwise :bool =False):
+        super(Conv, self).__init__()
+        convs = []
+        add_bias = False if norm_type else True
+        if depthwise:
+            convs.append(get_conv2d(c1, c1, k=k, p=p, s=s, d=d, g=c1, bias=add_bias))
+            # depthwise conv
+            if norm_type:
+                convs.append(get_norm(norm_type, c1))
+            if act_type:
+                convs.append(get_activation(act_type))
+            # pointwise conv
+            convs.append(get_conv2d(c1, c2, k=1, p=0, s=1, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+
+        else:
+            convs.append(get_conv2d(c1, c2, k=k, p=p, s=s, d=d, g=1, bias=add_bias))
+            if norm_type:
+                convs.append(get_norm(norm_type, c2))
+            if act_type:
+                convs.append(get_activation(act_type))
+            
+        self.convs = nn.Sequential(*convs)
+
+
+    def forward(self, x):
+        return self.convs(x)
+
+class Bottleneck(nn.Module):
+    def __init__(self,
+                 in_dim       :int,
+                 out_dim      :int,
+                 expand_ratio :float = 0.5,
+                 kernel_sizes :List = [3, 3],
+                 shortcut     :bool = True,
+                 act_type     :str  = 'silu',
+                 norm_type    :str  = 'BN',
+                 depthwise    :bool = False,):
+        super(Bottleneck, self).__init__()
+        inter_dim = int(out_dim * expand_ratio)  # hidden channels            
+        self.cv1 = Conv(in_dim, inter_dim, k=kernel_sizes[0], p=kernel_sizes[0]//2, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
+        self.cv2 = Conv(inter_dim, out_dim, k=kernel_sizes[1], p=kernel_sizes[1]//2, norm_type=norm_type, act_type=act_type, depthwise=depthwise)
+        self.shortcut = shortcut and in_dim == out_dim
+
+    def forward(self, x):
+        h = self.cv2(self.cv1(x))
+
+        return x + h if self.shortcut else h
+
+class RTCBlock(nn.Module):
+    def __init__(self,
+                 in_dim     :int,
+                 out_dim    :int,
+                 num_blocks :int  = 1,
+                 shortcut   :bool = False,
+                 act_type   :str  = 'silu',
+                 norm_type  :str  = 'BN',
+                 depthwise  :bool = False,):
+        super(RTCBlock, self).__init__()
+        self.inter_dim = out_dim // 2
+        self.input_proj = Conv(in_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+        self.m = nn.Sequential(*(
+            Bottleneck(self.inter_dim, self.inter_dim, 1.0, [3, 3], shortcut, act_type, norm_type, depthwise)
+            for _ in range(num_blocks)))
+        self.output_proj = Conv((2 + num_blocks) * self.inter_dim, out_dim, k=1, act_type=act_type, norm_type=norm_type)
+
+    def forward(self, x):
+        # Input proj
+        x1, x2 = torch.chunk(self.input_proj(x), 2, dim=1)
+        out = list([x1, x2])
+
+        # Bottlenecl
+        out.extend(m(out[-1]) for m in self.m)
+
+        # Output proj
+        out = self.output_proj(torch.cat(out, dim=1))
+
+        return out
+
+
+# ----------------- Transformer modules -----------------

+ 2 - 2
models/detectors/yolov8/README.md

@@ -2,10 +2,10 @@
 
 |   Model   |  Batch | Scale | AP<sup>val<br>0.5:0.95 | AP<sup>val<br>0.5 | FLOPs<br><sup>(G) | Params<br><sup>(M) | Weight |
 |-----------|--------|-------|------------------------|-------------------|-------------------|--------------------|--------|
-| YOLOv8-N  | 8xb16  |  640  |          36.8          |        52.9       |        8.8        |         3.2        | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov8_n_coco.pth) |
+| YOLOv8-N  | 8xb16  |  640  |          37.0          |        52.9       |        8.8        |         3.2        | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov8_n_coco.pth) |
 | YOLOv8-S  | 8xb16  |  640  |                        |                   |                   |                    |  |
 | YOLOv8-M  | 8xb16  |  640  |                        |                   |                   |                    |  |
-| YOLOv8-L  | 8xb16  |  640  |          50.2          |        68.0       |       165.7       |         43.7       | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov8_l_coco.pth) |
+| YOLOv8-L  | 8xb16  |  640  |          50.7          |        68.3       |       165.7       |         43.7       | [ckpt](https://github.com/yjh0410/RT-ODLab/releases/download/yolo_tutorial_ckpt/yolov8_l_coco.pth) |
 
 - For training, we train YOLOv8 series with 500 epochs on COCO.
 - For data augmentation, we use the large scale jitter (LSJ), Mosaic augmentation and Mixup augmentation, following the setting of [YOLOv8](https://github.com/ultralytics/yolov8).

+ 85 - 0
tools/clean_coco.py

@@ -0,0 +1,85 @@
+import os
+import json
+
+
+if __name__ == "__main__":
+    import argparse
+    
+    parser = argparse.ArgumentParser(description='COCO-Dataset')
+
+    # --------------- opt parameters ---------------
+    parser.add_argument('--root', default='/Users/liuhaoran/Desktop/python_work/object-detection/dataset/COCO/',
+                        help='data root')
+    parser.add_argument('--image_set', type=str, default='val',
+                        help='augmentation type')
+    parser.add_argument('--task', type=str, default='det',
+                        help='augmentation type')
+    
+    args = parser.parse_args()
+
+    # --------------- load json ---------------
+    if args.task == 'det':
+        task_prefix = 'instances_{}2017.json'
+        clean_task_prefix = 'instances_{}2017_clean.json'
+    elif args.task == 'pos':
+        task_prefix = 'person_keypoints_{}2017.json'
+        clean_task_prefix = 'person_keypoints_{}2017_clean.json'
+    else:
+        raise NotImplementedError('Unkown task !')
+    
+    json_path = os.path.join(args.root, 'annotations', task_prefix.format(args.image_set))
+
+    clean_json_file = dict()
+    with open(json_path, 'r') as file:
+        json_file = json.load(file)
+        # json_file is a Dict: dict_keys(['info', 'licenses', 'images', 'annotations', 'categories'])
+        clean_json_file['info'] = json_file['info'] 
+        clean_json_file['licenses'] = json_file['licenses']
+        clean_json_file['categories'] = json_file['categories']
+
+        images_list = json_file['images']
+        annots_list = json_file['annotations']
+        num_images = len(images_list)
+
+        # -------------- Filter annotations --------------
+        print("Processing annotations ...")
+        valid_image_ids = []
+        clean_annots_list = [] 
+        for i, anno in enumerate(annots_list):
+            if i % 5000 == 0:
+                print("[{}] / [{}] ...".format(i, len(annots_list)))
+            x1, y1, bw, bh = anno['bbox']
+            if bw > 0 and bh > 0:
+                clean_annots_list.append(anno)
+                if anno['image_id'] not in valid_image_ids:
+                    valid_image_ids.append(anno['image_id'])
+        print("Valid number of images: ", len(valid_image_ids))
+        print("Valid number of annots: ", len(clean_annots_list))
+        print("Original number of annots: ", len(annots_list))
+
+        # -------------- Filter images --------------
+        print("Processing images ...")
+        clean_images_list = []
+        for i in range(num_images):
+            if args.image_set == 'train' and i % 5000 == 0:
+                print("[{}] / [{}] ...".format(i, num_images))
+            if args.image_set == 'val' and i % 500 == 0:
+                print("[{}] / [{}] ...".format(i, num_images))
+            
+            # A single image dict
+            image_dict = images_list[i]
+            image_id = image_dict['id']
+
+            if image_id in valid_image_ids:
+                clean_images_list.append(image_dict)
+
+        print('Number of images after cleaning: ', len(clean_images_list))
+        print('Number of annotations after cleaning: ', len(clean_annots_list))
+
+        clean_json_file['images'] = clean_images_list
+        clean_json_file['annotations'] = clean_annots_list
+    
+    # --------------- Save filterd json file ---------------
+    new_json_path = os.path.join(args.root, 'annotations', clean_task_prefix.format(args.image_set))
+    with open(new_json_path, 'w') as f:
+        json.dump(clean_json_file, f)

+ 2 - 2
train.py

@@ -71,9 +71,9 @@ def parse_args():
     # Model
     parser.add_argument('-m', '--model', default='yolov1', type=str,
                         help='build yolo')
-    parser.add_argument('-ct', '--conf_thresh', default=0.005, type=float,
+    parser.add_argument('-ct', '--conf_thresh', default=0.001, type=float,
                         help='confidence threshold')
-    parser.add_argument('-nt', '--nms_thresh', default=0.6, type=float,
+    parser.add_argument('-nt', '--nms_thresh', default=0.7, type=float,
                         help='NMS threshold')
     parser.add_argument('--topk', default=1000, type=int,
                         help='topk candidates dets of each level before NMS')