Skip to content

Commit

Permalink
fix: Fixing flaky unit tests (#1215)
Browse files Browse the repository at this point in the history
  • Loading branch information
memoryz authored Oct 19, 2021
1 parent 5d31e3e commit e23041f
Showing 1 changed file with 46 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, Transformer
import com.microsoft.azure.synapse.ml.core.utils.BreezeUtils._
import com.microsoft.azure.synapse.ml.explainers.LocalExplainer.LIME
import com.microsoft.azure.synapse.ml.explainers.VectorLIME
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.linalg.{DenseVector, SQLDataTypes, Vector, Vectors}
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.ml.util.MLReadable
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types.{ArrayType, DataType}
import org.scalactic.Equality

class VectorLIMEExplainerSuite extends TestBase
Expand All @@ -22,9 +24,50 @@ class VectorLIMEExplainerSuite extends TestBase
import spark.implicits._

implicit val vectorEquality: Equality[BDV[Double]] = breezeVectorEq(1E-6)
implicit val matrixEquality: Equality[BDM[Double]] = breezeMatrixEq(1E-6)

override val sortInDataframeEquality = true
// The equality comparison for the weights matrices are quite complicated,
// creating an override here to handle the logic.
override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = {
assert(DataType.equalsStructurally(df1.schema, df2.schema))

val arrayOfVectorFields = df1.schema.fields.filter {
f => DataType.equalsStructurally(f.dataType, ArrayType(SQLDataTypes.VectorType))
}.map(_.name)

// Non-matrix columns, defer to super.assertDFEq
val otherFields = df1.schema.fields.map(_.name).filterNot(arrayOfVectorFields.contains)

super.assertDFEq(
df1.select(otherFields.map(col): _*),
df2.select(otherFields.map(col): _*)
)(eq)

val (a, b) = (df1.select(arrayOfVectorFields.map(col): _*), df2.select(arrayOfVectorFields.map(col): _*)) match {
case (x, y) =>
val sorted = if (sortInDataframeEquality) {
(x.sort(arrayOfVectorFields.map(col): _*), y.sort(arrayOfVectorFields.map(col): _*))
} else {
(x, y)
}

(sorted._1.collect, sorted._2.collect)
}

assert(a.length === b.length)
a.indices foreach { // For each row
i =>
arrayOfVectorFields.indices foreach {
j => // For each matrix type column
val x: Seq[DenseVector] = a(i).getAs[Seq[DenseVector]](j)
val y: Seq[DenseVector] = b(i).getAs[Seq[DenseVector]](j)
assert(x.length === y.length)
x.indices foreach {
k =>
x(k) === y(k)
}
}
}
}

val d1 = 3
val d2 = 1
Expand Down Expand Up @@ -54,7 +97,6 @@ class VectorLIMEExplainerSuite extends TestBase
.setNumSamples(1000)

test("VectorLIME can explain a model locally") {

val predicted = model.transform(df)
val weights = lime
.transform(predicted)
Expand Down

0 comments on commit e23041f

Please sign in to comment.