-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding support for training metrics in PipelineSweeperMacro + new graph variable outputs #152
Changes from 1 commit
05482eb
b789159
76dd270
483a50b
d42a522
fbee51c
0285c87
6f92693
e006ee1
90d96a3
1d1c303
ce9792f
a828266
c642f68
e066db0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,16 +21,24 @@ public sealed class PipelineResultRow | |
{ | ||
public string GraphJson { get; } | ||
public double MetricValue { get; } | ||
public double TrainingMetricValue { get; } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
can you add some comments to this property. What's the difference between MetricValue and TrainingMetricValue? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I've added some comments above the properties, explaining them. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. XML In reply to: 188105054 [](ancestors = 188105054) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll make that change. #Closed |
||
public string PipelineId { get; } | ||
public string FirstInput { get; } | ||
public string PredictorModel { get; } | ||
|
||
public PipelineResultRow() | ||
{ } | ||
|
||
public PipelineResultRow(string graphJson, double metricValue, string pipelineId) | ||
public PipelineResultRow(string graphJson, double metricValue, | ||
string pipelineId, double trainingMetricValue, string firstInput, | ||
string predictorModel) | ||
{ | ||
GraphJson = graphJson; | ||
MetricValue = metricValue; | ||
PipelineId = pipelineId; | ||
TrainingMetricValue = trainingMetricValue; | ||
FirstInput = firstInput; | ||
PredictorModel = predictorModel; | ||
} | ||
} | ||
|
||
|
@@ -111,7 +119,8 @@ public AutoInference.EntryPointGraphDef ToEntryPointGraph(Experiment experiment | |
public bool Equals(PipelinePattern obj) => obj != null && UniqueId == obj.UniqueId; | ||
|
||
// REVIEW: We may want to allow for sweeping with CV in the future, so we will need to add new methods like this, or refactor these in that case. | ||
public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testData, MacroUtils.TrainerKinds trainerKind, out Models.TrainTestEvaluator.Output resultsOutput) | ||
public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testData, MacroUtils.TrainerKinds trainerKind, | ||
bool includeTrainingMetrics, out Models.TrainTestEvaluator.Output resultsOutput) | ||
{ | ||
var graphDef = ToEntryPointGraph(); | ||
var subGraph = graphDef.Graph; | ||
|
@@ -136,7 +145,8 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD | |
Model = finalOutput | ||
}, | ||
PipelineId = UniqueId.ToString("N"), | ||
Kind = MacroUtils.TrainerKindApiValue<Models.MacroUtilsTrainerKinds>(trainerKind) | ||
Kind = MacroUtils.TrainerKindApiValue<Models.MacroUtilsTrainerKinds>(trainerKind), | ||
IncludeTrainingMetrics = includeTrainingMetrics | ||
}; | ||
|
||
var experiment = _env.CreateExperiment(); | ||
|
@@ -150,7 +160,7 @@ public Experiment CreateTrainTestExperiment(IDataView trainData, IDataView testD | |
} | ||
|
||
public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData, Var<IDataView> testData, | ||
MacroUtils.TrainerKinds trainerKind, Experiment experiment = null) | ||
MacroUtils.TrainerKinds trainerKind, Experiment experiment = null, bool includeTrainingMetrics = false) | ||
{ | ||
experiment = experiment ?? _env.CreateExperiment(); | ||
var graphDef = ToEntryPointGraph(experiment); | ||
|
@@ -174,7 +184,8 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData, | |
TrainingData = trainData, | ||
TestingData = testData, | ||
Kind = MacroUtils.TrainerKindApiValue<Models.MacroUtilsTrainerKinds>(trainerKind), | ||
PipelineId = UniqueId.ToString("N") | ||
PipelineId = UniqueId.ToString("N"), | ||
IncludeTrainingMetrics = includeTrainingMetrics | ||
}; | ||
var trainTestOutput = experiment.Add(trainTestInput); | ||
return trainTestOutput; | ||
|
@@ -183,34 +194,58 @@ public Models.TrainTestEvaluator.Output AddAsTrainTest(Var<IDataView> trainData, | |
/// <summary> | ||
/// Runs a train-test experiment on the current pipeline, through entrypoints. | ||
/// </summary> | ||
public double RunTrainTestExperiment(IDataView trainData, IDataView testData, AutoInference.SupportedMetric metric, MacroUtils.TrainerKinds trainerKind) | ||
public void RunTrainTestExperiment(IDataView trainData, IDataView testData, | ||
AutoInference.SupportedMetric metric, MacroUtils.TrainerKinds trainerKind, out double testMetricValue, | ||
out double trainMetricValue) | ||
{ | ||
var experiment = CreateTrainTestExperiment(trainData, testData, trainerKind, out var trainTestOutput); | ||
var experiment = CreateTrainTestExperiment(trainData, testData, trainerKind, true, out var trainTestOutput); | ||
experiment.Run(); | ||
|
||
var dataOut = experiment.GetOutput(trainTestOutput.OverallMetrics); | ||
var schema = dataOut.Schema; | ||
schema.TryGetColumnIndex(metric.Name, out var metricCol); | ||
double metricValue = 0; | ||
double trainingMetricValue = 0; | ||
|
||
using (var cursor = dataOut.GetRowCursor(col => col == metricCol)) | ||
{ | ||
var getter = cursor.GetGetter<double>(metricCol); | ||
double metricValue = 0; | ||
cursor.MoveNext(); | ||
getter(ref metricValue); | ||
return metricValue; | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you use this code pattern twice here, and once in AutoMlUtils, maybe it make sense to refactor it to a method? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good suggestion. I'll make that change. |
||
|
||
dataOut = experiment.GetOutput(trainTestOutput.TrainingOverallMetrics); | ||
schema = dataOut.Schema; | ||
schema.TryGetColumnIndex(metric.Name, out metricCol); | ||
|
||
using (var cursor = dataOut.GetRowCursor(col => col == metricCol)) | ||
{ | ||
var getter = cursor.GetGetter<double>(metricCol); | ||
cursor.MoveNext(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We should validate that this works. We should also validate that the next call to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Made changes, will push new version. |
||
getter(ref trainingMetricValue); | ||
testMetricValue = metricValue; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
why it's part of this code block and not one where you set value for metric value? #Closed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Training metrics are optional, so they will not always be present. |
||
trainMetricValue = trainingMetricValue; | ||
} | ||
} | ||
|
||
public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView data, string graphColName, string metricColName, string idColName) | ||
public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView data, | ||
string graphColName, string metricColName, string idColName, string trainingMetricColName, | ||
string firstInputColName, string predictorModelColName) | ||
{ | ||
var results = new List<PipelineResultRow>(); | ||
var schema = data.Schema; | ||
if (!schema.TryGetColumnIndex(graphColName, out var graphCol)) | ||
throw env.ExceptNotSupp($"Column name {graphColName} not found"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm not sure "Not supported" is an appropriate exception here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will make these changes. Thanks. |
||
if (!schema.TryGetColumnIndex(metricColName, out var metricCol)) | ||
throw env.ExceptNotSupp($"Column name {metricColName} not found"); | ||
if (!schema.TryGetColumnIndex(trainingMetricColName, out var trainingMetricCol)) | ||
throw env.ExceptNotSupp($"Column name {trainingMetricColName} not found"); | ||
if (!schema.TryGetColumnIndex(idColName, out var pipelineIdCol)) | ||
throw env.ExceptNotSupp($"Column name {idColName} not found"); | ||
if (!schema.TryGetColumnIndex(firstInputColName, out var firstInputCol)) | ||
throw env.ExceptNotSupp($"Column name {firstInputColName} not found"); | ||
if (!schema.TryGetColumnIndex(predictorModelColName, out var predictorModelCol)) | ||
throw env.ExceptNotSupp($"Column name {predictorModelColName} not found"); | ||
|
||
using (var cursor = data.GetRowCursor(col => true)) | ||
{ | ||
|
@@ -225,15 +260,33 @@ public static PipelineResultRow[] ExtractResults(IHostEnvironment env, IDataView | |
var getter3 = cursor.GetGetter<DvText>(pipelineIdCol); | ||
DvText pipelineId = new DvText(); | ||
getter3(ref pipelineId); | ||
results.Add(new PipelineResultRow(graphJson.ToString(), metricValue, pipelineId.ToString())); | ||
var getter4 = cursor.GetGetter<double>(trainingMetricCol); | ||
double trainingMetricValue = 0; | ||
getter4(ref trainingMetricValue); | ||
var getter5 = cursor.GetGetter<DvText>(firstInputCol); | ||
DvText firstInput = new DvText(); | ||
getter5(ref firstInput); | ||
var getter6 = cursor.GetGetter<DvText>(predictorModelCol); | ||
DvText predictorModel = new DvText(); | ||
getter6(ref predictorModel); | ||
|
||
results.Add(new PipelineResultRow(graphJson.ToString(), | ||
metricValue, pipelineId.ToString(), trainingMetricValue, | ||
firstInput.ToString(), predictorModel.ToString())); | ||
} | ||
} | ||
|
||
return results.ToArray(); | ||
} | ||
|
||
public PipelineResultRow ToResultRow() => | ||
new PipelineResultRow(ToEntryPointGraph().Graph.ToJsonString(), | ||
PerformanceSummary?.MetricValue ?? -1d, UniqueId.ToString("N")); | ||
public PipelineResultRow ToResultRow() { | ||
var graphDef = ToEntryPointGraph(); | ||
|
||
return new PipelineResultRow($"{{'Nodes' : [{graphDef.Graph.ToJsonString()}]}}", | ||
PerformanceSummary?.MetricValue ?? -1d, UniqueId.ToString("N"), | ||
PerformanceSummary?.TrainingMetricValue ?? -1d, | ||
graphDef.GetSubgraphFirstNodeDataVarName(_env), | ||
graphDef.ModelOutput.VarName); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering if it's possible to tighten up this method a bit, in terms of its handling of inputs other than "happy path" inputs.
To give an example: if
metricColumnName
is not a column intrainResult
, thenTryGetColumnIndex
will return false. However, we merely assume it succeeds, buttrainingMetricCol
will hold some undertermined value -- I guessdefault(int)
. And it will happily extract that (assuming it was of typedouble
).Since these a public methods on a public class these would be
env.Check*
style checks. (If it were non-public though we'd still prefer to have at least asserts.)Ideally we'd also make the other methods in this a bit more robust, but perhaps that's a bit beyond the scope of the PR. Might as well start here though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. I'll make the corresponding changes.