Skip to content

Commit 47d9d51

Browse files
committed
dev: add batching to litserve app
1 parent c0672e3 commit 47d9d51

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

litserve-app.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,15 @@ def setup(self, device):
1717
def decode_request(self, request):
1818
image_bytes = bytes.fromhex(request["image_bytes"])
1919
image = Image.open(BytesIO(image_bytes))
20-
return self.processor(image, return_tensors="pt").to(self.device)
20+
image_tensor = self.processor(image, return_tensors="pt")["pixel_values"]
21+
return image_tensor.to(self.device)
22+
23+
def batch(self, inputs):
24+
return torch.cat(inputs, dim=0)
2125

2226
def predict(self, inputs):
2327
with torch.no_grad():
24-
logits = self.model(**inputs).logits
28+
logits = self.model(inputs).logits
2529
return logits
2630

2731
def encode_response(self, logits):
@@ -32,5 +36,5 @@ def encode_response(self, logits):
3236

3337
if __name__ == "__main__":
3438
api = ResNetLitAPI()
35-
server = ls.LitServer(api, accelerator="gpu")
39+
server = ls.LitServer(api, accelerator="gpu", max_batch_size=8, batch_timeout=0.05)
3640
server.run(port=8000)

0 commit comments

Comments
 (0)