Skip to content

Commit

Permalink
[audio] Move WhisperTranslator to audio extension
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Mar 16, 2023
1 parent a9426cf commit 6343a36
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
package ai.djl.examples.inference.whisper;

import ai.djl.ModelException;
import ai.djl.audio.translator.WhisperTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.modality.audio.Audio;
import ai.djl.modality.audio.AudioFactory;
Expand All @@ -37,7 +38,7 @@ public WhisperModel() throws ModelException, IOException {
.optModelUrls(
"https://resources.djl.ai/demo/pytorch/whisper/whisper_en.zip")
.optEngine("PyTorch")
.optTranslator(new WhisperTranslator())
.optTranslatorFactory(new WhisperTranslatorFactory())
.build();
whisperModel = criteria.loadModel();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.examples.inference.whisper;
package ai.djl.audio.translator;

import ai.djl.audio.processor.AudioProcessor;
import ai.djl.audio.processor.LogMelSpectrogram;
Expand All @@ -22,6 +22,7 @@
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.JsonUtils;

Expand All @@ -37,11 +38,16 @@
import java.util.List;
import java.util.Map;

/**
* A {@link Translator} that process the {@link Audio} into {@link String} to get a text translation
* of the audio.
*/
public class WhisperTranslator implements NoBatchifyTranslator<Audio, String> {

private List<AudioProcessor> processors;
private Vocabulary vocabulary;

/** Constructs a new instance of {@code WhisperTranslator}. */
public WhisperTranslator() {
processors = new ArrayList<>();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.audio.translator;

import ai.djl.Model;
import ai.djl.modality.audio.Audio;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.Pair;

import java.io.Serializable;
import java.lang.reflect.Type;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;

/** A {@link TranslatorFactory} that creates a {@link WhisperTranslator} instance. */
public class WhisperTranslatorFactory implements TranslatorFactory, Serializable {

private static final long serialVersionUID = 1L;

private static final Set<Pair<Type, Type>> SUPPORTED_TYPES = new HashSet<>();

static {
SUPPORTED_TYPES.add(new Pair<>(Audio.class, String.class));
}

/** {@inheritDoc} */
@Override
public Set<Pair<Type, Type>> getSupportedTypes() {
return SUPPORTED_TYPES;
}

/** {@inheritDoc} */
@Override
@SuppressWarnings("unchecked")
public <I, O> Translator<I, O> newInstance(
Class<I> input, Class<O> output, Model model, Map<String, ?> arguments) {
if (input == Audio.class && output == String.class) {
return (Translator<I, O>) new WhisperTranslator();
}
throw new IllegalArgumentException("Unsupported input/output types.");
}
}

0 comments on commit 6343a36

Please sign in to comment.