request_batching_server.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146
  1. import sys
  2. import asyncio
  3. import itertools
  4. import functools
  5. from sanic import Sanic
  6. from sanic.response import json, text
  7. from sanic.log import logger
  8. from sanic.exceptions import ServerError
  9. import sanic
  10. import threading
  11. import PIL.Image
  12. import io
  13. import torch
  14. import torchvision
  15. from .cyclegan import get_pretrained_model
  16. app = Sanic(__name__)
  17. device = torch.device('cpu')
  18. # we only run 1 inference run at any time (one could schedule between several runners if desired)
  19. MAX_QUEUE_SIZE = 3 # we accept a backlog of MAX_QUEUE_SIZE before handing out "Too busy" errors
  20. MAX_BATCH_SIZE = 2 # we put at most MAX_BATCH_SIZE things in a single batch
  21. MAX_WAIT = 1 # we wait at most MAX_WAIT seconds before running for more inputs to arrive in batching
  22. class HandlingError(Exception):
  23. def __init__(self, msg, code=500):
  24. super().__init__()
  25. self.handling_code = code
  26. self.handling_msg = msg
  27. class ModelRunner:
  28. def __init__(self, model_name):
  29. self.model_name = model_name
  30. self.queue = []
  31. self.queue_lock = None
  32. self.model = get_pretrained_model(self.model_name,
  33. map_location=device)
  34. self.needs_processing = None
  35. self.needs_processing_timer = None
  36. def schedule_processing_if_needed(self):
  37. if len(self.queue) >= MAX_BATCH_SIZE:
  38. logger.debug("next batch ready when processing a batch")
  39. self.needs_processing.set()
  40. elif self.queue:
  41. logger.debug("queue nonempty when processing a batch, setting next timer")
  42. self.needs_processing_timer = app.loop.call_at(self.queue[0]["time"] + MAX_WAIT, self.needs_processing.set)
  43. async def process_input(self, input):
  44. our_task = {"done_event": asyncio.Event(loop=app.loop),
  45. "input": input,
  46. "time": app.loop.time()}
  47. async with self.queue_lock:
  48. if len(self.queue) >= MAX_QUEUE_SIZE:
  49. raise HandlingError("I'm too busy", code=503)
  50. self.queue.append(our_task)
  51. logger.debug("enqueued task. new queue size {}".format(len(self.queue)))
  52. self.schedule_processing_if_needed()
  53. await our_task["done_event"].wait()
  54. return our_task["output"]
  55. def run_model(self, batch): # runs in other thread
  56. return self.model(batch.to(device)).to('cpu')
  57. async def model_runner(self):
  58. self.queue_lock = asyncio.Lock(loop=app.loop)
  59. self.needs_processing = asyncio.Event(loop=app.loop)
  60. logger.info("started model runner for {}".format(self.model_name))
  61. while True:
  62. await self.needs_processing.wait()
  63. self.needs_processing.clear()
  64. if self.needs_processing_timer is not None:
  65. self.needs_processing_timer.cancel()
  66. self.needs_processing_timer = None
  67. async with self.queue_lock:
  68. if self.queue:
  69. longest_wait = app.loop.time() - self.queue[0]["time"]
  70. else: # oops
  71. longest_wait = None
  72. logger.debug("launching processing. queue size: {}. longest wait: {}".format(len(self.queue), longest_wait))
  73. to_process = self.queue[:MAX_BATCH_SIZE]
  74. del self.queue[:len(to_process)]
  75. self.schedule_processing_if_needed()
  76. # so here we copy, it would be neater to avoid this
  77. batch = torch.stack([t["input"] for t in to_process], dim=0)
  78. # we could delete inputs here...
  79. result = await app.loop.run_in_executor(
  80. None, functools.partial(self.run_model, batch)
  81. )
  82. for t, r in zip(to_process, result):
  83. t["output"] = r
  84. t["done_event"].set()
  85. del to_process
  86. style_transfer_runner = ModelRunner(sys.argv[1])
  87. @app.route('/image', methods=['PUT'], stream=True)
  88. async def image(request):
  89. try:
  90. print (request.headers)
  91. content_length = int(request.headers.get('content-length', '0'))
  92. MAX_SIZE = 2**22 # 10MB
  93. if content_length:
  94. if content_length > MAX_SIZE:
  95. raise HandlingError("Too large")
  96. data = bytearray(content_length)
  97. else:
  98. data = bytearray(MAX_SIZE)
  99. pos = 0
  100. while True:
  101. # so this still copies too much stuff.
  102. data_part = await request.stream.read()
  103. if data_part is None:
  104. break
  105. data[pos: len(data_part) + pos] = data_part
  106. pos += len(data_part)
  107. if pos > MAX_SIZE:
  108. raise HandlingError("Too large")
  109. # ideally, we would minimize preprocessing...
  110. im = PIL.Image.open(io.BytesIO(data))
  111. im = torchvision.transforms.functional.resize(im, (228, 228))
  112. im = torchvision.transforms.functional.to_tensor(im)
  113. im = im[:3] # drop alpha channel if present
  114. if im.dim() != 3 or im.size(0) < 3 or im.size(0) > 4:
  115. raise HandlingError("need rgb image")
  116. out_im = await style_transfer_runner.process_input(im)
  117. out_im = torchvision.transforms.functional.to_pil_image(out_im)
  118. imgByteArr = io.BytesIO()
  119. out_im.save(imgByteArr, format='JPEG')
  120. return sanic.response.raw(imgByteArr.getvalue(), status=200,
  121. content_type='image/jpeg')
  122. except HandlingError as e:
  123. # we don't want these to be logged...
  124. return sanic.response.text(e.handling_msg, status=e.handling_code)
  125. app.add_task(style_transfer_runner.model_runner())
  126. app.run(host="0.0.0.0", port=8000,debug=True)