|
|
@@ -62,13 +62,11 @@ class HybridEncoder(nn.Module):
|
|
|
c3, c4, c5 = in_dims
|
|
|
|
|
|
# ---------------- Input projs ----------------
|
|
|
- self.reduce_layer_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
|
|
|
- self.reduce_layer_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
|
|
|
- self.reduce_layer_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=act_type, norm_type=norm_type)
|
|
|
+ self.input_proj_1 = BasicConv(c5, self.out_dim, kernel_size=1, act_type=None, norm_type=norm_type)
|
|
|
+ self.input_proj_2 = BasicConv(c4, self.out_dim, kernel_size=1, act_type=None, norm_type=norm_type)
|
|
|
+ self.input_proj_3 = BasicConv(c3, self.out_dim, kernel_size=1, act_type=None, norm_type=norm_type)
|
|
|
|
|
|
# ---------------- Downsample ----------------
|
|
|
- self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
|
|
|
- self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim, kernel_size=3, padding=1, stride=2, act_type=act_type, norm_type=norm_type)
|
|
|
|
|
|
# ---------------- Transformer Encoder ----------------
|
|
|
self.transformer_encoder = TransformerEncoder(d_model = self.out_dim,
|
|
|
@@ -82,38 +80,50 @@ class HybridEncoder(nn.Module):
|
|
|
|
|
|
# ---------------- Top dwon FPN ----------------
|
|
|
## P5 -> P4
|
|
|
+ self.reduce_layer_1 = BasicConv(self.out_dim, self.out_dim,
|
|
|
+ kernel_size=1, padding=0, stride=1,
|
|
|
+ act_type=act_type, norm_type=norm_type)
|
|
|
self.top_down_layer_1 = RepRTCBlock(in_dim = self.out_dim * 2,
|
|
|
- out_dim = self.out_dim,
|
|
|
- num_blocks = num_blocks,
|
|
|
- expansion = expansion,
|
|
|
- act_type = act_type,
|
|
|
- norm_type = norm_type,
|
|
|
+ out_dim = self.out_dim,
|
|
|
+ num_blocks = num_blocks,
|
|
|
+ expansion = expansion,
|
|
|
+ act_type = act_type,
|
|
|
+ norm_type = norm_type,
|
|
|
)
|
|
|
## P4 -> P3
|
|
|
+ self.reduce_layer_2 = BasicConv(self.out_dim, self.out_dim,
|
|
|
+ kernel_size=1, padding=0, stride=1,
|
|
|
+ act_type=act_type, norm_type=norm_type)
|
|
|
self.top_down_layer_2 = RepRTCBlock(in_dim = self.out_dim * 2,
|
|
|
- out_dim = self.out_dim,
|
|
|
- num_blocks = num_blocks,
|
|
|
- expansion = expansion,
|
|
|
- act_type = act_type,
|
|
|
- norm_type = norm_type,
|
|
|
+ out_dim = self.out_dim,
|
|
|
+ num_blocks = num_blocks,
|
|
|
+ expansion = expansion,
|
|
|
+ act_type = act_type,
|
|
|
+ norm_type = norm_type,
|
|
|
)
|
|
|
|
|
|
# ---------------- Bottom up PAN----------------
|
|
|
## P3 -> P4
|
|
|
- self.bottom_up_layer_1 = RepRTCBlock(in_dim = self.out_dim * 2,
|
|
|
- out_dim = self.out_dim,
|
|
|
- num_blocks = num_blocks,
|
|
|
- expansion = expansion,
|
|
|
- act_type = act_type,
|
|
|
- norm_type = norm_type,
|
|
|
+ self.dowmsample_layer_1 = BasicConv(self.out_dim, self.out_dim,
|
|
|
+ kernel_size=3, padding=1, stride=2,
|
|
|
+ act_type=act_type, norm_type=norm_type)
|
|
|
+ self.bottom_up_layer_1 = RepRTCBlock(in_dim = self.out_dim * 2,
|
|
|
+ out_dim = self.out_dim,
|
|
|
+ num_blocks = num_blocks,
|
|
|
+ expansion = expansion,
|
|
|
+ act_type = act_type,
|
|
|
+ norm_type = norm_type,
|
|
|
)
|
|
|
## P4 -> P5
|
|
|
- self.bottom_up_layer_2 = RepRTCBlock(in_dim = self.out_dim * 2,
|
|
|
- out_dim = self.out_dim,
|
|
|
- num_blocks = num_blocks,
|
|
|
- expansion = expansion,
|
|
|
- act_type = act_type,
|
|
|
- norm_type = norm_type,
|
|
|
+ self.dowmsample_layer_2 = BasicConv(self.out_dim, self.out_dim,
|
|
|
+ kernel_size=3, padding=1, stride=2,
|
|
|
+ act_type=act_type, norm_type=norm_type)
|
|
|
+ self.bottom_up_layer_2 = RepRTCBlock(in_dim = self.out_dim * 2,
|
|
|
+ out_dim = self.out_dim,
|
|
|
+ num_blocks = num_blocks,
|
|
|
+ expansion = expansion,
|
|
|
+ act_type = act_type,
|
|
|
+ norm_type = norm_type,
|
|
|
)
|
|
|
|
|
|
self.init_weights()
|
|
|
@@ -130,26 +140,31 @@ class HybridEncoder(nn.Module):
|
|
|
c3, c4, c5 = features
|
|
|
|
|
|
# -------- Input projs --------
|
|
|
- p5 = self.reduce_layer_1(c5)
|
|
|
- p4 = self.reduce_layer_2(c4)
|
|
|
- p3 = self.reduce_layer_3(c3)
|
|
|
+ p5 = self.input_proj_1(c5)
|
|
|
+ p4 = self.input_proj_2(c4)
|
|
|
+ p3 = self.input_proj_3(c3)
|
|
|
|
|
|
# -------- Transformer encoder --------
|
|
|
p5 = self.transformer_encoder(p5)
|
|
|
|
|
|
# -------- Top down FPN --------
|
|
|
- p5_up = F.interpolate(p5, scale_factor=2.0)
|
|
|
+ ## P5 -> P4
|
|
|
+ p5_in = self.reduce_layer_1(p5)
|
|
|
+ p5_up = F.interpolate(p5_in, scale_factor=2.0)
|
|
|
p4 = self.top_down_layer_1(torch.cat([p4, p5_up], dim=1))
|
|
|
|
|
|
- p4_up = F.interpolate(p4, scale_factor=2.0)
|
|
|
+ ## P4 -> P3
|
|
|
+ p4_in = self.reduce_layer_2(p4)
|
|
|
+ p4_up = F.interpolate(p4_in, scale_factor=2.0)
|
|
|
p3 = self.top_down_layer_2(torch.cat([p3, p4_up], dim=1))
|
|
|
|
|
|
# -------- Bottom up PAN --------
|
|
|
+ ## P3 -> P4
|
|
|
p3_ds = self.dowmsample_layer_1(p3)
|
|
|
- p4 = self.bottom_up_layer_1(torch.cat([p4, p3_ds], dim=1))
|
|
|
+ p4 = self.bottom_up_layer_1(torch.cat([p4_in, p3_ds], dim=1))
|
|
|
|
|
|
p4_ds = self.dowmsample_layer_2(p4)
|
|
|
- p5 = self.bottom_up_layer_2(torch.cat([p5, p4_ds], dim=1))
|
|
|
+ p5 = self.bottom_up_layer_2(torch.cat([p5_in, p4_ds], dim=1))
|
|
|
|
|
|
out_feats = [p3, p4, p5]
|
|
|
|