|
|
@@ -2,7 +2,7 @@ import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
-from .yolov5_plus_basic import (Conv, build_reduce_layer, build_downsample_layer, build_fpn_block)
|
|
|
+from .yolov5_plus_basic import (Conv, build_downsample_layer, build_fpn_block)
|
|
|
|
|
|
|
|
|
# YOLO-Style PaFPN
|
|
|
@@ -18,21 +18,19 @@ class Yolov5PlusPaFPN(nn.Module):
|
|
|
# --------------------------- Network Parameters ---------------------------
|
|
|
## top dwon
|
|
|
### P5 -> P4
|
|
|
- self.reduce_layer_1 = build_reduce_layer(cfg, c5, round(512*width))
|
|
|
- self.top_down_layer_1 = build_fpn_block(cfg, c4 + round(512*width), round(512*width))
|
|
|
+ self.top_down_layer_1 = build_fpn_block(cfg, c4 + c5, round(512*width))
|
|
|
|
|
|
### P4 -> P3
|
|
|
- self.reduce_layer_2 = build_reduce_layer(cfg, round(512*width), round(256*width))
|
|
|
- self.top_down_layer_2 = build_fpn_block(cfg, c3 + round(256*width), round(256*width))
|
|
|
+ self.top_down_layer_2 = build_fpn_block(cfg, c3 + round(512*width), round(256*width))
|
|
|
|
|
|
## bottom up
|
|
|
### P3 -> P4
|
|
|
self.downsample_layer_1 = build_downsample_layer(cfg, round(256*width), round(256*width))
|
|
|
- self.bottom_up_layer_1 = build_fpn_block(cfg, round(256*width) + round(256*width), round(512*width))
|
|
|
+ self.bottom_up_layer_1 = build_fpn_block(cfg, round(256*width) + round(512*width), round(512*width))
|
|
|
|
|
|
### P4 -> P5
|
|
|
self.downsample_layer_2 = build_downsample_layer(cfg, round(512*width), round(512*width))
|
|
|
- self.bottom_up_layer_2 = build_fpn_block(cfg, round(512*width) + round(512*width), round(512*width*ratio))
|
|
|
+ self.bottom_up_layer_2 = build_fpn_block(cfg, c5 + round(512*width), round(512*width*ratio))
|
|
|
|
|
|
## output proj layers
|
|
|
if out_dim is not None:
|
|
|
@@ -52,27 +50,25 @@ class Yolov5PlusPaFPN(nn.Module):
|
|
|
|
|
|
# Top down
|
|
|
## P5 -> P4
|
|
|
- c6 = self.reduce_layer_1(c5)
|
|
|
- c7 = F.interpolate(c6, scale_factor=2.0)
|
|
|
- c8 = torch.cat([c7, c4], dim=1)
|
|
|
- c9 = self.top_down_layer_1(c8)
|
|
|
+ c6 = F.interpolate(c5, scale_factor=2.0)
|
|
|
+ c7 = torch.cat([c6, c4], dim=1)
|
|
|
+ c8 = self.top_down_layer_1(c7)
|
|
|
## P4 -> P3
|
|
|
- c10 = self.reduce_layer_2(c9)
|
|
|
- c11 = F.interpolate(c10, scale_factor=2.0)
|
|
|
- c12 = torch.cat([c11, c3], dim=1)
|
|
|
- c13 = self.top_down_layer_2(c12)
|
|
|
+ c9 = F.interpolate(c8, scale_factor=2.0)
|
|
|
+ c10 = torch.cat([c9, c3], dim=1)
|
|
|
+ c11 = self.top_down_layer_2(c10)
|
|
|
|
|
|
# Bottom up
|
|
|
## p3 -> P4
|
|
|
- c14 = self.downsample_layer_1(c13)
|
|
|
- c15 = torch.cat([c14, c10], dim=1)
|
|
|
- c16 = self.bottom_up_layer_1(c15)
|
|
|
+ c12 = self.downsample_layer_1(c11)
|
|
|
+ c13 = torch.cat([c12, c8], dim=1)
|
|
|
+ c14 = self.bottom_up_layer_1(c13)
|
|
|
## P4 -> P5
|
|
|
- c17 = self.downsample_layer_2(c16)
|
|
|
- c18 = torch.cat([c17, c6], dim=1)
|
|
|
- c19 = self.bottom_up_layer_2(c18)
|
|
|
+ c15 = self.downsample_layer_2(c14)
|
|
|
+ c16 = torch.cat([c15, c5], dim=1)
|
|
|
+ c17 = self.bottom_up_layer_2(c16)
|
|
|
|
|
|
- out_feats = [c13, c16, c19] # [P3, P4, P5]
|
|
|
+ out_feats = [c11, c14, c17] # [P3, P4, P5]
|
|
|
|
|
|
# output proj layers
|
|
|
if self.out_layers is not None:
|