diff --git a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestSplit.cs b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestSplit.cs index 8b199045fa..40909ad108 100644 --- a/src/Microsoft.ML/Runtime/EntryPoints/TrainTestSplit.cs +++ b/src/Microsoft.ML/Runtime/EntryPoints/TrainTestSplit.cs @@ -93,7 +93,8 @@ public static string CreateStratificationColumn(IHost host, ref IDataView data, new HashJoinTransform.Arguments { Column = new[] { new HashJoinTransform.Column { Name = stratCol, Source = stratificationColumn } }, - Join = true + Join = true, + HashBits = 30 }, data); } diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs index 66e241163f..439dc069f4 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestCSharpApi.cs @@ -330,5 +330,73 @@ public void TestCrossValidationMacro() } } } + + [Fact] + public void TestCrossValidationMacroWithStratification() + { + var dataPath = GetDataPath(@"breast-cancer.txt"); + using (var env = new TlcEnvironment()) + { + var subGraph = env.CreateExperiment(); + + var nop = new ML.Transforms.NoOperation(); + var nopOutput = subGraph.Add(nop); + + var learnerInput = new ML.Trainers.StochasticDualCoordinateAscentBinaryClassifier + { + TrainingData = nopOutput.OutputData, + NumThreads = 1 + }; + var learnerOutput = subGraph.Add(learnerInput); + + var modelCombine = new ML.Transforms.ManyHeterogeneousModelCombiner + { + TransformModels = new ArrayVar(nopOutput.Model), + PredictorModel = learnerOutput.PredictorModel + }; + var modelCombineOutput = subGraph.Add(modelCombine); + + var experiment = env.CreateExperiment(); + var importInput = new ML.Data.TextLoader(dataPath); + importInput.Arguments.Column = new ML.Data.TextLoaderColumn[] + { + new ML.Data.TextLoaderColumn { Name = "Label", Source = new[] { new ML.Data.TextLoaderRange(0) } }, + new ML.Data.TextLoaderColumn { Name = "Strat", Source = new[] { new ML.Data.TextLoaderRange(1) } }, + new ML.Data.TextLoaderColumn { Name = "Features", Source = new[] { new ML.Data.TextLoaderRange(2, 9) } } + }; + var importOutput = experiment.Add(importInput); + + var crossValidate = new ML.Models.CrossValidator + { + Data = importOutput.Data, + Nodes = subGraph, + TransformModel = null, + StratificationColumn = "Strat" + }; + crossValidate.Inputs.Data = nop.Data; + crossValidate.Outputs.Model = modelCombineOutput.PredictorModel; + var crossValidateOutput = experiment.Add(crossValidate); + + experiment.Compile(); + experiment.SetInput(importInput.InputFile, new SimpleFileHandle(env, dataPath, false, false)); + experiment.Run(); + var data = experiment.GetOutput(crossValidateOutput.OverallMetrics[0]); + + var schema = data.Schema; + var b = schema.TryGetColumnIndex("AUC", out int metricCol); + Assert.True(b); + using (var cursor = data.GetRowCursor(col => col == metricCol)) + { + var getter = cursor.GetGetter(metricCol); + b = cursor.MoveNext(); + Assert.True(b); + double val = 0; + getter(ref val); + Assert.Equal(0.99, val, 2); + b = cursor.MoveNext(); + Assert.False(b); + } + } + } } }