Skip to content

Commit

Permalink
Add caching (dotnet#249)
Browse files Browse the repository at this point in the history
  • Loading branch information
daholste authored Mar 7, 2019
1 parent 551a7a1 commit 3326539
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 24 deletions.
8 changes: 7 additions & 1 deletion src/Microsoft.ML.Auto/API/ExperimentSettings.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@ public class ExperimentSettings
public uint MaxExperimentTimeInSeconds { get; set; } = 24 * 60 * 60;
public CancellationToken CancellationToken { get; set; } = default;

internal bool EnableCaching;
/// <summary>
/// This setting controls whether or not an AutoML experiment will make use of ML.NET-provided caching.
/// If set to true, caching will be forced on for all pipelines. If set to false, caching will be forced off.
/// If set to null (default value), AutoML will decide whether to enable caching for each model.
/// </summary>
public bool? EnableCaching = null;

internal int MaxModels = int.MaxValue;
internal IDebugLogger DebugLogger;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Microsoft.ML.Auto/Experiment/Experiment.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ public List<RunResult<T>> Execute()
var getPiplelineStopwatch = Stopwatch.StartNew();

// get next pipeline
pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, columns, _task, _optimizingMetricInfo.IsMaximizing, _trainerWhitelist);
pipeline = PipelineSuggester.GetNextInferredPipeline(_context, _history, columns, _task, _optimizingMetricInfo.IsMaximizing, _trainerWhitelist, _experimentSettings.EnableCaching);

getPiplelineStopwatch.Stop();

Expand Down
18 changes: 13 additions & 5 deletions src/Microsoft.ML.Auto/Experiment/SuggestedPipeline.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Microsoft.Data.DataView;
using Microsoft.ML.Data;

namespace Microsoft.ML.Auto
Expand All @@ -16,20 +15,24 @@ namespace Microsoft.ML.Auto
/// </summary>
internal class SuggestedPipeline
{
private readonly MLContext _context;
public readonly IList<SuggestedTransform> Transforms;
public readonly SuggestedTrainer Trainer;

private readonly MLContext _context;
private readonly bool? _enableCaching;

public SuggestedPipeline(IEnumerable<SuggestedTransform> transforms,
SuggestedTrainer trainer,
MLContext context,
bool? enableCaching,
bool autoNormalize = true)
{
Transforms = transforms.Select(t => t.Clone()).ToList();
Trainer = trainer.Clone();
_context = context;
_enableCaching = enableCaching;

if(autoNormalize)
if (autoNormalize)
{
AddNormalizationTransforms();
}
Expand Down Expand Up @@ -88,7 +91,7 @@ public static SuggestedPipeline FromPipeline(MLContext context, Pipeline pipelin
}
}

return new SuggestedPipeline(transforms, trainer, context, false);
return new SuggestedPipeline(transforms, trainer, context, null);
}

public IEstimator<ITransformer> ToEstimator()
Expand All @@ -107,6 +110,11 @@ public IEstimator<ITransformer> ToEstimator()
// Get learner
var learner = Trainer.BuildTrainer();

if (_enableCaching == true || (_enableCaching == null && learner.Info.WantCaching))
{
pipeline = pipeline.AppendCacheCheckpoint(_context);
}

// Append learner to pipeline
pipeline = pipeline.Append(learner);

Expand All @@ -128,4 +136,4 @@ private void AddNormalizationTransforms()
Transforms.Add(transform);
}
}
}
}
12 changes: 7 additions & 5 deletions src/Microsoft.ML.Auto/PipelineSuggesters/PipelineSuggester.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
(string, DataViewType, ColumnPurpose, ColumnDimensions)[] columns,
TaskKind task,
bool isMaximizingMetric,
IEnumerable<TrainerName> trainerWhitelist = null)
IEnumerable<TrainerName> trainerWhitelist = null,
bool? _enableCaching = null)
{
var availableTrainers = RecipeInference.AllowedTrainers(context, task,
ColumnInformationUtil.BuildColumnInfo(columns), trainerWhitelist);
Expand All @@ -40,7 +41,7 @@ public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
// if we haven't run all pipelines once
if (history.Count() < availableTrainers.Count())
{
return GetNextFirstStagePipeline(context, history, availableTrainers, transforms);
return GetNextFirstStagePipeline(context, history, availableTrainers, transforms, _enableCaching);
}

// get top trainers from stage 1 runs
Expand Down Expand Up @@ -71,7 +72,7 @@ public static SuggestedPipeline GetNextInferredPipeline(MLContext context,
break;
}

var suggestedPipeline = new SuggestedPipeline(transforms, newTrainer, context);
var suggestedPipeline = new SuggestedPipeline(transforms, newTrainer, context, _enableCaching);

// make sure we have not seen pipeline before
if (!visitedPipelines.Contains(suggestedPipeline))
Expand Down Expand Up @@ -117,10 +118,11 @@ private static IEnumerable<SuggestedTrainer> OrderTrainersByNumTrials(IEnumerabl
private static SuggestedPipeline GetNextFirstStagePipeline(MLContext context,
IEnumerable<SuggestedPipelineResult> history,
IEnumerable<SuggestedTrainer> availableTrainers,
IEnumerable<SuggestedTransform> transforms)
IEnumerable<SuggestedTransform> transforms,
bool? _enableCaching)
{
var trainer = availableTrainers.ElementAt(history.Count());
return new SuggestedPipeline(transforms, trainer, context);
return new SuggestedPipeline(transforms, trainer, context, _enableCaching);
}

private static IValueGenerator[] ConvertToValueGenerators(IEnumerable<SweepableParam> hps)
Expand Down
20 changes: 10 additions & 10 deletions src/Test/InferredPipelineTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,43 +22,43 @@ public void InferredPipelinesHashTest()
var trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo);
var transforms1 = new List<SuggestedTransform>();
var transforms2 = new List<SuggestedTransform>();
var inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
var inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
var inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
var inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);
Assert.AreEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode());

// test same learners with hyperparams set vs empty hyperparams have different hash codes
var hyperparams1 = new ParameterSet(new List<IParameterValue>() { new LongParameterValue("NumLeaves", 2) });
trainer1 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo, hyperparams1);
trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo);
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);
Assert.AreNotEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode());

// same learners with different hyperparams
hyperparams1 = new ParameterSet(new List<IParameterValue>() { new LongParameterValue("NumLeaves", 2) });
var hyperparams2 = new ParameterSet(new List<IParameterValue>() { new LongParameterValue("NumLeaves", 6) });
trainer1 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo, hyperparams1);
trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo, hyperparams2);
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);
Assert.AreNotEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode());

// same learners with same transforms
trainer1 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo);
trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo);
transforms1 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
transforms2 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);
Assert.AreEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode());

// same transforms with different learners
trainer1 = new SuggestedTrainer(context, new SdcaBinaryExtension(), columnInfo);
trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), columnInfo);
transforms1 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
transforms2 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);
Assert.AreNotEqual(inferredPipeline1.GetHashCode(), inferredPipeline2.GetHashCode());
}
}
Expand Down
4 changes: 2 additions & 2 deletions src/mlnet.Test/ApprovalTests/ConsoleCodeGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,8 @@ public void GeneratedHelperCodeTest()
var trainer2 = new SuggestedTrainer(context, new LightGbmBinaryExtension(), new ColumnInformation(), hyperparams2);
var transforms1 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
var transforms2 = new List<SuggestedTransform>() { ColumnConcatenatingExtension.CreateSuggestedTransform(context, new[] { "In" }, "Out") };
var inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context);
var inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context);
var inferredPipeline1 = new SuggestedPipeline(transforms1, trainer1, context, null);
var inferredPipeline2 = new SuggestedPipeline(transforms2, trainer2, context, null);

this.pipeline = inferredPipeline1.ToPipeline();
var textLoaderArgs = new TextLoader.Options()
Expand Down

0 comments on commit 3326539

Please sign in to comment.