From 881ab0c7e51166c43202109d6879d9356a7fb1e3 Mon Sep 17 00:00:00 2001 From: Markus Cozowicz Date: Sat, 7 Mar 2020 06:15:04 +0100 Subject: [PATCH 1/2] support numeric types (not just double) for weight/label --- .../microsoft/ml/spark/vw/VowpalWabbitBase.scala | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitBase.scala b/src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitBase.scala index f7fde998a9..8cd6dc9c63 100644 --- a/src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitBase.scala +++ b/src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitBase.scala @@ -166,13 +166,23 @@ trait VowpalWabbitBase extends Wrappable protected def createLabelSetter(schema: StructType) = { val labelColIdx = schema.fieldIndex(getLabelCol) + // support numeric types as input + def getAsFloat(idx: Int) = + schema.fields(idx).dataType match { + case _: DoubleType => (row: Row) => row.getDouble(idx).toFloat + case _: FloatType => (row: Row) => row.getFloat(idx) + case _: IntegerType => (row: Row) => row.getInt(idx).toFloat + case _: LongType => (row: Row) => row.getLong(idx).toFloat + } + + val labelGetter = getAsFloat(labelColIdx) if (get(weightCol).isDefined) { - val weightColIdx = schema.fieldIndex(getWeightCol) + val weightGetter = getAsFloat(schema.fieldIndex(getWeightCol)) (row: Row, ex: VowpalWabbitExample) => - ex.setLabel(row.getDouble(weightColIdx).toFloat, row.getDouble(labelColIdx).toFloat) + ex.setLabel(weightGetter(row), labelGetter(row)) } else - (row: Row, ex: VowpalWabbitExample) => ex.setLabel(row.getDouble(labelColIdx).toFloat) + (row: Row, ex: VowpalWabbitExample) => ex.setLabel(labelGetter(row)) } /** From 09b0dac8706c86a7b0710e5f6f6ded59ad70e39e Mon Sep 17 00:00:00 2001 From: Markus Cozowicz Date: Sat, 7 Mar 2020 06:51:48 +0100 Subject: [PATCH 2/2] added warning when casting double to float fixed style (remove space *sigh*) --- .../scala/com/microsoft/ml/spark/vw/VowpalWabbitBase.scala | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitBase.scala b/src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitBase.scala index 8cd6dc9c63..1e455f2c1a 100644 --- a/src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitBase.scala +++ b/src/main/scala/com/microsoft/ml/spark/vw/VowpalWabbitBase.scala @@ -166,10 +166,13 @@ trait VowpalWabbitBase extends Wrappable protected def createLabelSetter(schema: StructType) = { val labelColIdx = schema.fieldIndex(getLabelCol) - // support numeric types as input + // support numeric types as input def getAsFloat(idx: Int) = schema.fields(idx).dataType match { - case _: DoubleType => (row: Row) => row.getDouble(idx).toFloat + case _: DoubleType => { + log.warn(s"Casting column '${schema.fields(idx).name}' to float. Loss of precision.") + (row: Row) => row.getDouble(idx).toFloat + } case _: FloatType => (row: Row) => row.getFloat(idx) case _: IntegerType => (row: Row) => row.getInt(idx).toFloat case _: LongType => (row: Row) => row.getLong(idx).toFloat