From 6343a3638904bc50a1ee04822a0c351cf6abdf9d Mon Sep 17 00:00:00 2001 From: Xin Yang Date: Thu, 16 Mar 2023 14:55:10 -0700 Subject: [PATCH] [audio] Move WhisperTranslator to audio extension --- .../inference/whisper/WhisperModel.java | 3 +- .../audio/translator}/WhisperTranslator.java | 8 ++- .../translator/WhisperTranslatorFactory.java | 54 +++++++++++++++++++ 3 files changed, 63 insertions(+), 2 deletions(-) rename {examples/src/main/java/ai/djl/examples/inference/whisper => extensions/audio/src/main/java/ai/djl/audio/translator}/WhisperTranslator.java (93%) create mode 100644 extensions/audio/src/main/java/ai/djl/audio/translator/WhisperTranslatorFactory.java diff --git a/examples/src/main/java/ai/djl/examples/inference/whisper/WhisperModel.java b/examples/src/main/java/ai/djl/examples/inference/whisper/WhisperModel.java index 2cb7cb7a8e2..935a4b8c9c5 100644 --- a/examples/src/main/java/ai/djl/examples/inference/whisper/WhisperModel.java +++ b/examples/src/main/java/ai/djl/examples/inference/whisper/WhisperModel.java @@ -13,6 +13,7 @@ package ai.djl.examples.inference.whisper; import ai.djl.ModelException; +import ai.djl.audio.translator.WhisperTranslatorFactory; import ai.djl.inference.Predictor; import ai.djl.modality.audio.Audio; import ai.djl.modality.audio.AudioFactory; @@ -37,7 +38,7 @@ public WhisperModel() throws ModelException, IOException { .optModelUrls( "https://resources.djl.ai/demo/pytorch/whisper/whisper_en.zip") .optEngine("PyTorch") - .optTranslator(new WhisperTranslator()) + .optTranslatorFactory(new WhisperTranslatorFactory()) .build(); whisperModel = criteria.loadModel(); } diff --git a/examples/src/main/java/ai/djl/examples/inference/whisper/WhisperTranslator.java b/extensions/audio/src/main/java/ai/djl/audio/translator/WhisperTranslator.java similarity index 93% rename from examples/src/main/java/ai/djl/examples/inference/whisper/WhisperTranslator.java rename to extensions/audio/src/main/java/ai/djl/audio/translator/WhisperTranslator.java index 2116991b704..2736b8acc7f 100644 --- a/examples/src/main/java/ai/djl/examples/inference/whisper/WhisperTranslator.java +++ b/extensions/audio/src/main/java/ai/djl/audio/translator/WhisperTranslator.java @@ -10,7 +10,7 @@ * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions * and limitations under the License. */ -package ai.djl.examples.inference.whisper; +package ai.djl.audio.translator; import ai.djl.audio.processor.AudioProcessor; import ai.djl.audio.processor.LogMelSpectrogram; @@ -22,6 +22,7 @@ import ai.djl.ndarray.NDList; import ai.djl.ndarray.NDManager; import ai.djl.translate.NoBatchifyTranslator; +import ai.djl.translate.Translator; import ai.djl.translate.TranslatorContext; import ai.djl.util.JsonUtils; @@ -37,11 +38,16 @@ import java.util.List; import java.util.Map; +/** + * A {@link Translator} that process the {@link Audio} into {@link String} to get a text translation + * of the audio. + */ public class WhisperTranslator implements NoBatchifyTranslator { private List processors; private Vocabulary vocabulary; + /** Constructs a new instance of {@code WhisperTranslator}. */ public WhisperTranslator() { processors = new ArrayList<>(); } diff --git a/extensions/audio/src/main/java/ai/djl/audio/translator/WhisperTranslatorFactory.java b/extensions/audio/src/main/java/ai/djl/audio/translator/WhisperTranslatorFactory.java new file mode 100644 index 00000000000..4af42e249d0 --- /dev/null +++ b/extensions/audio/src/main/java/ai/djl/audio/translator/WhisperTranslatorFactory.java @@ -0,0 +1,54 @@ +/* + * Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance + * with the License. A copy of the License is located at + * + * http://aws.amazon.com/apache2.0/ + * + * or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES + * OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions + * and limitations under the License. + */ +package ai.djl.audio.translator; + +import ai.djl.Model; +import ai.djl.modality.audio.Audio; +import ai.djl.translate.Translator; +import ai.djl.translate.TranslatorFactory; +import ai.djl.util.Pair; + +import java.io.Serializable; +import java.lang.reflect.Type; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +/** A {@link TranslatorFactory} that creates a {@link WhisperTranslator} instance. */ +public class WhisperTranslatorFactory implements TranslatorFactory, Serializable { + + private static final long serialVersionUID = 1L; + + private static final Set> SUPPORTED_TYPES = new HashSet<>(); + + static { + SUPPORTED_TYPES.add(new Pair<>(Audio.class, String.class)); + } + + /** {@inheritDoc} */ + @Override + public Set> getSupportedTypes() { + return SUPPORTED_TYPES; + } + + /** {@inheritDoc} */ + @Override + @SuppressWarnings("unchecked") + public Translator newInstance( + Class input, Class output, Model model, Map arguments) { + if (input == Audio.class && output == String.class) { + return (Translator) new WhisperTranslator(); + } + throw new IllegalArgumentException("Unsupported input/output types."); + } +}