flask_server.py 1016 B

12345678910111213141516171819202122232425262728293031323334353637
  1. import numpy as np
  2. import sys
  3. import os
  4. import torch
  5. from flask import Flask, request, jsonify
  6. import json
  7. from p2ch13.model_cls import LunaModel
  8. app = Flask(__name__)
  9. model = LunaModel()
  10. model.load_state_dict(torch.load(sys.argv[1],
  11. map_location='cpu')['model_state'])
  12. model.eval()
  13. def run_inference(in_tensor):
  14. with torch.no_grad():
  15. # LunaModel takes a batch and outputs a tuple (scores, probs)
  16. out_tensor = model(in_tensor.unsqueeze(0))[1].squeeze(0)
  17. probs = out_tensor.tolist()
  18. out = {'prob_malignant': probs[1]}
  19. return out
  20. @app.route("/predict", methods=["POST"])
  21. def predict():
  22. meta = json.load(request.files['meta'])
  23. blob = request.files['blob'].read()
  24. in_tensor = torch.from_numpy(np.frombuffer(
  25. blob, dtype=np.float32))
  26. in_tensor = in_tensor.view(*meta['shape'])
  27. out = run_inference(in_tensor)
  28. return jsonify(out)
  29. if __name__ == '__main__':
  30. app.run(host='0.0.0.0', port=8000)
  31. print (sys.argv[1])