|
|
@@ -30,10 +30,17 @@ class Yolov8(nn.Module):
|
|
|
self.no_multi_labels = False if is_val else True
|
|
|
|
|
|
# ---------------------- Network Parameters ----------------------
|
|
|
+ ## Backbone
|
|
|
self.backbone = Yolov8Backbone(cfg)
|
|
|
- self.neck = SPPF(cfg, self.backbone.feat_dims[-1], self.backbone.feat_dims[-1])
|
|
|
+ self.pyramid_feat_dims = self.backbone.feat_dims[-3:]
|
|
|
+ ## Neck
|
|
|
+ self.neck = SPPF(cfg, self.pyramid_feat_dims[-1], self.pyramid_feat_dims[-1])
|
|
|
+ self.pyramid_feat_dims[-1] = self.neck.out_dim
|
|
|
+ ## Neck: PaFPN
|
|
|
self.fpn = Yolov8PaFPN(cfg, self.backbone.feat_dims)
|
|
|
+ ## Head
|
|
|
self.head = Yolov8DetHead(cfg, self.fpn.out_dims)
|
|
|
+ ## Pred
|
|
|
self.pred = Yolov8DetPredLayer(cfg, self.head.cls_head_dim, self.head.reg_head_dim)
|
|
|
|
|
|
def post_process(self, cls_preds, box_preds):
|