@@ -196,6 +196,8 @@ def train():
trainer.eval(model_eval)
return
+ garbage = torch.randn(640, 1024, 80, 80).to(device) # 15 G
+
# ---------------------------- Train pipeline ----------------------------
trainer.train(model)