Skip to content

Commit

Permalink
Add implementation of inference engine (mlc-ai#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
yelite authored Oct 10, 2023
1 parent fc9876d commit a2cc226
Show file tree
Hide file tree
Showing 6 changed files with 597 additions and 10 deletions.
13 changes: 12 additions & 1 deletion serve/mlc_serve/engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,12 @@
from .types import (RequestId, Request, TextGenerationOutput, TextGenerationError, InferenceStepResult, InferenceEngine)
from .types import (
InferenceEngine,
InferenceStepResult,
Request,
RequestId,
SamplingParams,
SequenceGenerationRequest,
SequenceGenerationResponse,
StoppingCriteria,
TextGenerationError,
TextGenerationOutput,
)
71 changes: 71 additions & 0 deletions serve/mlc_serve/engine/dummy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from threading import Event, Lock
from typing import Dict
from uuid import uuid4

from .types import InferenceStepResult, Request, RequestId, TextGenerationOutput


class DummyInferenceEngine:
def __init__(self):
self.queue_lock = Lock()
self.has_requests = Event()
self.request_queue: Dict[RequestId, int] = {}

def add(self, requests: list[Request]) -> list[RequestId]:
ids = []
requests_to_add = {}

for r in requests:
request_id = str(uuid4())
ids.append(request_id)
requests_to_add[request_id] = 5

with self.queue_lock:
self.request_queue.update(requests_to_add)
self.has_requests.set()

return ids

def cancel(self, request_id: RequestId):
"""
Cancel the generation of a request.
"""
with self.queue_lock:
del self.request_queue[request_id]
if not self.request_queue:
self.has_requests.clear()

def step(self) -> InferenceStepResult:
"""
InferenceResult contains the next token for processed results,
and indicates whether the generation for a request is finished.
It's up to the InferenceEngine to choose which requests
to work on, while it should be guaranteed that all requests will be
processed eventually.
If the engine has no requests in the queue, `step` will block until there is
request coming in.
"""
result = InferenceStepResult(outputs=[], errors=[])

self.has_requests.wait()

with self.queue_lock:
for request_id, n in list(self.request_queue.items()):
result.outputs.append(
TextGenerationOutput(
request_id=request_id,
delta=" test",
finish_reason="length" if n == 1 else None,
)
)
if n == 1:
del self.request_queue[request_id]
else:
self.request_queue[request_id] -= 1

if not self.request_queue:
self.has_requests.clear()

return result
204 changes: 204 additions & 0 deletions serve/mlc_serve/engine/local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
"""
A implementation of InferenceEngine that executes in the current process.
"""

from collections import deque
from dataclasses import dataclass
from threading import Condition, Lock
from uuid import uuid4

from .types import (
InferenceStepResult,
ModelExecutor,
Request,
RequestId,
SamplingParams,
SequenceGenerationRequest,
SequenceGenerationResponse,
StoppingCriteria,
TextGenerationError,
TextGenerationOutput,
Tokenizer,
)


@dataclass
class RequestState:
request_id: RequestId
token_ids: list[int]
output_text: str
prompt_len: int
next_start_position: int
sampling_params: SamplingParams
stopping_criteria: StoppingCriteria


class LocalProcessInferenceEngine:
def __init__(self, executor: ModelExecutor, tokenizer: Tokenizer):
self.queue_lock = Lock()
self.queue = deque[RequestState]()
self.has_new_requests = Condition(lock=self.queue_lock)
self.requests_to_be_cancelled = set[RequestId]()

self.current_batch = dict[RequestId, RequestState]()

self.executor = executor
self.tokenizer = tokenizer

def add(self, requests: list[Request]) -> list[RequestId]:
if not requests:
return []

new_request_states = []
for req in requests:
state = self._get_new_request_state(req)
new_request_states.append(state)

with self.queue_lock:
self.queue.extend(new_request_states)
self.has_new_requests.notify_all()

return [s.request_id for s in new_request_states]

def cancel(self, request_id: RequestId):
with self.queue_lock:
self.requests_to_be_cancelled.add(request_id)

def step(self) -> InferenceStepResult:
outputs = list[TextGenerationOutput]()
errors = list[TextGenerationError]()

previous_requests_to_be_cancelled = set(self.requests_to_be_cancelled)
self._adjust_batch()

for request_id in previous_requests_to_be_cancelled:
if request_id not in self.requests_to_be_cancelled:
outputs.append(
TextGenerationOutput(
request_id=request_id,
delta="",
finish_reason="cancelled",
)
)

requests = [
SequenceGenerationRequest(
request_id=state.request_id,
token_ids=state.token_ids[state.next_start_position :],
start_position=state.next_start_position,
sampling_params=state.sampling_params,
)
for state in self.current_batch.values()
]

for req in requests:
if req.start_position > 0:
self.executor.extend(req.request_id, len(req.token_ids))
responses = self.executor.generate(requests)

for res in responses:
if res.error is not None:
errors.append(
TextGenerationError(res.request_id, "GenerationError", res.error)
)
del self.current_batch[res.request_id]
continue

state = self.current_batch[res.request_id]
state.next_start_position = len(state.token_ids)
state.token_ids.extend(res.token_ids)

delta = self._decode_last_output(state)
state.output_text += delta

output = TextGenerationOutput(res.request_id, delta)
if self._should_stop_by_length(state):
output.finish_reason = "length"
self.current_batch.pop(state.request_id)
self.executor.free(state.request_id)

outputs.append(output)

return InferenceStepResult(outputs=outputs, errors=errors)

def _adjust_batch(self):
with self.queue_lock:
for request_id in list(self.requests_to_be_cancelled):
if request_id in self.current_batch:
state = self.current_batch.pop(request_id)
self.executor.free(state.request_id)
self.requests_to_be_cancelled.remove(request_id)

while self.executor.get_max_new_tokens() < 1:
request_to_remove = min(
self.current_batch.values(), key=lambda s: len(s.token_ids)
)
del self.current_batch[request_to_remove.request_id]
self.executor.free(request_to_remove.request_id)
self.queue.appendleft(request_to_remove)

while True:
self._discard_cancelled_requests_from_queue()
if len(self.queue) != 0 or len(self.current_batch) != 0:
break
self.has_new_requests.wait()

if not self._should_process_new_request():
return

# TODO: make this 15 into config
while self.queue and self.executor.get_max_new_tokens() > 15:
state = self.queue[0]
num_tokens = len(state.token_ids)
if self.executor.get_free_space() <= 1.5 * num_tokens:
break

self.queue.popleft()
self.executor.allocate(state.request_id, num_tokens)
self.current_batch[state.request_id] = state

self._discard_cancelled_requests_from_queue()

def _should_process_new_request(self):
return self.executor.get_free_space() * 1.6 > self.executor.get_kv_cache_size()

def _discard_cancelled_requests_from_queue(self):
"""
Requires the self.queue_lock to be held before calling this function.
"""
while self.queue and self.queue[0].request_id in self.requests_to_be_cancelled:
state = self.queue.popleft()
self.requests_to_be_cancelled.remove(state.request_id)

def _get_new_request_state(self, request: Request) -> RequestState:
request_id = str(uuid4())

prompt_tokens = self.tokenizer.encode(request.prompt)

return RequestState(
request_id=request_id,
token_ids=prompt_tokens,
prompt_len=len(prompt_tokens),
next_start_position=0,
sampling_params=request.sampling_params,
stopping_criteria=request.stopping_criteria,
output_text="",
)

def _decode_last_output(self, state: RequestState) -> str:
prefix_idx = max(0, state.next_start_position - 6)
if prefix_idx == 0:
return self.tokenizer.decode(state.token_ids)

prefix = self.tokenizer.decode(state.token_ids[prefix_idx:])
full = self.tokenizer.decode(state.token_ids[state.next_start_position :])

return full[len(prefix) :]

def _should_stop_by_length(self, state: RequestState) -> bool:
# TODO: put to config
max_tokens = 4096
if state.stopping_criteria.max_tokens is not None:
max_tokens = min(max_tokens, state.stopping_criteria.max_tokens)

return len(state.token_ids) - state.prompt_len >= max_tokens
Loading

0 comments on commit a2cc226

Please sign in to comment.