|
|
@@ -156,3 +156,59 @@ class Yolov3DetPredLayer(nn.Module):
|
|
|
}
|
|
|
|
|
|
return outputs
|
|
|
+
|
|
|
+
|
|
|
+if __name__=='__main__':
|
|
|
+ import time
|
|
|
+ from thop import profile
|
|
|
+ # Model config
|
|
|
+
|
|
|
+ # YOLOv8-Base config
|
|
|
+ class Yolov3BaseConfig(object):
|
|
|
+ def __init__(self) -> None:
|
|
|
+ # ---------------- Model config ----------------
|
|
|
+ self.width = 1.0
|
|
|
+ self.depth = 1.0
|
|
|
+ self.out_stride = [8, 16, 32]
|
|
|
+ self.max_stride = 32
|
|
|
+ self.num_levels = 3
|
|
|
+ ## Head
|
|
|
+ self.head_dim = 256
|
|
|
+ self.anchor_size = {0: [[10, 13], [16, 30], [33, 23]],
|
|
|
+ 1: [[30, 61], [62, 45], [59, 119]],
|
|
|
+ 2: [[116, 90], [156, 198], [373, 326]]}
|
|
|
+
|
|
|
+ cfg = Yolov3BaseConfig()
|
|
|
+ cfg.num_classes = 20
|
|
|
+ # Build a pred layer
|
|
|
+ pred = Yolov3DetPredLayer(cfg)
|
|
|
+
|
|
|
+ # Inference
|
|
|
+ cls_feats = [torch.randn(1, cfg.head_dim, 80, 80),
|
|
|
+ torch.randn(1, cfg.head_dim, 40, 40),
|
|
|
+ torch.randn(1, cfg.head_dim, 20, 20),]
|
|
|
+ reg_feats = [torch.randn(1, cfg.head_dim, 80, 80),
|
|
|
+ torch.randn(1, cfg.head_dim, 40, 40),
|
|
|
+ torch.randn(1, cfg.head_dim, 20, 20),]
|
|
|
+ t0 = time.time()
|
|
|
+ output = pred(cls_feats, reg_feats)
|
|
|
+ t1 = time.time()
|
|
|
+ print('Time: ', t1 - t0)
|
|
|
+ print('====== Pred output ======= ')
|
|
|
+ pred_obj = output["pred_obj"]
|
|
|
+ pred_cls = output["pred_cls"]
|
|
|
+ pred_reg = output["pred_reg"]
|
|
|
+ pred_box = output["pred_box"]
|
|
|
+ anchors = output["anchors"]
|
|
|
+
|
|
|
+ for level in range(cfg.num_levels):
|
|
|
+ print("- Level-{} : objectness -> {}".format(level, pred_obj[level].shape))
|
|
|
+ print("- Level-{} : classification -> {}".format(level, pred_cls[level].shape))
|
|
|
+ print("- Level-{} : delta regression -> {}".format(level, pred_reg[level].shape))
|
|
|
+ print("- Level-{} : bbox regression -> {}".format(level, pred_box[level].shape))
|
|
|
+ print("- Level-{} : anchor boxes -> {}".format(level, anchors[level].shape))
|
|
|
+
|
|
|
+ flops, params = profile(pred, inputs=(cls_feats, reg_feats, ), verbose=False)
|
|
|
+ print('==============================')
|
|
|
+ print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
|
|
|
+ print('Params : {:.2f} M'.format(params / 1e6))
|