From e23041f47f3bad97435eb5564e0ca451fc70aee2 Mon Sep 17 00:00:00 2001 From: Jason Wang Date: Mon, 18 Oct 2021 17:01:46 -0700 Subject: [PATCH] fix: Fixing flaky unit tests (#1215) --- .../split1/VectorLIMEExplainerSuite.scala | 50 +++++++++++++++++-- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/VectorLIMEExplainerSuite.scala b/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/VectorLIMEExplainerSuite.scala index 1d6a7dcdce..0ffe673ef6 100644 --- a/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/VectorLIMEExplainerSuite.scala +++ b/core/src/test/scala/com/microsoft/azure/synapse/ml/explainers/split1/VectorLIMEExplainerSuite.scala @@ -10,10 +10,12 @@ import com.microsoft.azure.synapse.ml.core.test.fuzzing.{TestObject, Transformer import com.microsoft.azure.synapse.ml.core.utils.BreezeUtils._ import com.microsoft.azure.synapse.ml.explainers.LocalExplainer.LIME import com.microsoft.azure.synapse.ml.explainers.VectorLIME -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseVector, SQLDataTypes, Vector, Vectors} import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel} import org.apache.spark.ml.util.MLReadable import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.types.{ArrayType, DataType} import org.scalactic.Equality class VectorLIMEExplainerSuite extends TestBase @@ -22,9 +24,50 @@ class VectorLIMEExplainerSuite extends TestBase import spark.implicits._ implicit val vectorEquality: Equality[BDV[Double]] = breezeVectorEq(1E-6) - implicit val matrixEquality: Equality[BDM[Double]] = breezeMatrixEq(1E-6) - override val sortInDataframeEquality = true + // The equality comparison for the weights matrices are quite complicated, + // creating an override here to handle the logic. + override def assertDFEq(df1: DataFrame, df2: DataFrame)(implicit eq: Equality[DataFrame]): Unit = { + assert(DataType.equalsStructurally(df1.schema, df2.schema)) + + val arrayOfVectorFields = df1.schema.fields.filter { + f => DataType.equalsStructurally(f.dataType, ArrayType(SQLDataTypes.VectorType)) + }.map(_.name) + + // Non-matrix columns, defer to super.assertDFEq + val otherFields = df1.schema.fields.map(_.name).filterNot(arrayOfVectorFields.contains) + + super.assertDFEq( + df1.select(otherFields.map(col): _*), + df2.select(otherFields.map(col): _*) + )(eq) + + val (a, b) = (df1.select(arrayOfVectorFields.map(col): _*), df2.select(arrayOfVectorFields.map(col): _*)) match { + case (x, y) => + val sorted = if (sortInDataframeEquality) { + (x.sort(arrayOfVectorFields.map(col): _*), y.sort(arrayOfVectorFields.map(col): _*)) + } else { + (x, y) + } + + (sorted._1.collect, sorted._2.collect) + } + + assert(a.length === b.length) + a.indices foreach { // For each row + i => + arrayOfVectorFields.indices foreach { + j => // For each matrix type column + val x: Seq[DenseVector] = a(i).getAs[Seq[DenseVector]](j) + val y: Seq[DenseVector] = b(i).getAs[Seq[DenseVector]](j) + assert(x.length === y.length) + x.indices foreach { + k => + x(k) === y(k) + } + } + } + } val d1 = 3 val d2 = 1 @@ -54,7 +97,6 @@ class VectorLIMEExplainerSuite extends TestBase .setNumSamples(1000) test("VectorLIME can explain a model locally") { - val predicted = model.transform(df) val weights = lime .transform(predicted)