Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add VowpalWabbit ngram support #696

Merged
merged 10 commits into from
Oct 9, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package com.microsoft.ml.spark.vw
import com.microsoft.ml.spark.core.contracts.{HasInputCols, HasOutputCol, Wrappable}
import com.microsoft.ml.spark.vw.featurizer._
import org.apache.spark.ml.{ComplexParamsReadable, ComplexParamsWritable, Transformer}
import org.apache.spark.ml.param.{IntParam, ParamMap, StringArrayParam}
import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, StringArrayParam}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.{col, struct, udf}
Expand Down Expand Up @@ -41,51 +41,79 @@ class VowpalWabbitFeaturizer(override val uid: String) extends Transformer
def getStringSplitInputCols: Array[String] = $(stringSplitInputCols)
def setStringSplitInputCols(value: Array[String]): this.type = set(stringSplitInputCols, value)

val preserveOrderNumBits = new IntParam(this, "preserveOrderNumBits",
"Number of bits used to preserve the feature order. This will reduce the hash size. " +
"Needs to be large enough to fit count the maximum number of words")
setDefault(preserveOrderNumBits -> 0)

def getPreserveOrderNumBits: Int = $(preserveOrderNumBits)
def setPreserveOrderNumBits(value: Int): this.type = {
if (value < 1 || value > 28)
throw new IllegalArgumentException("preserveOrderNumBits must be between 1 and 28 bits")
set(preserveOrderNumBits, value)
}

val prefixStringsWithColumnName = new BooleanParam(this, "prefixStringsWithColumnName",
"Prefix string features with column name")
setDefault(prefixStringsWithColumnName -> true)

def getPrefixStringsWithColumnName: Boolean = $(prefixStringsWithColumnName)
def setPrefixStringsWithColumnName(value: Boolean): this.type = set(prefixStringsWithColumnName, value)

private def getAllInputCols = getInputCols ++ getStringSplitInputCols

private def getFeaturizer(name: String, dataType: DataType, idx: Int, namespaceHash: Int): Featurizer = {
val stringSplitInputCols = getStringSplitInputCols

val prefixName = if (getPrefixStringsWithColumnName) name else ""

dataType match {
case DoubleType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getDouble(idx))
case FloatType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getFloat(idx).toDouble)
case IntegerType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getInt(idx).toDouble)
case LongType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getLong(idx).toDouble)
case ShortType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getShort(idx).toDouble)
case ByteType => new NumericFeaturizer(idx, name, namespaceHash, getMask, r => r.getByte(idx).toDouble)
case BooleanType => new BooleanFeaturizer(idx, name, namespaceHash, getMask)
case DoubleType => new NumericFeaturizer(idx, prefixName, namespaceHash, getMask, r => r.getDouble(idx))
case FloatType => new NumericFeaturizer(idx, prefixName, namespaceHash, getMask, r => r.getFloat(idx).toDouble)
case IntegerType => new NumericFeaturizer(idx, prefixName, namespaceHash, getMask, r => r.getInt(idx).toDouble)
case LongType => new NumericFeaturizer(idx, prefixName, namespaceHash, getMask, r => r.getLong(idx).toDouble)
case ShortType => new NumericFeaturizer(idx, prefixName, namespaceHash, getMask, r => r.getShort(idx).toDouble)
case ByteType => new NumericFeaturizer(idx, prefixName, namespaceHash, getMask, r => r.getByte(idx).toDouble)
case BooleanType => new BooleanFeaturizer(idx, prefixName, namespaceHash, getMask)
case StringType =>
if (stringSplitInputCols.contains(name))
new StringSplitFeaturizer(idx, name, namespaceHash, getMask)
else new StringFeaturizer(idx, name, namespaceHash, getMask)
new StringSplitFeaturizer(idx, prefixName, namespaceHash, getMask)
else new StringFeaturizer(idx, prefixName, namespaceHash, getMask)
case arr: ArrayType =>
if (arr.elementType != DataTypes.StringType)
throw new RuntimeException(s"Unsupported array element type: $dataType")
new StringArrayFeaturizer(idx, name, namespaceHash, getMask)
new StringArrayFeaturizer(idx, prefixName, namespaceHash, getMask)

case m: MapType =>
if (m.keyType != DataTypes.StringType)
throw new RuntimeException(s"Unsupported map key type: $dataType")

m.valueType match {
case StringType => new MapStringFeaturizer(idx, name, namespaceHash, getMask)
case DoubleType => new MapFeaturizer[Double](idx, name, namespaceHash, getMask, v => v)
case FloatType => new MapFeaturizer[Float](idx, name, namespaceHash, getMask, v => v.toDouble)
case IntegerType => new MapFeaturizer[Int](idx, name, namespaceHash, getMask, v => v.toDouble)
case LongType => new MapFeaturizer[Long](idx, name, namespaceHash, getMask, v => v.toDouble)
case ShortType => new MapFeaturizer[Short](idx, name, namespaceHash, getMask, v => v.toDouble)
case ByteType => new MapFeaturizer[Byte](idx, name, namespaceHash, getMask, v => v.toDouble)
case StringType => new MapStringFeaturizer(idx, prefixName, namespaceHash, getMask)
case DoubleType => new MapFeaturizer[Double](idx, prefixName, namespaceHash, getMask, v => v)
case FloatType => new MapFeaturizer[Float](idx, prefixName, namespaceHash, getMask, v => v.toDouble)
case IntegerType => new MapFeaturizer[Int](idx, prefixName, namespaceHash, getMask, v => v.toDouble)
case LongType => new MapFeaturizer[Long](idx, prefixName, namespaceHash, getMask, v => v.toDouble)
case ShortType => new MapFeaturizer[Short](idx, prefixName, namespaceHash, getMask, v => v.toDouble)
case ByteType => new MapFeaturizer[Byte](idx, prefixName, namespaceHash, getMask, v => v.toDouble)
case _ => throw new RuntimeException(s"Unsupported map value type: $dataType")
}
case m: Any =>
if (m == VectorType) // unfortunately the type is private
new VectorFeaturizer(idx, name, getMask)
new VectorFeaturizer(idx, prefixName, getMask)
else
throw new RuntimeException(s"Unsupported data type: $dataType")
}
}

override def transform(dataset: Dataset[_]): DataFrame = {
if (getPreserveOrderNumBits + getNumBits > 30)
throw new IllegalArgumentException(
s"Number of bits used for hashing (${getNumBits} and " +
s"number of bits used for order preserving (${getPreserveOrderNumBits}) must be less than 30")

val maxFeaturesForOrdering = 1 << getPreserveOrderNumBits

val inputColsList = getAllInputCols
val namespaceHash: Int = VowpalWabbitMurmur.hash(this.getOutputCol, this.getSeed)

Expand Down Expand Up @@ -116,13 +144,34 @@ class VowpalWabbitFeaturizer(override val uid: String) extends Transformer
if (!r.isNullAt(f.fieldIdx))
f.featurize(r, indices, values)

val indicesArray = indices.result
if (getPreserveOrderNumBits > 0) {
var idxPrefixBits = 30 - getPreserveOrderNumBits

if (indicesArray.length > maxFeaturesForOrdering)
throw new IllegalArgumentException(
s"Too many features ${indicesArray.length} for " +
s"number of bits used for order preserving (${getPreserveOrderNumBits})")

// prefix every feature index with a counter value
// will be stripped when passing to VW
for (i <- 0 until indicesArray.length) {
val idxPrefix = i << idxPrefixBits
indicesArray(i) = indicesArray(i) | idxPrefix
}
}

// if we use the highest order bits to preserve the ordering
// the maximum index size is larger
val size = if(getPreserveOrderNumBits > 0) 1 << 30 else 1 << getNumBits

// sort by indices and remove duplicate values
// Warning:
// - due to SparseVector limitations (which doesn't allow duplicates) we need filter
// - VW command line allows for duplicate features with different values (just updates twice)
val (indicesSorted, valuesSorted) = VectorUtils.sortAndDistinct(indices.result, values.result, getSumCollisions)
val (indicesSorted, valuesSorted) = VectorUtils.sortAndDistinct(indicesArray, values.result, getSumCollisions)

Vectors.sparse(1 << getNumBits, indicesSorted, valuesSorted)
Vectors.sparse(size, indicesSorted, valuesSorted)
})

dataset.toDF.withColumn(getOutputCol, mode.apply(struct(fieldSubset.map(f => col(f.name)): _*)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,32 @@ class VerifyVowpalWabbitClassifier extends Benchmarks with EstimatorFuzzing[Vowp
println(labelOneCnt1)
}

case class ClassificationInput[T](label: Int, in: T)

test("Verify VowpalWabbit Classifier w/ ngrams") {
val featurizer = new VowpalWabbitFeaturizer()
.setStringSplitInputCols(Array("in"))
.setPreserveOrderNumBits(2)
.setNumBits(18)
.setPrefixStringsWithColumnName(false)
.setOutputCol("features")

val dataset = session.createDataFrame(Seq(
ClassificationInput[String](1, "marie markus fun"),
ClassificationInput[String](0, "marie markus no fun")
)).coalesce(1)

val datasetFeaturized = featurizer.transform(dataset)

val vw1 = new VowpalWabbitClassifier()
.setArgs("--ngram f2 -a")
val classifier1 = vw1.fit(datasetFeaturized)

// 3 (words) + 2 (ngrams) + 1 (constant) = 6
// 4 (words) + 3 (ngrams) + 1 (constant) = 8
assert (classifier1.getPerformanceStatistics.select("totalNumberOfFeatures").head.get(0) == 14)
}

/** Reads a CSV file given the file name and file location.
* @param fileName The name of the csv file.
* @param fileLocation The full path to the csv file.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,33 @@ class VerifyVowpalWabbitFeaturizer extends TestBase with TransformerFuzzing[Vowp

val namespaceFeatures = VowpalWabbitMurmur.hash("features", 0)

test("Verify order preserving") {
val featurizer1 = new VowpalWabbitFeaturizer()
.setStringSplitInputCols(Array("in"))
.setPreserveOrderNumBits(2)
.setNumBits(18)
.setPrefixStringsWithColumnName(false)
.setOutputCol("features")
val df1 = session.createDataFrame(Seq(Input[String]("marie markus fun")))

val v1 = featurizer1.transform(df1).select(col("features")).collect.apply(0).getAs[SparseVector](0)

assert(v1.numNonzeros == 3)

val bitMask = (1 << 18) - 1

// the order is the same as in the string above
assert((bitMask & v1.indices(0)) == (bitMask &
VowpalWabbitMurmur.hash("marie", namespaceFeatures)))
assert((bitMask & v1.indices(1)) == (bitMask &
VowpalWabbitMurmur.hash("markus", namespaceFeatures)))
assert((bitMask & v1.indices(2)) == (bitMask &
VowpalWabbitMurmur.hash("fun", namespaceFeatures)))
assert(v1.values(0) == 1.0)
assert(v1.values(1) == 1.0)
assert(v1.values(2) == 1.0)
}

test("Verify VowpalWabbit Featurizer can be run with seq and string") {
val featurizer1 = new VowpalWabbitFeaturizer()
.setInputCols(Array("str", "seq"))
Expand Down