Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ebranchformer #1951

Merged
merged 5 commits into from
Mar 4, 2025
Merged

Ebranchformer #1951

merged 5 commits into from
Mar 4, 2025

Conversation

KarelVesely84
Copy link
Contributor

Hello Fangyun @csukuangfj ,
i did extend sherpa-onnx to support our EBranchformer encoder implementation
that we currently use widely at Brno University in Technology.

The EBranchformer code is based on Conformer model from transformers, but the internals are different:
https://github.com/BUTSpeechFIT/huggingface_asr/blob/streaming_karel/src/models/encoders/e_branchformer.py

This allows to pre-train the encoder with BestRQ alg. and then fine-tune with modified icefall.
We would like to deploy it as a production system for streaming ASR (it already works for me locally).

So for us it would be good to have the support directly inside sherpa-onnx,
so we can use the official sherpa-onnx builds.

On the other hand, it is a bit specific model, not yet widespread.
Would you agree on accepting this extension into the codebase ?
Or, should we rely on our cusom builds ?

The Encoder assumes little different preset of input features derived from
Speech2TextFeatureExtractor, hence newly surfacing the FBANK options:
normalize_samples and snip_edges.

Best regards,
Karel

- so ebranchformer feature extraction can be configured from Python
- the GlobCmvn is not needed, at it is a module in the OnnxEncoder
@csukuangfj
Copy link
Collaborator

Thanks! Will review it today.

Copy link
Collaborator

@csukuangfj csukuangfj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Looks great to me. Left only some minor comments

@@ -48,7 +48,9 @@ std::string FeatureExtractorConfig::ToString() const {
os << "feature_dim=" << feature_dim << ", ";
os << "low_freq=" << low_freq << ", ";
os << "high_freq=" << high_freq << ", ";
os << "dither=" << dither << ")";
os << "dither=" << dither << ", ";
os << "normalize_samples=" << normalize_samples << ", ";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
os << "normalize_samples=" << normalize_samples << ", ";
os << "normalize_samples=" << (normalize_samples ? "True" : "False" )<< ", ";

os << "dither=" << dither << ")";
os << "dither=" << dither << ", ";
os << "normalize_samples=" << normalize_samples << ", ";
os << "snip_edges=" << snip_edges << ")";
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
os << "snip_edges=" << snip_edges << ")";
os << "snip_edges=" << (snip_edges ? "True" : "False") << ")";

@@ -34,6 +51,9 @@ void PybindOnlineStream(py::module *m) {
py::arg("sample_rate"), py::arg("waveform"), kAcceptWaveformUsage,
py::call_guard<py::gil_scoped_release>())
.def("input_finished", &PyClass::InputFinished,
py::call_guard<py::gil_scoped_release>())
.def("get_frames", &PyClass::GetFrames,
py::arg("frame_index"), py::arg("n"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
py::arg("frame_index"), py::arg("n"),
py::arg("frame_index"), py::arg("n"), kGetFramesUsage

so that if you use help(OnlineStream.get_frames), you can view the help info in python.

@@ -92,6 +96,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
const auto &model_type = config.model_type;
if (model_type == "conformer") {
return std::make_unique<OnlineConformerTransducerModel>(config);
} else if (model_type == "ebranchformer") {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also update

if (model_type == "conformer") {
return std::make_unique<OnlineConformerTransducerModel>(mgr, config);

@@ -115,6 +121,8 @@ std::unique_ptr<OnlineTransducerModel> OnlineTransducerModel::Create(
switch (model_type) {
case ModelType::kConformer:
return std::make_unique<OnlineConformerTransducerModel>(config);
case ModelType::kEbranchformer:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also update

case ModelType::kConformer:
return std::make_unique<OnlineConformerTransducerModel>(mgr, config);

It is for Android and HarmonyOS.

@csukuangfj
Copy link
Collaborator

(Please ignore the failed CI tests.)

@KarelVesely84
Copy link
Contributor Author

KarelVesely84 commented Mar 4, 2025

ok, good, thank you for the feedback,
the remarks are now integrated into the PR code

Copy link
Collaborator

@csukuangfj csukuangfj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your contribution!

@csukuangfj csukuangfj merged commit 7740dbf into k2-fsa:master Mar 4, 2025
163 of 214 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants