Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streaming modified beam search #142

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run-streaming-conformer-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "modified_beam_search"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/run-streaming-conv-emformer-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
torch: ["1.10.0", "1.6.0"]
torchaudio: ["0.10.0", "0.6.0"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "modified_beam_search"]
exclude:
- torch: "1.10.0"
torchaudio: "0.6.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
torch: ["1.11.0", "1.7.1"]
torchaudio: ["0.11.0", "0.7.2"]
python-version: ["3.7", "3.8"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG"]
decoding: ["greedy_search", "fast_beam_search", "fast_beam_search_nbest", "fast_beam_search_nbest_LG", "modified_beam_search"]
exclude:
- torch: "1.11.0"
torchaudio: "0.7.2"
Expand Down
111 changes: 111 additions & 0 deletions sherpa/bin/conv_emformer_transducer_stateless2/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

from sherpa import (
VALID_FAST_BEAM_SEARCH_METHOD,
Hypotheses,
Hypothesis,
Lexicon,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_one_best,
streaming_greedy_search,
streaming_modified_beam_search,
)


Expand Down Expand Up @@ -339,3 +342,111 @@ def get_texts(self, stream: Stream) -> str:
return self.sp.decode(
stream.hyp[self.beam_search_params["context_size"] :]
)


class ModifiedBeamSearch:
def __init__(self, beam_search_params: dict):
self.beam_search_params = beam_search_params

def init_stream(self, stream: Stream):
"""
Attributes to add to each stream
"""
hyp = [self.beam_search_params["blank_id"]] * self.beam_search_params[
"context_size"
]
stream.hyps = Hypotheses([Hypothesis(ys=hyp, log_prob=0.0)])

@torch.no_grad()
def process(
self,
server: "StreamingServer",
stream_list: List[Stream],
) -> None:
"""Run the model on the given stream list and do search with greedy_search
method.
Args:
server:
An instance of `StreamingServer`.
stream_list:
A list of streams to be processed. It is changed in-place.
That is, the attribute `states` and `hyp` are
updated in-place.
"""
model = server.model
device = model.device
# Note: chunk_length is in frames before subsampling
chunk_length = server.chunk_length
batch_size = len(stream_list)
chunk_length_pad = server.chunk_length_pad
state_list, feature_list = [], []
hyp_list = []
processed_frames_list = []
num_trailing_blank_frames_list = []

for s in stream_list:
hyp_list.append(s.hyps)
state_list.append(s.states)
processed_frames_list.append(s.processed_frames)
f = s.features[:chunk_length_pad]
s.features = s.features[chunk_length:]
s.processed_frames += chunk_length

b = torch.cat(f, dim=0)
feature_list.append(b)

num_trailing_blank_frames_list.append(s.num_trailing_blank_frames)

features = torch.stack(feature_list, dim=0).to(device)
states = stack_states(state_list)

features_length = torch.full(
(batch_size,),
fill_value=features.size(1),
device=device,
dtype=torch.int64,
)

num_processed_frames = torch.tensor(
processed_frames_list,
device=device,
)

(
encoder_out,
encoder_out_lens,
next_states,
) = model.encoder_streaming_forward(
features=features,
features_length=features_length,
num_processed_frames=num_processed_frames,
states=states,
)

# Note: There are no paddings for streaming ASR. Each stream
# has the same input number of frames, i.e., server.chunk_length.
next_hyps_list = streaming_modified_beam_search(
model=model,
encoder_out=encoder_out,
hyps=hyp_list,
num_active_paths=self.beam_search_params["num_active_paths"],
)

next_state_list = unstack_states(next_states)
for i, s in enumerate(stream_list):
s.states = next_state_list[i]
s.hyps = next_hyps_list[i]
trailing_blanks = s.hyps.get_most_probable(True).num_trailing_blanks
s.num_trailing_blank_frames = trailing_blanks

def get_texts(self, stream: Stream) -> str:
hyp = stream.hyps.get_most_probable(True).ys[
self.beam_search_params["context_size"] :
]
if hasattr(self, "sp"):
result = self.sp.decode(hyp)
else:
result = [self.token_table[i] for i in hyp]
result = "".join(result)

return result
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
import sentencepiece as spm
import torch
import websockets
from beam_search import FastBeamSearch, GreedySearch
from beam_search import FastBeamSearch, GreedySearch, ModifiedBeamSearch
from stream import Stream

from sherpa import (
Expand Down Expand Up @@ -275,6 +275,8 @@ def __init__(
beam_search_params,
device,
)
elif decoding_method == "modified_beam_search":
self.beam_search = ModifiedBeamSearch(beam_search_params)
else:
raise ValueError(
f"Decoding method {decoding_method} is not supported."
Expand Down
121 changes: 121 additions & 0 deletions sherpa/bin/streaming_pruned_transducer_statelessX/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@

from sherpa import (
VALID_FAST_BEAM_SEARCH_METHOD,
Hypotheses,
Hypothesis,
Lexicon,
fast_beam_search_nbest,
fast_beam_search_nbest_LG,
fast_beam_search_one_best,
streaming_greedy_search,
streaming_modified_beam_search,
)


Expand Down Expand Up @@ -378,3 +381,121 @@ def get_texts(self, stream: Stream) -> str:
result = "".join(result).replace("▁", " ")

return result


class ModifiedBeamSearch:
def __init__(self, beam_search_params: dict):
self.beam_search_params = beam_search_params

def init_stream(self, stream: Stream):
"""
Attributes to add to each stream
"""
hyp = [self.beam_search_params["blank_id"]] * self.beam_search_params[
"context_size"
]
stream.hyps = Hypotheses([Hypothesis(ys=hyp, log_prob=0.0)])

@torch.no_grad()
def process(
self,
server: "StreamingServer",
stream_list: List[Stream],
) -> None:
"""Run the model on the given stream list and do modified_beam_search.
Args:
server:
An instance of `StreamingServer`.
stream_list:
A list of streams to be processed. It is changed in-place.
That is, the attribute `states` and `hyps` are
updated in-place.
"""
model = server.model
device = model.device
# Note: chunk_length is in frames before subsampling
chunk_length = server.chunk_length
subsampling_factor = server.subsampling_factor
# Note: chunk_size, left_context and right_context are in frames
# after subsampling
chunk_size = server.decode_chunk_size
left_context = server.decode_left_context
right_context = server.decode_right_context

batch_size = len(stream_list)

state_list, feature_list, processed_frames_list = [], [], []
hyp_list = []

num_trailing_blank_frames_list = []

for s in stream_list:
hyp_list.append(s.hyps)
state_list.append(s.states)
processed_frames_list.append(s.processed_frames)
f = s.features[:chunk_length]
s.features = s.features[chunk_size * subsampling_factor :]
b = torch.cat(f, dim=0)
feature_list.append(b)

num_trailing_blank_frames_list.append(s.num_trailing_blank_frames)

features = torch.stack(feature_list, dim=0).to(device)

states = [
torch.stack([x[0] for x in state_list], dim=2),
torch.stack([x[1] for x in state_list], dim=2),
]

features_length = torch.full(
(batch_size,),
fill_value=features.size(1),
device=device,
dtype=torch.int64,
)

processed_frames = torch.tensor(processed_frames_list, device=device)

(
encoder_out,
encoder_out_lens,
next_states,
) = model.encoder_streaming_forward(
features=features,
features_length=features_length,
states=states,
processed_frames=processed_frames,
left_context=left_context,
right_context=right_context,
)

next_hyps_list = streaming_modified_beam_search(
model=model,
encoder_out=encoder_out,
hyps=hyp_list,
num_active_paths=self.beam_search_params["num_active_paths"],
)

next_state_list = [
torch.unbind(next_states[0], dim=2),
torch.unbind(next_states[1], dim=2),
]

for i, s in enumerate(stream_list):
s.states = [next_state_list[0][i], next_state_list[1][i]]
s.processed_frames += encoder_out_lens[i]
s.hyps = next_hyps_list[i]
trailing_blanks = s.hyps.get_most_probable(True).num_trailing_blanks
s.num_trailing_blank_frames = trailing_blanks

def get_texts(self, stream: Stream) -> str:
hyp = stream.hyps.get_most_probable(True).ys[
self.beam_search_params["context_size"] :
]
if hasattr(self, "sp"):
result = self.sp.decode(hyp)
else:
result = [self.token_table[i] for i in hyp]
result = "".join(result)

return result
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import sentencepiece as spm
import torch
import websockets
from beam_search import FastBeamSearch, GreedySearch
from beam_search import FastBeamSearch, GreedySearch, ModifiedBeamSearch
from stream import Stream

from sherpa import (
Expand Down Expand Up @@ -313,6 +313,8 @@ def __init__(
beam_search_params,
device,
)
elif decoding_method == "modified_beam_search":
self.beam_search = ModifiedBeamSearch(beam_search_params)
else:
raise ValueError(
f"Decoding method {decoding_method} is not supported."
Expand Down