Skip to content

Commit

Permalink
[spark] Add audio predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
xyang16 committed Mar 27, 2023
1 parent 6343a36 commit ba29c2a
Show file tree
Hide file tree
Showing 8 changed files with 220 additions and 5 deletions.
14 changes: 12 additions & 2 deletions docker/spark/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,32 +16,42 @@ LABEL maintainer="djl-dev@amazon.com"
USER root
ARG DJL_VERSION=0.21.0
ARG JNA_VERSION=5.12.1
ARG JAVACV_VERSION=1.5.8
ARG JAVACPP_VERSION=1.5.8
ARG FFMPEG_VERSION=5.1.2-1.5.8
ARG TENSORFLOW_CORE_VERSION=0.4.2
ARG PROTOBUF_VERSION=3.21.9

COPY extensions/spark/setup/dist/ dist/
RUN pip3 install --no-cache-dir dist/djl_spark-*-py3-none-any.whl && \
rm -rf dist
RUN pip3 install --no-cache-dir pillow pandas numpy
RUN pip3 install --no-cache-dir pillow pandas numpy pyarrow

ADD https://repo1.maven.org/maven2/ai/djl/api/${DJL_VERSION}/api-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-api-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/ai/djl/spark/spark/${DJL_VERSION}/spark-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-spark-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/ai/djl/huggingface/tokenizers/${DJL_VERSION}/tokenizers-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tokenizers-${DJL_VERSION}.jar

ADD https://repo1.maven.org/maven2/ai/djl/audio/audio/${DJL_VERSION}/audio-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-audio-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/org/bytedeco/javacv/${JAVACV_VERSION}/javacv-${JAVACV_VERSION}.jar /usr/lib/spark/jars/
ADD https://repo1.maven.org/maven2/org/bytedeco/javacpp/${JAVACPP_VERSION}/javacpp-${JAVACPP_VERSION}.jar /usr/lib/spark/jars/
ADD https://repo1.maven.org/maven2/org/bytedeco/ffmpeg/${FFMPEG_VERSION}/ffmpeg-${FFMPEG_VERSION}.jar /usr/lib/spark/jars/
ADD https://repo1.maven.org/maven2/org/bytedeco/ffmpeg/${FFMPEG_VERSION}/ffmpeg-${FFMPEG_VERSION}-linux-x86_64.jar /usr/lib/spark/jars/
ADD https://repo1.maven.org/maven2/org/bytedeco/ffmpeg-platform/${FFMPEG_VERSION}/ffmpeg-platform-${FFMPEG_VERSION}.jar /usr/lib/spark/jars/

ADD https://repo1.maven.org/maven2/ai/djl/pytorch/pytorch-engine/${DJL_VERSION}/pytorch-engine-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-pytorch-engine-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/ai/djl/pytorch/pytorch-model-zoo/${DJL_VERSION}/pytorch-model-zoo-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-model-zoo-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/net/java/dev/jna/jna/${JNA_VERSION}/jna-${JNA_VERSION}.jar /usr/lib/spark/jars/

ADD https://repo1.maven.org/maven2/ai/djl/tensorflow/tensorflow-engine/${DJL_VERSION}/tensorflow-engine-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tensorflow-engine-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/ai/djl/tensorflow/tensorflow-model-zoo/${DJL_VERSION}/tensorflow-model-zoo-${DJL_VERSION}.jar /usr/lib/spark/jars/djl-tensorflow-model-zoo-${DJL_VERSION}.jar
ADD https://repo1.maven.org/maven2/org/bytedeco/javacpp/${JAVACPP_VERSION}/javacpp-${JAVACPP_VERSION}.jar /usr/lib/spark/jars/
ADD https://repo1.maven.org/maven2/org/tensorflow/tensorflow-core-api/${TENSORFLOW_CORE_VERSION}/tensorflow-core-api-${TENSORFLOW_CORE_VERSION}.jar /usr/lib/spark/jars/
RUN rm /usr/lib/spark/jars/protobuf-java-*.jar
ADD https://repo1.maven.org/maven2/com/google/protobuf/protobuf-java/${PROTOBUF_VERSION}/protobuf-java-${PROTOBUF_VERSION}.jar /usr/lib/spark/jars/

RUN chmod -R +r /usr/lib/spark/jars/

RUN echo 'export SPARK_JAVA_OPTS="$SPARK_JAVA_OPTS -Dai.djl.pytorch.graph_optimizer=false"' >> /usr/lib/spark/conf/spark-env.sh

# Set environment
ENV PYTORCH_PRECXX11 true
ENV OMP_NUM_THREADS 1
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
/*
* Copyright 2022 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/

/** Contains translators for audio processing. */
package ai.djl.audio.translator;
1 change: 1 addition & 0 deletions extensions/spark/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ group "ai.djl.spark"
dependencies {
api project(":api")
api project(":extensions:tokenizers")
api project(":extensions:audio")
api "org.apache.spark:spark-core_2.12:${spark_version}"
api "org.apache.spark:spark-sql_2.12:${spark_version}"
api "org.apache.spark:spark-mllib_2.12:${spark_version}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,11 +86,12 @@ abstract class BasePredictor[A, B](override val uid: String) extends Transformer

setDefault(engine, null)
setDefault(modelUrl, null)
setDefault(batchifier, "none")

/** @inheritdoc */
override def transform(dataset: Dataset[_]): DataFrame = {
arguments.put("batchifier", $(batchifier))
if (isDefined(batchifier)) {
arguments.put("batchifier", $(batchifier))
}
model = new ModelLoader[A, B]($(engine), $(modelUrl), $(inputClass), $(outputClass), $(translatorFactory),
arguments)
validateInputType(dataset.schema)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.spark.task.audio

import ai.djl.modality.audio.Audio
import ai.djl.spark.task.BasePredictor
import org.apache.spark.ml.util.Identifiable

/**
* BaseAudioPredictor is the base class for audio predictors.
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
abstract class BaseAudioPredictor[B](override val uid: String) extends BasePredictor[Audio, B] {

def this() = this(Identifiable.randomUID("BaseAudioPredictor"))

setDefault(inputClass, classOf[Audio])
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.spark.task.audio

import ai.djl.modality.audio.AudioFactory
import ai.djl.modality.audio.translator.SpeechRecognitionTranslatorFactory
import org.apache.spark.ml.param.IntParam
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.types.{BinaryType, StringType, StructField, StructType}

import java.io.ByteArrayInputStream

/**
* SpeechRecognizer performs speech recognition on audio.
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
class SpeechRecognizer(override val uid: String) extends BaseAudioPredictor[String]
with HasInputCol with HasOutputCol {

def this() = this(Identifiable.randomUID("SpeechRecognizer"))

final val channels = new IntParam(this, "channels", "The number of channels")
final val sampleRate = new IntParam(this, "sampleRate", "The audio sample rate")
final val sampleFormat = new IntParam(this, "sampleFormat", "The audio sample format")

protected var inputColIndex: Int = _

/**
* Sets the inputCol parameter.
*
* @param value the value of the parameter
*/
def setInputCol(value: String): this.type = set(inputCol, value)

/**
* Sets the outputCol parameter.
*
* @param value the value of the parameter
*/
def setOutputCol(value: String): this.type = set(outputCol, value)

/**
* Sets the channels parameter.
*
* @param value the value of the parameter
*/
def setChannels(value: Int): this.type = set(channels, value)

/**
* Sets the sampleRate parameter.
*
* @param value the value of the parameter
*/
def setSampleRate(value: Int): this.type = set(sampleRate, value)

/**
* Sets the sampleFormat parameter.
*
* @param value the value of the parameter
*/
def setSampleFormat(value: Int): this.type = set(sampleFormat, value)

setDefault(outputClass, classOf[String])
setDefault(translatorFactory, new SpeechRecognitionTranslatorFactory())

/**
* Performs speech recognition on the provided dataset.
*
* @param dataset input dataset
* @return output dataset
*/
def recognize(dataset: Dataset[_]): DataFrame = {
transform(dataset)
}

/** @inheritdoc */
override def transform(dataset: Dataset[_]): DataFrame = {
inputColIndex = dataset.schema.fieldIndex($(inputCol))
super.transform(dataset)
}

/**
* Transforms the rows.
*
* @param iter the rows to transform
* @return the transformed rows
*/
override protected def transformRows(iter: Iterator[Row]): Iterator[Row] = {
val predictor = model.newPredictor()
iter.map(row => {
val data = row.getAs[Array[Byte]](inputColIndex)
val audioFactory = AudioFactory.newInstance
if (isDefined(channels)) {
audioFactory.setChannels($(channels))
}
if (isDefined(sampleRate)) {
audioFactory.setSampleRate($(sampleRate))
}
if (isDefined(sampleFormat)) {
audioFactory.setSampleFormat($(sampleFormat))
}
val audio = audioFactory.fromInputStream(new ByteArrayInputStream(data))
Row.fromSeq(row.toSeq :+ predictor.predict(audio))
})
}

/** @inheritdoc */
def validateInputType(schema: StructType): Unit = {
validateType(schema($(inputCol)), BinaryType)
}

/** @inheritdoc */
override def transformSchema(schema: StructType): StructType = {
val outputSchema = StructType(schema.fields :+ StructField($(outputCol), StringType))
outputSchema
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance
* with the License. A copy of the License is located at
*
* http://aws.amazon.com/apache2.0/
*
* or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES
* OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions
* and limitations under the License.
*/
package ai.djl.spark.task.audio

import ai.djl.audio.translator.WhisperTranslatorFactory
import org.apache.spark.ml.util.Identifiable

/**
* WhisperSpeechRecognizer is very similar to the SpeechRecognizer that performs speech recognition on audio,
* except that this API is specially tailored for OpenAI Whisper related models.
*
* @param uid An immutable unique ID for the object and its derivatives.
*/
class WhisperSpeechRecognizer(override val uid: String) extends SpeechRecognizer {

def this() = this(Identifiable.randomUID("WhisperSpeechRecognizer"))

setDefault(translatorFactory, new WhisperTranslatorFactory())
}

0 comments on commit ba29c2a

Please sign in to comment.