Ver Fonte

add MemoryCompressor

yjh0410 há 2 anos atrás
pai
commit
a041d996ff

+ 7 - 1
config/model_config/rtdetr_config.py

@@ -24,13 +24,19 @@ rtdetr_cfg = {
         'neck_norm': 'BN',
         'neck_depthwise': False,
         ### CNN-CSFM
-        'fpn': 'yolo_pafpn',
+        'fpn': 'yolovx_pafpn',
         'fpn_reduce_layer': 'conv',
         'fpn_downsample_layer': 'conv',
         'fpn_core_block': 'elanblock',
         'fpn_act': 'silu',
         'fpn_norm': 'BN',
         'fpn_depthwise': False,
+        ## ------- Memory Decoder -------
+        'dim_compressed': 900,
+        'com_dim_feedforward': 1024,
+        'com_num_heads': 8,
+        'com_dropout': 0.1,
+        'com_act': 'silu',
         ## ------- Transformer Decoder -------
         'd_model': 256,
         'attn_type': 'mhsa',

+ 1 - 1
models/detectors/rtdetr/image_encoder/cnn_pafpn.py

@@ -92,7 +92,7 @@ class YolovxPaFPN(nn.Module):
 def build_fpn(cfg, in_dims, out_dim=None, input_proj=False):
     model = cfg['fpn']
     # build pafpn
-    if model == 'YolovxPaFPN':
+    if model == 'yolovx_pafpn':
         fpn_net = YolovxPaFPN(cfg, in_dims, out_dim, input_proj)
 
     return fpn_net

+ 1 - 1
models/detectors/rtdetr/loss.py

@@ -102,7 +102,7 @@ class Criterion(nn.Module):
         return loss_map[loss](outputs, targets, indices, num_boxes, **kwargs)
 
 
-    def forward(self, outputs, targets):
+    def forward(self, outputs, targets, epoch=0):
         """ This performs the loss computation.
         Parameters:
              outputs: dict of tensors, see the output specification of the model for the format

+ 12 - 2
models/detectors/rtdetr/rtdetr.py

@@ -2,6 +2,7 @@ import torch
 import torch.nn as nn
 
 from .rtdetr_encoder import build_encoder
+from .rtdetr_compressor import build_compressor
 from .rtdetr_decoder import build_decoder
 from .rtdetr_dethead import build_dethead
 
@@ -32,6 +33,9 @@ class RTDETR(nn.Module):
         ## Encoder
         self.encoder = build_encoder(cfg, trainable, 'img_encoder')
 
+        ## Compressor
+        self.compressor = build_compressor(cfg, self.d_model)
+
         ## Decoder
         self.decoder = build_decoder(cfg, self.d_model, return_intermediate=aux_loss)
 
@@ -97,8 +101,11 @@ class RTDETR(nn.Module):
         memory = memory.permute(0, 2, 1).contiguous()
         memory_pos = memory_pos.permute(0, 2, 1).contiguous()
 
+        # -------------------- Compressor --------------------
+        compressed_memory = self.compressor(memory, memory_pos)
+
         # -------------------- Decoder --------------------
-        hs, reference = self.decoder(memory, memory_pos)
+        hs, reference = self.decoder(compressed_memory, None)
 
         # -------------------- DetHead --------------------
         out_logits, out_bbox = self.dethead(hs, reference, False)
@@ -139,8 +146,11 @@ class RTDETR(nn.Module):
             memory = memory.permute(0, 2, 1).contiguous()
             memory_pos = memory_pos.permute(0, 2, 1).contiguous()
             
+            # -------------------- Compressor --------------------
+            compressed_memory = self.compressor(memory, memory_pos)
+
             # -------------------- Decoder --------------------
-            hs, reference = self.decoder(memory, memory_pos)
+            hs, reference = self.decoder(compressed_memory, None)
 
             # -------------------- DetHead --------------------
             outputs_class, outputs_coords = self.dethead(hs, reference, True)

+ 34 - 0
models/detectors/rtdetr/rtdetr_compressor.py

@@ -0,0 +1,34 @@
+import torch
+import torch.nn as nn
+
+from .rtdetr_basic import TRDecoderLayer
+
+
+# Transformer Decoder Module
+class MemoryCompressor(nn.Module):
+    def __init__(self, cfg, in_dim):
+        super().__init__()
+        # -------------------- Basic Parameters ---------------------
+        self.d_model = in_dim
+        self.ffn_dim = round(cfg['com_dim_feedforward']*cfg['width'])
+        self.compressed_vector = nn.Embedding(cfg['dim_compressed'], in_dim)
+        # -------------------- Network Parameters ---------------------
+        self.compress_layer = TRDecoderLayer(
+            d_model=in_dim,
+            dim_feedforward=self.ffn_dim,
+            num_heads=cfg['com_num_heads'],
+            dropout=cfg['com_dropout'],
+            act_type=cfg['com_act']
+        )
+
+
+    def forward(self, memory, memory_pos):
+        bs = memory.size(0)
+        output = self.compressed_vector.weight[None].repeat(bs, 1, 1)
+        output = self.compress_layer(output, None, memory, memory_pos)
+
+        return output
+
+
+def build_compressor(cfg, in_dim):
+    return MemoryCompressor(cfg, in_dim)

+ 0 - 3
models/detectors/rtdetr/rtdetr_decoder.py

@@ -35,7 +35,6 @@ class TransformerDecoder(nn.Module):
         nn.init.normal_(self.object_query.weight.data)
         ## TODO: Group queries
 
-
         self.bbox_embed = None
         self.class_embed = None
 
@@ -86,7 +85,6 @@ class TransformerDecoder(nn.Module):
             # Conditional query
             query_sine_embed = self.query_sine_embed(num_feats, reference_points)
             query_pos = self.ref_point_head(query_sine_embed) # [B, N, C]
-
             # Decoder
             output = layer(
                     # input for decoder
@@ -96,7 +94,6 @@ class TransformerDecoder(nn.Module):
                     memory = memory,
                     memory_pos = memory_pos,
                 )
-
             # Iter update
             if self.bbox_embed is not None:
                 delta_unsig = self.bbox_embed[layer_id](output)