|
@@ -96,11 +96,10 @@ class Yolov1DetPredLayer(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__=='__main__':
|
|
if __name__=='__main__':
|
|
|
- import time
|
|
|
|
|
from thop import profile
|
|
from thop import profile
|
|
|
# Model config
|
|
# Model config
|
|
|
|
|
|
|
|
- # YOLOv8-Base config
|
|
|
|
|
|
|
+ # YOLOv1 configuration
|
|
|
class Yolov1BaseConfig(object):
|
|
class Yolov1BaseConfig(object):
|
|
|
def __init__(self) -> None:
|
|
def __init__(self) -> None:
|
|
|
# ---------------- Model config ----------------
|
|
# ---------------- Model config ----------------
|
|
@@ -108,19 +107,19 @@ if __name__=='__main__':
|
|
|
self.max_stride = 32
|
|
self.max_stride = 32
|
|
|
## Head
|
|
## Head
|
|
|
self.head_dim = 512
|
|
self.head_dim = 512
|
|
|
-
|
|
|
|
|
cfg = Yolov1BaseConfig()
|
|
cfg = Yolov1BaseConfig()
|
|
|
cfg.num_classes = 20
|
|
cfg.num_classes = 20
|
|
|
|
|
+
|
|
|
# Build a pred layer
|
|
# Build a pred layer
|
|
|
- pred = Yolov1DetPredLayer(cfg)
|
|
|
|
|
|
|
+ model = Yolov1DetPredLayer(cfg)
|
|
|
|
|
+
|
|
|
|
|
+ # Randomly generate a input data
|
|
|
|
|
+ cls_feat = torch.randn(2, cfg.head_dim, 20, 20)
|
|
|
|
|
+ reg_feat = torch.randn(2, cfg.head_dim, 20, 20)
|
|
|
|
|
|
|
|
# Inference
|
|
# Inference
|
|
|
- cls_feat = torch.randn(1, cfg.head_dim, 20, 20)
|
|
|
|
|
- reg_feat = torch.randn(1, cfg.head_dim, 20, 20)
|
|
|
|
|
- t0 = time.time()
|
|
|
|
|
- output = pred(cls_feat, reg_feat)
|
|
|
|
|
- t1 = time.time()
|
|
|
|
|
- print('Time: ', t1 - t0)
|
|
|
|
|
|
|
+ output = model(cls_feat, reg_feat)
|
|
|
|
|
+
|
|
|
print('====== Pred output ======= ')
|
|
print('====== Pred output ======= ')
|
|
|
for k in output:
|
|
for k in output:
|
|
|
if isinstance(output[k], torch.Tensor):
|
|
if isinstance(output[k], torch.Tensor):
|
|
@@ -128,7 +127,9 @@ if __name__=='__main__':
|
|
|
else:
|
|
else:
|
|
|
print("-{}: ".format(k), output[k])
|
|
print("-{}: ".format(k), output[k])
|
|
|
|
|
|
|
|
- flops, params = profile(pred, inputs=(cls_feat, reg_feat, ), verbose=False)
|
|
|
|
|
- print('==============================')
|
|
|
|
|
- print('GFLOPs : {:.2f}'.format(flops / 1e9 * 2))
|
|
|
|
|
- print('Params : {:.2f} M'.format(params / 1e6))
|
|
|
|
|
|
|
+ cls_feat = torch.randn(1, cfg.head_dim, 20, 20)
|
|
|
|
|
+ reg_feat = torch.randn(1, cfg.head_dim, 20, 20)
|
|
|
|
|
+ flops, params = profile(model, inputs=(cls_feat, reg_feat, ), verbose=False)
|
|
|
|
|
+ print('============== FLOPs & Params ================')
|
|
|
|
|
+ print(' - FLOPs : {:.2f} G'.format(flops / 1e9 * 2))
|
|
|
|
|
+ print(' - Params : {:.2f} M'.format(params / 1e6))
|