Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create CalibratedPredictor instead of SchemaBindableCalibratedPredictor #338

Merged
merged 2 commits into from
Jun 13, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions src/Microsoft.ML.Data/Prediction/Calibrator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -746,13 +746,10 @@ private static bool NeedCalibration(IHostEnvironment env, IChannel ch, ICalibrat
/// <param name="trainer">The trainer used to train the predictor.</param>
/// <param name="predictor">The predictor that needs calibration.</param>
/// <param name="data">The examples to used for calibrator training.</param>
/// <param name="needValueMapper">Indicates whether the predictor returned needs to be an <see cref="IValueMapper"/>.
/// This parameter is needed for OVA that uses the predictors as <see cref="IValueMapper"/>s. If it is false,
/// The predictor returned is an an <see cref="ISchemaBindableMapper"/>.</param>
/// <returns>The original predictor, if no calibration is needed,
/// or a metapredictor that wraps the original predictor and the newly trained calibrator.</returns>
public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel ch, ICalibratorTrainer calibrator,
int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data, bool needValueMapper = false)
int maxRows, ITrainer trainer, IPredictor predictor, RoleMappedData data)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
Expand All @@ -763,7 +760,7 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel
if (!NeedCalibration(env, ch, calibrator, trainer, predictor, data.Schema))
return predictor;

return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data, needValueMapper);
return TrainCalibrator(env, ch, calibrator, maxRows, predictor, data);
}

/// <summary>
Expand All @@ -775,13 +772,10 @@ public static IPredictor TrainCalibratorIfNeeded(IHostEnvironment env, IChannel
/// <param name="maxRows">The maximum rows to use for calibrator training.</param>
/// <param name="predictor">The predictor that needs calibration.</param>
/// <param name="data">The examples to used for calibrator training.</param>
/// <param name="needValueMapper">Indicates whether the predictor returned needs to be an <see cref="IValueMapper"/>.
/// This parameter is needed for OVA that uses the predictors as <see cref="IValueMapper"/>s. If it is false,
/// The predictor returned is an an <see cref="ISchemaBindableMapper"/>.</param>
/// <returns>The original predictor, if no calibration is needed,
/// or a metapredictor that wraps the original predictor and the newly trained calibrator.</returns>
public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICalibratorTrainer caliTrainer,
int maxRows, IPredictor predictor, RoleMappedData data, bool needValueMapper = false)
int maxRows, IPredictor predictor, RoleMappedData data)
{
Contracts.CheckValue(env, nameof(env));
env.CheckValue(ch, nameof(ch));
Expand Down Expand Up @@ -834,10 +828,10 @@ public static IPredictor TrainCalibrator(IHostEnvironment env, IChannel ch, ICal
}
}
var cali = caliTrainer.FinishTraining(ch);
return CreateCalibratedPredictor(env, (IPredictorProducing<Float>)predictor, cali, needValueMapper);
return CreateCalibratedPredictor(env, (IPredictorProducing<Float>)predictor, cali);
}

public static IPredictorProducing<Float> CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing<Float> predictor, ICalibrator cali, bool needValueMapper = false)
public static IPredictorProducing<Float> CreateCalibratedPredictor(IHostEnvironment env, IPredictorProducing<Float> predictor, ICalibrator cali)
{
Contracts.Assert(predictor != null);
if (cali == null)
Expand All @@ -853,7 +847,7 @@ public static IPredictorProducing<Float> CreateCalibratedPredictor(IHostEnvironm
var predWithFeatureScores = predictor as IPredictorWithFeatureWeights<Float>;
if (predWithFeatureScores != null && predictor is IParameterMixer<Float> && cali is IParameterMixer)
return new ParameterMixingCalibratedPredictor(env, predWithFeatureScores, cali);
if (needValueMapper)
if (predictor is IValueMapper)
return new CalibratedPredictor(env, predictor, cali);
return new SchemaBindableCalibratedPredictor(env, predictor, cali);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ private TScalarPredictor TrainOne(IChannel ch, TScalarTrainer trainer, RoleMappe
else
calibrator = Args.Calibrator.CreateInstance(Host);
var res = CalibratorUtils.TrainCalibratorIfNeeded(Host, ch, calibrator, Args.MaxCalibrationExamples,
trainer, predictor, td, true);
trainer, predictor, td);
predictor = res as TScalarPredictor;
Host.Check(predictor != null, "Calibrated predictor does not implement the expected interface");
}
Expand Down
59 changes: 59 additions & 0 deletions test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -798,5 +798,64 @@ public void TestOvaMacro()
}
}
}

[Fact]
public void TestOvaMacroWithUncalibratedLearner()
{
var dataPath = GetDataPath(@"iris.txt");
using (var env = new TlcEnvironment(42))
{
// Specify subgraph for OVA
var subGraph = env.CreateExperiment();
var learnerInput = new Trainers.AveragedPerceptronBinaryClassifier { Shuffle = false };
var learnerOutput = subGraph.Add(learnerInput);
// Create pipeline with OVA and multiclass scoring.
var experiment = env.CreateExperiment();
var importInput = new ML.Data.TextLoader(dataPath);
importInput.Arguments.Column = new TextLoaderColumn[]
{
new TextLoaderColumn { Name = "Label", Source = new[] { new TextLoaderRange(0) } },
new TextLoaderColumn { Name = "Features", Source = new[] { new TextLoaderRange(1,4) } }
};
var importOutput = experiment.Add(importInput);
var oneVersusAll = new Models.OneVersusAll
{
TrainingData = importOutput.Data,
Nodes = subGraph,
UseProbabilities = true,
};
var ovaOutput = experiment.Add(oneVersusAll);
var scoreInput = new ML.Transforms.DatasetScorer
{
Data = importOutput.Data,
PredictorModel = ovaOutput.PredictorModel
};
var scoreOutput = experiment.Add(scoreInput);
var evalInput = new ML.Models.ClassificationEvaluator
{
Data = scoreOutput.ScoredData
};
var evalOutput = experiment.Add(evalInput);
experiment.Compile();
experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false));
experiment.Run();

var data = experiment.GetOutput(evalOutput.OverallMetrics);
var schema = data.Schema;
var b = schema.TryGetColumnIndex(MultiClassClassifierEvaluator.AccuracyMacro, out int accCol);
Assert.True(b);
using (var cursor = data.GetRowCursor(col => col == accCol))
{
var getter = cursor.GetGetter<double>(accCol);
b = cursor.MoveNext();
Assert.True(b);
double acc = 0;
getter(ref acc);
Assert.Equal(0.71, acc, 2);
b = cursor.MoveNext();
Assert.False(b);
}
}
}
}
}