forked from mlc-ai/mlc-llm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add implementation of inference engine (mlc-ai#14)
- Loading branch information
Showing
6 changed files
with
597 additions
and
10 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,12 @@ | ||
from .types import (RequestId, Request, TextGenerationOutput, TextGenerationError, InferenceStepResult, InferenceEngine) | ||
from .types import ( | ||
InferenceEngine, | ||
InferenceStepResult, | ||
Request, | ||
RequestId, | ||
SamplingParams, | ||
SequenceGenerationRequest, | ||
SequenceGenerationResponse, | ||
StoppingCriteria, | ||
TextGenerationError, | ||
TextGenerationOutput, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
from threading import Event, Lock | ||
from typing import Dict | ||
from uuid import uuid4 | ||
|
||
from .types import InferenceStepResult, Request, RequestId, TextGenerationOutput | ||
|
||
|
||
class DummyInferenceEngine: | ||
def __init__(self): | ||
self.queue_lock = Lock() | ||
self.has_requests = Event() | ||
self.request_queue: Dict[RequestId, int] = {} | ||
|
||
def add(self, requests: list[Request]) -> list[RequestId]: | ||
ids = [] | ||
requests_to_add = {} | ||
|
||
for r in requests: | ||
request_id = str(uuid4()) | ||
ids.append(request_id) | ||
requests_to_add[request_id] = 5 | ||
|
||
with self.queue_lock: | ||
self.request_queue.update(requests_to_add) | ||
self.has_requests.set() | ||
|
||
return ids | ||
|
||
def cancel(self, request_id: RequestId): | ||
""" | ||
Cancel the generation of a request. | ||
""" | ||
with self.queue_lock: | ||
del self.request_queue[request_id] | ||
if not self.request_queue: | ||
self.has_requests.clear() | ||
|
||
def step(self) -> InferenceStepResult: | ||
""" | ||
InferenceResult contains the next token for processed results, | ||
and indicates whether the generation for a request is finished. | ||
It's up to the InferenceEngine to choose which requests | ||
to work on, while it should be guaranteed that all requests will be | ||
processed eventually. | ||
If the engine has no requests in the queue, `step` will block until there is | ||
request coming in. | ||
""" | ||
result = InferenceStepResult(outputs=[], errors=[]) | ||
|
||
self.has_requests.wait() | ||
|
||
with self.queue_lock: | ||
for request_id, n in list(self.request_queue.items()): | ||
result.outputs.append( | ||
TextGenerationOutput( | ||
request_id=request_id, | ||
delta=" test", | ||
finish_reason="length" if n == 1 else None, | ||
) | ||
) | ||
if n == 1: | ||
del self.request_queue[request_id] | ||
else: | ||
self.request_queue[request_id] -= 1 | ||
|
||
if not self.request_queue: | ||
self.has_requests.clear() | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,204 @@ | ||
""" | ||
A implementation of InferenceEngine that executes in the current process. | ||
""" | ||
|
||
from collections import deque | ||
from dataclasses import dataclass | ||
from threading import Condition, Lock | ||
from uuid import uuid4 | ||
|
||
from .types import ( | ||
InferenceStepResult, | ||
ModelExecutor, | ||
Request, | ||
RequestId, | ||
SamplingParams, | ||
SequenceGenerationRequest, | ||
SequenceGenerationResponse, | ||
StoppingCriteria, | ||
TextGenerationError, | ||
TextGenerationOutput, | ||
Tokenizer, | ||
) | ||
|
||
|
||
@dataclass | ||
class RequestState: | ||
request_id: RequestId | ||
token_ids: list[int] | ||
output_text: str | ||
prompt_len: int | ||
next_start_position: int | ||
sampling_params: SamplingParams | ||
stopping_criteria: StoppingCriteria | ||
|
||
|
||
class LocalProcessInferenceEngine: | ||
def __init__(self, executor: ModelExecutor, tokenizer: Tokenizer): | ||
self.queue_lock = Lock() | ||
self.queue = deque[RequestState]() | ||
self.has_new_requests = Condition(lock=self.queue_lock) | ||
self.requests_to_be_cancelled = set[RequestId]() | ||
|
||
self.current_batch = dict[RequestId, RequestState]() | ||
|
||
self.executor = executor | ||
self.tokenizer = tokenizer | ||
|
||
def add(self, requests: list[Request]) -> list[RequestId]: | ||
if not requests: | ||
return [] | ||
|
||
new_request_states = [] | ||
for req in requests: | ||
state = self._get_new_request_state(req) | ||
new_request_states.append(state) | ||
|
||
with self.queue_lock: | ||
self.queue.extend(new_request_states) | ||
self.has_new_requests.notify_all() | ||
|
||
return [s.request_id for s in new_request_states] | ||
|
||
def cancel(self, request_id: RequestId): | ||
with self.queue_lock: | ||
self.requests_to_be_cancelled.add(request_id) | ||
|
||
def step(self) -> InferenceStepResult: | ||
outputs = list[TextGenerationOutput]() | ||
errors = list[TextGenerationError]() | ||
|
||
previous_requests_to_be_cancelled = set(self.requests_to_be_cancelled) | ||
self._adjust_batch() | ||
|
||
for request_id in previous_requests_to_be_cancelled: | ||
if request_id not in self.requests_to_be_cancelled: | ||
outputs.append( | ||
TextGenerationOutput( | ||
request_id=request_id, | ||
delta="", | ||
finish_reason="cancelled", | ||
) | ||
) | ||
|
||
requests = [ | ||
SequenceGenerationRequest( | ||
request_id=state.request_id, | ||
token_ids=state.token_ids[state.next_start_position :], | ||
start_position=state.next_start_position, | ||
sampling_params=state.sampling_params, | ||
) | ||
for state in self.current_batch.values() | ||
] | ||
|
||
for req in requests: | ||
if req.start_position > 0: | ||
self.executor.extend(req.request_id, len(req.token_ids)) | ||
responses = self.executor.generate(requests) | ||
|
||
for res in responses: | ||
if res.error is not None: | ||
errors.append( | ||
TextGenerationError(res.request_id, "GenerationError", res.error) | ||
) | ||
del self.current_batch[res.request_id] | ||
continue | ||
|
||
state = self.current_batch[res.request_id] | ||
state.next_start_position = len(state.token_ids) | ||
state.token_ids.extend(res.token_ids) | ||
|
||
delta = self._decode_last_output(state) | ||
state.output_text += delta | ||
|
||
output = TextGenerationOutput(res.request_id, delta) | ||
if self._should_stop_by_length(state): | ||
output.finish_reason = "length" | ||
self.current_batch.pop(state.request_id) | ||
self.executor.free(state.request_id) | ||
|
||
outputs.append(output) | ||
|
||
return InferenceStepResult(outputs=outputs, errors=errors) | ||
|
||
def _adjust_batch(self): | ||
with self.queue_lock: | ||
for request_id in list(self.requests_to_be_cancelled): | ||
if request_id in self.current_batch: | ||
state = self.current_batch.pop(request_id) | ||
self.executor.free(state.request_id) | ||
self.requests_to_be_cancelled.remove(request_id) | ||
|
||
while self.executor.get_max_new_tokens() < 1: | ||
request_to_remove = min( | ||
self.current_batch.values(), key=lambda s: len(s.token_ids) | ||
) | ||
del self.current_batch[request_to_remove.request_id] | ||
self.executor.free(request_to_remove.request_id) | ||
self.queue.appendleft(request_to_remove) | ||
|
||
while True: | ||
self._discard_cancelled_requests_from_queue() | ||
if len(self.queue) != 0 or len(self.current_batch) != 0: | ||
break | ||
self.has_new_requests.wait() | ||
|
||
if not self._should_process_new_request(): | ||
return | ||
|
||
# TODO: make this 15 into config | ||
while self.queue and self.executor.get_max_new_tokens() > 15: | ||
state = self.queue[0] | ||
num_tokens = len(state.token_ids) | ||
if self.executor.get_free_space() <= 1.5 * num_tokens: | ||
break | ||
|
||
self.queue.popleft() | ||
self.executor.allocate(state.request_id, num_tokens) | ||
self.current_batch[state.request_id] = state | ||
|
||
self._discard_cancelled_requests_from_queue() | ||
|
||
def _should_process_new_request(self): | ||
return self.executor.get_free_space() * 1.6 > self.executor.get_kv_cache_size() | ||
|
||
def _discard_cancelled_requests_from_queue(self): | ||
""" | ||
Requires the self.queue_lock to be held before calling this function. | ||
""" | ||
while self.queue and self.queue[0].request_id in self.requests_to_be_cancelled: | ||
state = self.queue.popleft() | ||
self.requests_to_be_cancelled.remove(state.request_id) | ||
|
||
def _get_new_request_state(self, request: Request) -> RequestState: | ||
request_id = str(uuid4()) | ||
|
||
prompt_tokens = self.tokenizer.encode(request.prompt) | ||
|
||
return RequestState( | ||
request_id=request_id, | ||
token_ids=prompt_tokens, | ||
prompt_len=len(prompt_tokens), | ||
next_start_position=0, | ||
sampling_params=request.sampling_params, | ||
stopping_criteria=request.stopping_criteria, | ||
output_text="", | ||
) | ||
|
||
def _decode_last_output(self, state: RequestState) -> str: | ||
prefix_idx = max(0, state.next_start_position - 6) | ||
if prefix_idx == 0: | ||
return self.tokenizer.decode(state.token_ids) | ||
|
||
prefix = self.tokenizer.decode(state.token_ids[prefix_idx:]) | ||
full = self.tokenizer.decode(state.token_ids[state.next_start_position :]) | ||
|
||
return full[len(prefix) :] | ||
|
||
def _should_stop_by_length(self, state: RequestState) -> bool: | ||
# TODO: put to config | ||
max_tokens = 4096 | ||
if state.stopping_criteria.max_tokens is not None: | ||
max_tokens = min(max_tokens, state.stopping_criteria.max_tokens) | ||
|
||
return len(state.token_ids) - state.prompt_len >= max_tokens |
Oops, something went wrong.