Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix cognitive service errors #1176

Merged
merged 14 commits into from
Sep 10, 2021
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,9 @@ trait HasServiceParams extends Params {
}

protected def shouldSkip(row: Row): Boolean = getRequiredParams.exists { p =>
emptyParamData(row, p)
if (emptyParamData(row, p))
throw new NullPointerException(s"required param undefined: $p")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will never throw false if the error is in here. What I'm thinking is that we look at the required params in the transformSchema function to ensure all required params are set with either a value or a column. We can then call transformSchema before we transform to add validation there too

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Should be an IllegalArgumentException

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In CognitiveServiceBase, transformSchema and transform both call getInternalTransformer first, and in this function we create new SimpleHTTPTransformer and setInputParser(getInternalInputParser(schema)). And in getInternalInputParser we call inputFunc, which calls shouldSkip, and that's why I add it here. I know it will never throw false if the error happens here, so do we want to silent the error to catch it later and return null?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could have the logic in getInternalTransformer. I think the logic should live in the "control-plane" which executes on the head node rather than code inside a mapPartitions like shouldSkip

else false
}

protected def getValueOpt[T](row: Row, p: ServiceParam[T]): Option[T] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,8 +312,8 @@ class RecognizeText(override val uid: String)
"printed text recognition is performed. If 'Handwritten' is specified," +
" handwriting recognition is performed",
{
case Left(_) => true
case Right(s) => Set("Printed", "Handwritten")(s)
case Left(s) => Set("Printed", "Handwritten")(s)
case Right(_) => true
}, isURLParam = true)

def getMode: String = getScalarParam(mode)
Expand Down Expand Up @@ -361,8 +361,8 @@ class ReadImage(override val uid: String)
" so only provide a language code if you would like to force the documented" +
" to be processed as that specific language.",
{
case Left(_) => true
case Right(s) => Set("en", "nl", "fr", "de", "it", "pt", "es")(s)
case Left(s) => Set("en", "nl", "fr", "de", "it", "pt", "es")(s)
case Right(_) => true
}, isURLParam = true)

def setLanguage(v: String): this.type = setScalarParam(language, v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ trait HasModelID extends HasServiceParams {
trait HasLocale extends HasServiceParams {
val locale = new ServiceParam[String](this, "locale", "Locale of the receipt. Supported" +
" locales: en-AU, en-CA, en-GB, en-IN, en-US.", {
case Left(_) => true
case Right(s) => Set("en-AU", "en-CA", "en-GB", "en-IN", "en-US")(s)
case Left(s) => Set("en-AU", "en-CA", "en-GB", "en-IN", "en-US")(s)
case Right(_) => true
}, isURLParam = true)

def setLocale(v: String): this.type = setScalarParam(locale, v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ trait HasTextInput extends HasServiceParams {

def setText(v: Seq[String]): this.type = setScalarParam(text, v)

def setText(v: String): this.type = setScalarParam(text, Seq(v))

def getTextCol: String = getVectorParam(text)

def setTextCol(v: String): this.type = setVectorParam(text, v)
Expand Down Expand Up @@ -76,9 +78,15 @@ trait TextAsOnlyEntity extends HasTextInput with HasCognitiveServiceInput {

override protected def prepareEntity: Row => Option[AbstractHttpEntity] = {
r =>
Some(new StringEntity(
getValueOpt(r, text)
.map(x => x.map(y => Map("Text" -> y))).toJson.compactPrint, ContentType.APPLICATION_JSON))
val textVal = getValueOpt(r, text)
if (textVal.nonEmpty) {
val content = textVal.get.getClass.getName match {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chris' second issue was caused by having nulls within a batch. Perhaps we should add that case as a test and ensure this new func can handle. I think in my TA batching logic it had to get pretty hairy to handle that unfortunately. Perhaps we can use similar logic. Also we might want to consider a better solution that just batching elements into single arrays. But we can push that to a later PR as that is current behavior in TA

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For nulls in text it will get null as return result, since the request itself would become invalid. But if toLanguage is set to null, then it will trigger an error as above "required param undefined", what behaviors are we expecting exactly? I noticed in TA there's a reshapeToArray stuff dealing with turning string into arrays, but I didn't get Chris's problem with the output schema (there's unpackBatchUDF dealing with the return json in TA), could explain more about that issue?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the error comes from when there are nulls of text which would ordinarily be skipped but they are in a batch so the mess up the whole batch

case "java.lang.String" => Seq(Map("Text" -> textVal.get.asInstanceOf[String])).toJson.compactPrint
case _ => textVal.get.map(x => Map("Text" -> x)).toJson.compactPrint
}
Some(new StringEntity(content, ContentType.APPLICATION_JSON))
}
else Some(new StringEntity(Map("Text" -> "").toJson.compactPrint, ContentType.APPLICATION_JSON))
}
}

Expand Down Expand Up @@ -162,6 +170,38 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)

def urlPath: String = "translate"

override protected def prepareUrl: Row => String = {
val urlParams: Array[ServiceParam[Any]] =
getUrlParams.asInstanceOf[Array[ServiceParam[Any]]];

// This semicolon is needed to avoid argument confusion
def replaceName(s: String): String = {
if (s == "fromLanguage") {
"from"
} else if (s == "toLanguage") {
"to"
} else {
s
}
}
{ row: Row =>
val base = getUrl + "?api-version=3.0"
val appended = if (!urlParams.isEmpty) {
"&" + URLEncodingUtils.format(urlParams.flatMap(p =>
getValueOpt(row, p).map {
v =>
if (p.name == "toLanguage" & v.getClass.getName == "java.lang.String")
replaceName(p.name) -> p.toValueString(Seq(v))
else replaceName(p.name) -> p.toValueString(v)
}
).toMap)
} else {
""
}
base + appended
}
}

val toLanguage = new ServiceParam[Seq[String]](this, "toLanguage", "Specifies the language of the output" +
" text. The target language must be one of the supported languages included in the translation scope." +
" For example, use to=de to translate to German. It's possible to translate to multiple languages simultaneously" +
Expand All @@ -171,6 +211,8 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)

def setToLanguage(v: Seq[String]): this.type = setScalarParam(toLanguage, v)

def setToLanguage(v: String): this.type = setScalarParam(toLanguage, Seq(v))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! We might need to add this to python methods too

Copy link
Contributor Author

@serena-ruan serena-ruan Sep 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this will automatically work? We're calling self._java_obj = self._java_obj.setXXX(value) and the java object will find the corresponding set function that matches the value type?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Youre right!


def setToLanguageCol(v: String): this.type = setVectorParam(toLanguage, v)

val fromLanguage = new ServiceParam[String](this, "fromLanguage", "Specifies the language of the input" +
Expand All @@ -186,8 +228,8 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)
val textType = new ServiceParam[String](this, "textType", "Defines whether the text being" +
" translated is plain text or HTML text. Any HTML needs to be a well-formed, complete element. Possible values" +
" are: plain (default) or html.", {
case Left(_) => true
case Right(s) => Set("plain", "html")(s)
case Left(s) => Set("plain", "html")(s)
case Right(_) => true
}, isURLParam = true)

def setTextType(v: String): this.type = setScalarParam(textType, v)
Expand All @@ -206,8 +248,8 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)
val profanityAction = new ServiceParam[String](this, "profanityAction", "Specifies how" +
" profanities should be treated in translations. Possible values are: NoAction (default), Marked or Deleted. ",
{
case Left(_) => true
case Right(s) => Set("NoAction", "Marked", "Deleted")(s)
case Left(s) => Set("NoAction", "Marked", "Deleted")(s)
case Right(_) => true
}, isURLParam = true)

def setProfanityAction(v: String): this.type = setScalarParam(profanityAction, v)
Expand All @@ -216,8 +258,8 @@ class Translate(override val uid: String) extends TextTranslatorBase(uid)

val profanityMarker = new ServiceParam[String](this, "profanityMarker", "Specifies how" +
" profanities should be marked in translations. Possible values are: Asterisk (default) or Tag.", {
case Left(_) => true
case Right(s) => Set("Asterisk", "Tag")(s)
case Left(s) => Set("Asterisk", "Tag")(s)
case Right(_) => true
}, isURLParam = true)

def setProfanityMarker(v: String): this.type = setScalarParam(profanityMarker, v)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,11 @@ class DescribeImageSuite extends TransformerFuzzing[DescribeImage]
assert(tags("person") && tags("glasses"))
}

override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
super.assertDFEq(df1.select("descriptions.description.tags", "descriptions.description.captions.text"),
df2.select("descriptions.description.tags", "descriptions.description.captions.text"))(eq)
}

override def testObjects(): Seq[TestObject[DescribeImage]] =
Seq(new TestObject(t, df))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ trait TranslatorUtils extends TestBase {
lazy val textDf5: DataFrame = Seq(List("The word <mstrans:dictionary translation=wordomatic>word " +
"or phrase</mstrans:dictionary> is a dictionary entry.")).toDF("text")

lazy val textDf6: DataFrame = Seq("Hi, this is Synapse!", "Yes!").toDF("text")

lazy val textDf7: DataFrame = Seq(("Hi, this is Synapse!", "zh-Hans")).toDF("text", "language")

}

class TranslateSuite extends TransformerFuzzing[Translate]
Expand Down Expand Up @@ -64,6 +68,47 @@ class TranslateSuite extends TransformerFuzzing[Translate]
translate.setToLanguage(Seq("zh-Hans")), textDf2, "你好,你叫什么名字?\n再见"
)
)

assert(
translationTextTest(
translate.setToLanguage("zh-Hans"), textDf6, "嗨, 这是突触!"
)
)

val translate1: Translate = new Translate()
.setSubscriptionKey(translatorKey)
.setLocation("eastus")
.setText("Hi, this is Synapse!")
.setOutputCol("translation")
.setConcurrency(5)

assert(
translationTextTest(
translate1.setToLanguage("zh-Hans"), textDf6, "嗨, 这是突触!"
)
)

val translate2: Translate = new Translate()
.setSubscriptionKey(translatorKey)
.setLocation("eastus")
.setTextCol("text")
.setToLanguageCol("language")
.setOutputCol("translation")
.setConcurrency(5)

assert(
translationTextTest(
translate2, textDf7, "嗨, 这是突触!"
)
)
}

test("Translate triggers errors if required fields not set") {
try {
translate.transform(textDf2).collect()
} catch {
case e: Exception => assert(e.getCause.getMessage.contains("required param undefined"))
}
}

test("Translate with transliteration") {
Expand Down Expand Up @@ -146,7 +191,7 @@ class TranslateSuite extends TransformerFuzzing[Translate]
}

override def testObjects(): Seq[TestObject[Translate]] =
Seq(new TestObject(translate, textDf1))
Seq(new TestObject(translate.setToLanguage(Seq("zh-Hans")), textDf1))

override def reader: MLReadable[_] = Translate
}
Expand Down