Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable GPT-4 in OpenAIPrompt #2248

Merged
merged 10 commits into from
Jul 16, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import com.microsoft.azure.synapse.ml.param.StringStringMapParam
import org.apache.spark.ml.param.{BooleanParam, Param, ParamMap, ParamValidators}
import org.apache.spark.ml.util.Identifiable
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.sql.functions.udf
import org.apache.spark.sql.types.{DataType, StructType}
import org.apache.spark.sql.{Column, DataFrame, Dataset, functions => F, types => T}

Expand Down Expand Up @@ -62,18 +63,26 @@ class OpenAIPrompt(override val uid: String) extends Transformer
set(postProcessingOptions, v.asScala.toMap)

val dropPrompt = new BooleanParam(
this, "dropPrompt", "whether to drop the column of prompts after templating")
this, "dropPrompt", "whether to drop the column of prompts after templating (when using legacy models)")

def getDropPrompt: Boolean = $(dropPrompt)

def setDropPrompt(value: Boolean): this.type = set(dropPrompt, value)

val dropMessages = new BooleanParam(
this, "dropMessages", "whether to drop the column of messages after templating (when using gpt-4 or higher)")

def getDropMessages: Boolean = $(dropMessages)

def setDropMessages(value: Boolean): this.type = set(dropMessages, value)

setDefault(
postProcessing -> "",
postProcessingOptions -> Map.empty,
outputCol -> (this.uid + "_output"),
errorCol -> (this.uid + "_error"),
dropPrompt -> true,
dropMessages -> true,
timeout -> 360.0
)

Expand All @@ -82,40 +91,75 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}

private val localParamNames = Seq(
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt")
"promptTemplate", "outputCol", "postProcessing", "postProcessingOptions", "dropPrompt", "dropMessages")

override def transform(dataset: Dataset[_]): DataFrame = {
import com.microsoft.azure.synapse.ml.core.schema.DatasetExtensions._

logTransform[DataFrame]({
val df = dataset.toDF

val promptColName = df.withDerivativeCol("prompt")

val dfTemplated = df.withColumn(promptColName, Functions.template(getPromptTemplate))

val completion = openAICompletion.setPromptCol(promptColName)

// run completion
val results = completion
.transform(dfTemplated)
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completion.getOutputCol).getField("choices"), 1)
.getField("text")))
.drop(completion.getOutputCol)

if (getDropPrompt) {
results.drop(promptColName)
} else {
results
val completion = openAICompletion
val promptCol = Functions.template(getPromptTemplate)
val systemPrompt = "You are an AI Chatbot. Only respond with a completion."
val createMessagesUDF = udf((userMessage: String) => {
Seq(
OpenAIMessage("system", systemPrompt),
OpenAIMessage("user", userMessage)
)
})
completion match {
case chatCompletion: OpenAIChatCompletion =>
val messageColName = df.withDerivativeCol("messages")
val dfTemplated = df.withColumn(messageColName, createMessagesUDF(promptCol))
val completionNamed = chatCompletion.setMessagesCol(messageColName)

val results = completionNamed
.transform(dfTemplated)
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("message").getField("content")))
.drop(completionNamed.getOutputCol)

if (getDropMessages) {
results.drop(messageColName)
} else {
results
}

case completion: OpenAICompletion =>
val promptColName = df.withDerivativeCol("prompt")
val dfTemplated = df.withColumn(promptColName, promptCol)
val completionNamed = completion.setPromptCol(promptColName)

// run completion
val results = completionNamed
.transform(dfTemplated)
.withColumn(getOutputCol,
getParser.parse(F.element_at(F.col(completionNamed.getOutputCol).getField("choices"), 1)
.getField("text")))
.drop(completionNamed.getOutputCol)

if (getDropPrompt) {
results.drop(promptColName)
} else {
results
}
}
}, dataset.columns.length)
}

private def openAICompletion: OpenAICompletion = {
// apply template
val completion = new OpenAICompletion()
private def openAICompletion: OpenAIServicesBase = {

val completion: OpenAIServicesBase =
// use OpenAICompletion
if (getDeploymentName != "gpt-4") {
new OpenAICompletion()
}
else {
// use OpenAIChatCompletion
new OpenAIChatCompletion()
}
// apply all parameters
extractParamMap().toSeq
.filter(p => !localParamNames.contains(p.param.name))
Expand All @@ -136,10 +180,20 @@ class OpenAIPrompt(override val uid: String) extends Transformer
}
}

override def transformSchema(schema: StructType): StructType =
openAICompletion
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
override def transformSchema(schema: StructType): StructType = {
openAICompletion match {
case chatCompletion: OpenAIChatCompletion =>
chatCompletion.setMessagesCol("messages")
chatCompletion
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
case completion: OpenAICompletion =>
completion
.transformSchema(schema)
.add(getPostProcessing, getParser.outputSchema)
}

}
}

trait OutputParser {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,44 @@ class OpenAIPromptSuite extends TransformerFuzzing[OpenAIPrompt] with OpenAIAPIK

test("Basic Usage JSON") {
prompt.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
.setPostProcessing("json")
.setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING"))
.transform(df)
.select("outParsed")
.where(col("outParsed").isNotNull)
.collect()
.foreach(r => assert(r.getStruct(0).getString(0).nonEmpty))
}

lazy val promptGpt4: OpenAIPrompt = new OpenAIPrompt()
.setSubscriptionKey(openAIAPIKey)
.setDeploymentName(deploymentNameGpt4)
.setCustomServiceName(openAIServiceName)
.setOutputCol("outParsed")
.setTemperature(0)

test("Basic Usage - Gpt 4") {
val nonNullCount = promptGpt4
.setPromptTemplate("here is a comma separated list of 5 {category}: {text}, ")
.setPostProcessing("csv")
.transform(df)
.select("outParsed")
.collect()
.count(r => Option(r.getSeq[String](0)).isDefined)

assert(nonNullCount == 3)
}

test("Basic Usage JSON - Gpt 4") {
promptGpt4.setPromptTemplate(
"""Split a word into prefix and postfix a respond in JSON
|Cherry: {{"prefix": "Che", "suffix": "rry"}}
|{text}:
|""".stripMargin)
.setPostProcessing("json")
.setPostProcessingOptions(Map("jsonSchema" -> "prefix STRING, suffix STRING"))
.transform(df)
Expand Down
Loading