Bläddra i källkod

keep training RT-DETR-R50

yjh0410 1 år sedan
förälder
incheckning
74e1cab412
1 ändrade filer med 8 tillägg och 8 borttagningar
  1. 8 8
      models/detectors/rtdetr/basic_modules/fpn.py

+ 8 - 8
models/detectors/rtdetr/basic_modules/fpn.py

@@ -22,7 +22,7 @@ def build_fpn(cfg, in_dims, out_dim):
                              depthwise   = cfg['fpn_depthwise'],
                              num_heads   = cfg['en_num_heads'],
                              num_layers  = cfg['en_num_layers'],
-                             ffn_dim   = cfg['en_ffn_dim'],
+                             ffn_dim     = cfg['en_ffn_dim'],
                              dropout     = cfg['en_dropout'],
                              pe_temperature = cfg['pe_temperature'],
                              en_act_type    = cfg['en_act'],
@@ -35,12 +35,12 @@ def build_fpn(cfg, in_dims, out_dim):
 ## Hybrid Encoder (Transformer encoder + Convolutional PaFPN)
 class HybridEncoder(nn.Module):
     def __init__(self, 
-                 in_dims     :List  = [256, 512, 1024],
-                 out_dim     :int   = 256,
-                 num_blocks  :int   = 3,
-                 act_type    :str   = 'silu',
-                 norm_type   :str   = 'BN',
-                 depthwise   :bool  = False,
+                 in_dims        :List  = [256, 512, 1024],
+                 out_dim        :int   = 256,
+                 num_blocks     :int   = 3,
+                 act_type       :str   = 'silu',
+                 norm_type      :str   = 'BN',
+                 depthwise      :bool  = False,
                  # Transformer's parameters
                  num_heads      :int   = 8,
                  num_layers     :int   = 1,
@@ -74,7 +74,7 @@ class HybridEncoder(nn.Module):
         self.transformer_encoder = TransformerEncoder(d_model        = self.out_dim,
                                                       num_heads      = num_heads,
                                                       num_layers     = num_layers,
-                                                      ffn_dim      = ffn_dim,
+                                                      ffn_dim        = ffn_dim,
                                                       pe_temperature = pe_temperature,
                                                       dropout        = dropout,
                                                       act_type       = en_act_type