Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add whisper tritonserver batch inference (#650) #683

Merged
merged 1 commit into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
args: [--max-line-length=80]

- repo: https://github.com/pycqa/isort
rev: 5.9.2
rev: 5.13.2
hooks:
- id: isort
args: [--profile=black, --line-length=80]
Expand Down
183 changes: 96 additions & 87 deletions triton/whisper/model_repo_whisper_trtllm/infer_bls/1/model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# -*- coding: utf-8 -*-
import triton_python_backend_utils as pb_utils
import numpy as np
# -*- coding: utf-8 -*-

import json
import torch
from torch.utils.dlpack import to_dlpack
import re
from .tokenizer import get_tokenizer
from collections import OrderedDict
from pathlib import Path

import numpy as np
import torch
import triton_python_backend_utils as pb_utils
from torch.utils.dlpack import to_dlpack

from .tokenizer import get_tokenizer


def read_config(component, engine_dir):
config_path = engine_dir / component / 'config.json'
with open(config_path, 'r') as f:
Expand All @@ -18,109 +22,114 @@ def read_config(component, engine_dir):
model_config.update(config['build_config'])
return model_config

class TritonPythonModel:
"""Your Python model must use the same class name. Every Python model
that is created must have "TritonPythonModel" as the class name.
"""

class TritonPythonModel:
def initialize(self, args):
"""`initialize` is called only once when the model is being loaded.
Implementing `initialize` function is optional. This function allows
the model to initialize any state associated with this model.

Parameters
----------
args : dict
Both keys and values are strings. The dictionary keys and values are:
* model_config: A JSON string containing the model configuration
* model_instance_kind: A string containing model instance kind
* model_instance_device_id: A string containing model instance device ID
* model_repository: Model repository path
* model_version: Model version
* model_name: Model name
"""
self.model_config = model_config = json.loads(args['model_config'])

# Get OUTPUT0 configuration
output0_config = pb_utils.get_output_config_by_name(
model_config, "TRANSCRIPTS")
# Convert Triton types to numpy types
self.out0_dtype = pb_utils.triton_string_to_numpy(
output0_config['data_type'])
encoder_config = read_config('encoder', Path(self.model_config['parameters']['engine_dir']["string_value"]))
self.tokenizer = get_tokenizer(num_languages=encoder_config['num_languages'])
self.blank = self.tokenizer.encode(" ", allowed_special=self.tokenizer.special_tokens_set)[0]

engine_dir = Path(
self.model_config['parameters']['engine_dir']["string_value"])
encoder_config = read_config('encoder', engine_dir)
self.tokenizer = get_tokenizer(
num_languages=encoder_config['num_languages']
)
self.blank = self.tokenizer.encode(
" ",
allowed_special=self.tokenizer.special_tokens_set
)[0]
self.device = torch.device("cuda")

def process_batch(self, wav, wav_len, prompt_id):
wav = torch.from_numpy(wav[0]).to(self.device)
wav_tensor = pb_utils.Tensor.from_dlpack("WAV", to_dlpack(wav.unsqueeze(0)))
wav_len_tensor = pb_utils.Tensor("WAV_LENS", np.array([[wav_len]], np.int32))
prompt_id = torch.tensor(prompt_id).unsqueeze(0)
def process_batch(self, wav_batch, wav_lens, prompt_id):
# Convert numpy arrays to torch tensors
wav_batch = torch.from_numpy(wav_batch).to(self.device)
wav_tensor = pb_utils.Tensor.from_dlpack(
"WAV",
to_dlpack(wav_batch)
)
wav_len_tensor = pb_utils.Tensor(
"WAV_LENS",
wav_lens.astype(np.int32)
)

# Replicate prompt_id for batch size
batch_size = wav_batch.shape[0]
prompt_ids = np.tile(prompt_id, (batch_size, 1))
prompt_ids_tensor = pb_utils.Tensor(
"DECODER_INPUT_IDS",
prompt_ids.astype(np.int32)
)

prompt_id = pb_utils.Tensor("DECODER_INPUT_IDS", prompt_id.numpy().astype(np.int32))
infer_request = pb_utils.InferenceRequest(
model_name="whisper",
requested_output_names=["OUTPUT_IDS"],
inputs=[wav_tensor, wav_len_tensor, prompt_id]
inputs=[wav_tensor, wav_len_tensor, prompt_ids_tensor]
)

inference_response = infer_request.exec()
if inference_response.has_error():
raise pb_utils.TritonModelException(inference_response.error().message())
else:
output_ids = pb_utils.get_output_tensor_by_name(inference_response, "OUTPUT_IDS")
return output_ids.as_numpy()

raise pb_utils.TritonModelException(
inference_response.error().message())

output_ids = pb_utils.get_output_tensor_by_name(
inference_response, "OUTPUT_IDS")
return output_ids.as_numpy()

def execute(self, requests):
"""`execute` must be implemented in every Python model. `execute`
function receives a list of pb_utils.InferenceRequest as the only
argument. This function is called when an inference is requested
for this model.

Parameters
----------
requests : list
A list of pb_utils.InferenceRequest

Returns
-------
list
A list of pb_utils.InferenceResponse. The length of this list must
be the same as `requests`
"""
# Every Python backend must iterate through list of requests and create
# an instance of pb_utils.InferenceResponse class for each of them. You
# should avoid storing any of the input Tensors in the class attributes
# as they will be overridden in subsequent inference requests. You can
# make a copy of the underlying NumPy array and store it if it is
# required.
responses = []

for request in requests:
# Perform inference on the request and append it to responses list...
in_0 = pb_utils.get_input_tensor_by_name(request, "TEXT_PREFIX")
prompt_ids = in_0.as_numpy().tolist()
prompt_ids = prompt_ids[0][0].decode('utf-8')
if prompt_ids == "":
prompt_ids = "<|startoftranscript|><|en|><|transcribe|><|notimestamps|>"
prompt_id = self.tokenizer.encode(prompt_ids, allowed_special=self.tokenizer.special_tokens_set)

wav = pb_utils.get_input_tensor_by_name(request, "WAV").as_numpy()
assert wav.shape[0] == 1, "Only support batch size 1 for now"
wav_len = pb_utils.get_input_tensor_by_name(request, "WAV_LENS").as_numpy()
wav_len = wav_len.item()

output_ids = self.process_batch(wav, wav_len, prompt_id)
s = self.tokenizer.decode(output_ids)
s = re.sub(r'<\|.*?\|>', '', s)
sentence = np.array([s])
out0 = pb_utils.Tensor("TRANSCRIPTS", sentence.astype(self.out0_dtype))
inference_response = pb_utils.InferenceResponse(output_tensors=[out0])
# Get batch inputs
text_prefix = pb_utils.get_input_tensor_by_name(
request, "TEXT_PREFIX").as_numpy()
wav_batch = pb_utils.get_input_tensor_by_name(
request, "WAV").as_numpy()
wav_lens = pb_utils.get_input_tensor_by_name(
request, "WAV_LENS").as_numpy()

# Use the same text_prefix for all items in the request
prefix = text_prefix[0][0].decode('utf-8')
if prefix == "":
prefix = (
"<|startoftranscript|><|ko|><|transcribe|><|notimestamps|>"
)
prompt_id = self.tokenizer.encode(
prefix,
allowed_special=self.tokenizer.special_tokens_set
)

# Process the entire batch
output_ids = self.process_batch(wav_batch, wav_lens, prompt_id)

# Decode outputs for each item in batch
transcripts = []

# Handle case where output_ids is 3-dimensional
# ([batch_size, beam_size, seq_len])
if len(output_ids.shape) == 3:
output_ids = output_ids[:, 0, :] # Remove beam_size dimension

for output_id in output_ids:
token_list = output_id.tolist()
s = self.tokenizer.decode(token_list)
s = re.sub(r'<\|.*?\|>', '', s)
transcripts.append(s)

# Create response tensor
out0 = pb_utils.Tensor(
"TRANSCRIPTS",
np.array(transcripts).astype(self.out0_dtype)
)
inference_response = pb_utils.InferenceResponse(
output_tensors=[out0]
)
responses.append(inference_response)

return responses

def finalize(self):
"""`finalize` is called only once when the model is being unloaded.
Implementing `finalize` function is optional. This function allows
the model to perform any necessary clean ups before exit.
"""
print('Cleaning up...')
124 changes: 74 additions & 50 deletions triton/whisper/model_repo_whisper_trtllm/whisper/1/model.py
100755 → 100644
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
Expand All @@ -24,78 +23,103 @@
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
from pathlib import Path

from .fbank import FeatureExtractor
import torch
from torch.utils.dlpack import from_dlpack

import triton_python_backend_utils as pb_utils
from tensorrt_llm.runtime import ModelRunnerCpp
from tensorrt_llm.bindings import GptJsonConfig
from tensorrt_llm.runtime import ModelRunnerCpp
from torch.utils.dlpack import from_dlpack

from .fbank import FeatureExtractor


class TritonPythonModel:
def initialize(self, args):
parameters = json.loads(args['model_config'])['parameters']
for key,value in parameters.items():
for key, value in parameters.items():
parameters[key] = value["string_value"]
engine_dir = parameters["engine_dir"]
json_config = GptJsonConfig.parse_file(Path(engine_dir) / 'decoder' / 'config.json')
config_path = Path(engine_dir) / 'decoder' / 'config.json'
json_config = GptJsonConfig.parse_file(config_path)
assert json_config.model_config.supports_inflight_batching
runner_kwargs = dict(engine_dir=engine_dir,
is_enc_dec=True,
max_batch_size=64,
max_input_len=3000,
max_output_len=96,
max_beam_width=1,
debug_mode=False,
kv_cache_free_gpu_memory_fraction=0.5)
runner_kwargs = dict(
engine_dir=engine_dir,
is_enc_dec=True,
max_batch_size=64,
max_input_len=3000,
max_output_len=96,
max_beam_width=1,
debug_mode=False,
kv_cache_free_gpu_memory_fraction=0.5,
)
self.model_runner_cpp = ModelRunnerCpp.from_dir(**runner_kwargs)
self.feature_extractor = FeatureExtractor(n_mels = int(parameters["n_mels"]))
self.zero_pad = True if parameters["zero_pad"] == "true" else False
self.feature_extractor = FeatureExtractor(
n_mels=int(parameters["n_mels"])
)
self.zero_pad = parameters["zero_pad"] == "true"
self.eot_id = 50257

def execute(self, requests):
"""
This function receives a list of requests (`pb_utils.InferenceRequest`),
performs inference on every request and appends it to responses.
"""
responses, batch_mel_list, decoder_input_ids = [], [], []
responses = []

for request in requests:
wav_tensor = pb_utils.get_input_tensor_by_name(request, "WAV")
wav_len = pb_utils.get_input_tensor_by_name(request, "WAV_LENS").as_numpy().item()
prompt_ids = pb_utils.get_input_tensor_by_name(request, "DECODER_INPUT_IDS").as_numpy()
wav_lens = pb_utils.get_input_tensor_by_name(
request, "WAV_LENS").as_numpy()
prompt_ids = pb_utils.get_input_tensor_by_name(
request, "DECODER_INPUT_IDS").as_numpy()

# Move WAV data to GPU
wav = from_dlpack(wav_tensor.to_dlpack())
wav = wav[:, :wav_len]
batch_size = wav.shape[0]

padding = 0 if self.zero_pad else 3000
mel = self.feature_extractor.compute_feature(wav[0].to('cuda'), padding_target_len=padding).transpose(1, 2)
batch_mel_list.append(mel.squeeze(0))
decoder_input_ids.append(torch.tensor(prompt_ids, dtype=torch.int32, device='cuda').squeeze(0))

decoder_input_ids = torch.nn.utils.rnn.pad_sequence(decoder_input_ids, batch_first=True, padding_value=self.eot_id)
mel_input_lengths = torch.tensor([mel.shape[0] for mel in batch_mel_list], dtype=torch.int32, device='cuda')

outputs = self.model_runner_cpp.generate(
batch_input_ids=decoder_input_ids,
encoder_input_features=batch_mel_list,
encoder_output_lengths=mel_input_lengths // 2,
max_new_tokens=96,
end_id=self.eot_id,
pad_id=self.eot_id,
num_beams=1,
output_sequence_lengths=True,
return_dict=True)
torch.cuda.synchronize()

output_ids = outputs['output_ids'].cpu().numpy()

for i, output_id in enumerate(output_ids):
batch_mel_list = []

# Batch processing for each sample in the batch
for i in range(batch_size):
wav_i = wav[i:i+1, :int(wav_lens[i].item())]
mel = self.feature_extractor.compute_feature(
wav_i[0].to('cuda'),
padding_target_len=padding
).transpose(1, 2)
batch_mel_list.append(mel.squeeze(0))

# Move prompt IDs to GPU
decoder_input_ids = torch.tensor(
prompt_ids, dtype=torch.int32, device='cuda')

# Calculate mel lengths
mel_input_lengths = torch.tensor(
[mel.shape[0] for mel in batch_mel_list],
dtype=torch.int32,
device='cuda'
)

# Run batch inference
outputs = self.model_runner_cpp.generate(
batch_input_ids=decoder_input_ids,
encoder_input_features=batch_mel_list,
encoder_output_lengths=mel_input_lengths // 2,
max_new_tokens=96,
end_id=self.eot_id,
pad_id=self.eot_id,
num_beams=1,
output_sequence_lengths=True,
return_dict=True
)
torch.cuda.synchronize()

# Process outputs
output_ids = outputs['output_ids'].cpu().numpy()

# Create response for the request
response = pb_utils.InferenceResponse(output_tensors=[
pb_utils.Tensor("OUTPUT_IDS", output_id[0])
pb_utils.Tensor("OUTPUT_IDS", output_ids)
])
responses.append(response)
assert len(responses) == len(requests)
return responses

return responses