request_batching_jit_server.py 5.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139
  1. import sys
  2. import asyncio
  3. import itertools
  4. import functools
  5. from sanic import Sanic
  6. from sanic.response import json, text
  7. from sanic.log import logger
  8. from sanic.exceptions import ServerError
  9. import sanic
  10. import threading
  11. import PIL.Image
  12. import io
  13. import torch
  14. import torchvision
  15. app = Sanic(__name__)
  16. device = torch.device('cpu')
  17. # we only run 1 inference run at any time (one could schedule between several runners if desired)
  18. MAX_QUEUE_SIZE = 3 # we accept a backlog of MAX_QUEUE_SIZE before handing out "Too busy" errors
  19. MAX_BATCH_SIZE = 2 # we put at most MAX_BATCH_SIZE things in a single batch
  20. MAX_WAIT = 1 # we wait at most MAX_WAIT seconds before running for more inputs to arrive in batching
  21. class HandlingError(Exception):
  22. def __init__(self, msg, code=500):
  23. super().__init__()
  24. self.handling_code = code
  25. self.handling_msg = msg
  26. class ModelRunner:
  27. def __init__(self, model_name):
  28. self.model_name = model_name
  29. self.queue = []
  30. self.queue_lock = None
  31. self.model = torch.jit.load(self.model_name, map_location=device)
  32. self.needs_processing = None
  33. self.needs_processing_timer = None
  34. def schedule_processing_if_needed(self):
  35. if len(self.queue) >= MAX_BATCH_SIZE:
  36. logger.debug("next batch ready when processing a batch")
  37. self.needs_processing.set()
  38. elif self.queue:
  39. logger.debug("queue nonempty when processing a batch, setting next timer")
  40. self.needs_processing_timer = app.loop.call_at(self.queue[0]["time"] + MAX_WAIT, self.needs_processing.set)
  41. async def process_input(self, input):
  42. our_task = {"done_event": asyncio.Event(loop=app.loop),
  43. "input": input,
  44. "time": app.loop.time()}
  45. async with self.queue_lock:
  46. if len(self.queue) >= MAX_QUEUE_SIZE:
  47. raise HandlingError("I'm too busy", code=503)
  48. self.queue.append(our_task)
  49. logger.debug("enqueued task. new queue size {}".format(len(self.queue)))
  50. self.schedule_processing_if_needed()
  51. await our_task["done_event"].wait()
  52. return our_task["output"]
  53. def run_model(self, batch): # runs in other thread
  54. return self.model(batch.to(device)).to('cpu')
  55. async def model_runner(self):
  56. self.queue_lock = asyncio.Lock(loop=app.loop)
  57. self.needs_processing = asyncio.Event(loop=app.loop)
  58. logger.info("started model runner for {}".format(self.model_name))
  59. while True:
  60. await self.needs_processing.wait()
  61. self.needs_processing.clear()
  62. if self.needs_processing_timer is not None:
  63. self.needs_processing_timer.cancel()
  64. self.needs_processing_timer = None
  65. async with self.queue_lock:
  66. if self.queue:
  67. longest_wait = app.loop.time() - self.queue[0]["time"]
  68. else: # oops
  69. longest_wait = None
  70. logger.debug("launching processing. queue size: {}. longest wait: {}".format(len(self.queue), longest_wait))
  71. to_process = self.queue[:MAX_BATCH_SIZE]
  72. del self.queue[:len(to_process)]
  73. self.schedule_processing_if_needed()
  74. # so here we copy, it would be neater to avoid this
  75. batch = torch.stack([t["input"] for t in to_process], dim=0)
  76. # we could delete inputs here...
  77. result = await app.loop.run_in_executor(
  78. None, functools.partial(self.run_model, batch)
  79. )
  80. for t, r in zip(to_process, result):
  81. t["output"] = r
  82. t["done_event"].set()
  83. del to_process
  84. style_transfer_runner = ModelRunner(sys.argv[1])
  85. @app.route('/image', methods=['PUT'], stream=True)
  86. async def image(request):
  87. try:
  88. print (request.headers)
  89. content_length = int(request.headers.get('content-length', '0'))
  90. MAX_SIZE = 2**22 # 10MB
  91. if content_length:
  92. if content_length > MAX_SIZE:
  93. raise HandlingError("Too large")
  94. data = bytearray(content_length)
  95. else:
  96. data = bytearray(MAX_SIZE)
  97. pos = 0
  98. while True:
  99. # so this still copies too much stuff.
  100. data_part = await request.stream.read()
  101. if data_part is None:
  102. break
  103. data[pos: len(data_part) + pos] = data_part
  104. pos += len(data_part)
  105. if pos > MAX_SIZE:
  106. raise HandlingError("Too large")
  107. # ideally, we would minimize preprocessing...
  108. im = PIL.Image.open(io.BytesIO(data))
  109. im = torchvision.transforms.functional.resize(im, (228, 228))
  110. im = torchvision.transforms.functional.to_tensor(im)
  111. im = im[:3] # drop alpha channel if present
  112. if im.dim() != 3 or im.size(0) < 3 or im.size(0) > 4:
  113. raise HandlingError("need rgb image")
  114. out_im = await style_transfer_runner.process_input(im)
  115. out_im = torchvision.transforms.functional.to_pil_image(out_im)
  116. imgByteArr = io.BytesIO()
  117. out_im.save(imgByteArr, format='JPEG')
  118. return sanic.response.raw(imgByteArr.getvalue(), status=200,
  119. content_type='image/jpeg')
  120. except HandlingError as e:
  121. # we don't want these to be logged...
  122. return sanic.response.text(e.handling_msg, status=e.handling_code)
  123. app.add_task(style_transfer_runner.model_runner())
  124. app.run(host="0.0.0.0", port=8000,debug=True)