Skip to content

Commit

Permalink
Modify API for advanced settings. (SDCA) (#2093)
Browse files Browse the repository at this point in the history
* dummy change to test build

* StochasticGradientDescentClassificationTrainer

* SdcaBinaryTrainer, SdcaMultiClassTrainer, SdcaRegressionTrainer

* review comments

* review comments

* review comments

* re-enable the TF tests + add more checks for options argument

* merge with latest master to pick up overload of

* review comments

* review comments
  • Loading branch information
abgoswam authored Jan 18, 2019
1 parent f81ef79 commit b65d725
Show file tree
Hide file tree
Showing 35 changed files with 561 additions and 255 deletions.
17 changes: 8 additions & 9 deletions docs/samples/Microsoft.ML.Samples/Dynamic/SDCA.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Linq;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;

namespace Microsoft.ML.Samples.Dynamic
{
Expand Down Expand Up @@ -59,15 +60,13 @@ public static void SDCA_BinaryClassification()
// If we wanted to specify more advanced parameters for the algorithm,
// we could do so by tweaking the 'advancedSetting'.
var advancedPipeline = mlContext.Transforms.Text.FeaturizeText("SentimentText", "Features")
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent
(labelColumn: "Sentiment",
featureColumn: "Features",
advancedSettings: s=>
{
s.ConvergenceTolerance = 0.01f; // The learning rate for adjusting bias from being regularized
s.NumThreads = 2; // Degree of lock-free parallelism
})
);
.Append(mlContext.BinaryClassification.Trainers.StochasticDualCoordinateAscent(
new SdcaBinaryTrainer.Options {
LabelColumn = "Sentiment",
FeatureColumn = "Features",
ConvergenceTolerance = 0.01f, // The learning rate for adjusting bias from being regularized
NumThreads = 2, // Degree of lock-free parallelism
}));

// Run Cross-Validation on this second pipeline.
var cvResults_advancedPipeline = mlContext.BinaryClassification.CrossValidate(data, pipeline, labelColumn: "Sentiment", numFolds: 3);
Expand Down
12 changes: 12 additions & 0 deletions src/Microsoft.ML.FastTree/TreeTrainersCatalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ public static FastTreeRegressionTrainer FastTree(this RegressionContext.Regressi
FastTreeRegressionTrainer.Options options)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(options, nameof(options));

var env = CatalogUtils.GetEnvironment(ctx);
return new FastTreeRegressionTrainer(env, options);
}
Expand Down Expand Up @@ -85,6 +87,8 @@ public static FastTreeBinaryClassificationTrainer FastTree(this BinaryClassifica
FastTreeBinaryClassificationTrainer.Options options)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(options, nameof(options));

var env = CatalogUtils.GetEnvironment(ctx);
return new FastTreeBinaryClassificationTrainer(env, options);
}
Expand Down Expand Up @@ -125,6 +129,8 @@ public static FastTreeRankingTrainer FastTree(this RankingContext.RankingTrainer
FastTreeRankingTrainer.Options options)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(options, nameof(options));

var env = CatalogUtils.GetEnvironment(ctx);
return new FastTreeRankingTrainer(env, options);
}
Expand Down Expand Up @@ -213,6 +219,8 @@ public static FastTreeTweedieTrainer FastTreeTweedie(this RegressionContext.Regr
FastTreeTweedieTrainer.Options options)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(options, nameof(options));

var env = CatalogUtils.GetEnvironment(ctx);
return new FastTreeTweedieTrainer(env, options);
}
Expand Down Expand Up @@ -251,6 +259,8 @@ public static FastForestRegression FastForest(this RegressionContext.RegressionT
FastForestRegression.Options options)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(options, nameof(options));

var env = CatalogUtils.GetEnvironment(ctx);
return new FastForestRegression(env, options);
}
Expand Down Expand Up @@ -289,6 +299,8 @@ public static FastForestClassification FastForest(this BinaryClassificationConte
FastForestClassification.Options options)
{
Contracts.CheckValue(ctx, nameof(ctx));
Contracts.CheckValue(options, nameof(options));

var env = CatalogUtils.GetEnvironment(ctx);
return new FastForestClassification(env, options);
}
Expand Down
143 changes: 63 additions & 80 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaBinary.cs

Large diffs are not rendered by default.

33 changes: 14 additions & 19 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaMultiClass.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
using Microsoft.ML.Training;
using Float = System.Single;

[assembly: LoadableClass(SdcaMultiClassTrainer.Summary, typeof(SdcaMultiClassTrainer), typeof(SdcaMultiClassTrainer.Arguments),
[assembly: LoadableClass(SdcaMultiClassTrainer.Summary, typeof(SdcaMultiClassTrainer), typeof(SdcaMultiClassTrainer.Options),
new[] { typeof(SignatureMultiClassClassifierTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
SdcaMultiClassTrainer.UserNameValue,
SdcaMultiClassTrainer.LoadNameValue,
Expand All @@ -29,14 +29,14 @@ namespace Microsoft.ML.Trainers
{
// SDCA linear multiclass trainer.
/// <include file='doc.xml' path='doc/members/member[@name="SDCA"]/*' />
public class SdcaMultiClassTrainer : SdcaTrainerBase<SdcaMultiClassTrainer.Arguments, MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters>
public class SdcaMultiClassTrainer : SdcaTrainerBase<SdcaMultiClassTrainer.Options, MulticlassPredictionTransformer<MulticlassLogisticRegressionModelParameters>, MulticlassLogisticRegressionModelParameters>
{
public const string LoadNameValue = "SDCAMC";
public const string UserNameValue = "Fast Linear Multi-class Classification (SA-SDCA)";
public const string ShortName = "sasdcamc";
internal const string Summary = "The SDCA linear multi-class classification trainer.";

public sealed class Arguments : ArgumentsBase
public sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
public ISupportSdcaClassificationLossFactory LossFunction = new LogLossFactory();
Expand All @@ -57,41 +57,36 @@ public sealed class Arguments : ArgumentsBase
/// <param name="l2Const">The L2 regularization hyperparameter.</param>
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct method signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
public SdcaMultiClassTrainer(IHostEnvironment env,
internal SdcaMultiClassTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
ISupportSdcaClassificationLoss loss = null,
float? l2Const = null,
float? l1Threshold = null,
int? maxIterations = null,
Action<Arguments> advancedSettings = null)
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), advancedSettings,
l2Const, l1Threshold, maxIterations)
int? maxIterations = null)
: base(env, featureColumn, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights),
l2Const, l1Threshold, maxIterations)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
_loss = loss ?? Args.LossFunction.CreateComponent(env);
Loss = _loss;
}

internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args,
internal SdcaMultiClassTrainer(IHostEnvironment env, Options options,
string featureColumn, string labelColumn, string weightColumn = null)
: base(env, args, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
: base(env, options, TrainerUtils.MakeU4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
{
Host.CheckValue(labelColumn, nameof(labelColumn));
Host.CheckValue(featureColumn, nameof(featureColumn));

_loss = args.LossFunction.CreateComponent(env);
_loss = options.LossFunction.CreateComponent(env);
Loss = _loss;
}

internal SdcaMultiClassTrainer(IHostEnvironment env, Arguments args)
: this(env, args, args.FeatureColumn, args.LabelColumn)
internal SdcaMultiClassTrainer(IHostEnvironment env, Options options)
: this(env, options, options.FeatureColumn, options.LabelColumn)
{
}

Expand Down Expand Up @@ -455,14 +450,14 @@ public static partial class Sdca
ShortName = SdcaMultiClassTrainer.ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />",
@"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentClassifier""]/*' />" })]
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, SdcaMultiClassTrainer.Arguments input)
public static CommonOutputs.MulticlassClassificationOutput TrainMultiClass(IHostEnvironment env, SdcaMultiClassTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainSDCA");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<SdcaMultiClassTrainer.Arguments, CommonOutputs.MulticlassClassificationOutput>(host, input,
return LearnerEntryPointsUtils.Train<SdcaMultiClassTrainer.Options, CommonOutputs.MulticlassClassificationOutput>(host, input,
() => new SdcaMultiClassTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}
Expand Down
35 changes: 15 additions & 20 deletions src/Microsoft.ML.StandardLearners/Standard/SdcaRegression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
using Microsoft.ML.Trainers;
using Microsoft.ML.Training;

[assembly: LoadableClass(SdcaRegressionTrainer.Summary, typeof(SdcaRegressionTrainer), typeof(SdcaRegressionTrainer.Arguments),
[assembly: LoadableClass(SdcaRegressionTrainer.Summary, typeof(SdcaRegressionTrainer), typeof(SdcaRegressionTrainer.Options),
new[] { typeof(SignatureRegressorTrainer), typeof(SignatureTrainer), typeof(SignatureFeatureScorerTrainer) },
SdcaRegressionTrainer.UserNameValue,
SdcaRegressionTrainer.LoadNameValue,
Expand All @@ -24,19 +24,19 @@
namespace Microsoft.ML.Trainers
{
/// <include file='doc.xml' path='doc/members/member[@name="SDCA"]/*' />
public sealed class SdcaRegressionTrainer : SdcaTrainerBase<SdcaRegressionTrainer.Arguments, RegressionPredictionTransformer<LinearRegressionModelParameters>, LinearRegressionModelParameters>
public sealed class SdcaRegressionTrainer : SdcaTrainerBase<SdcaRegressionTrainer.Options, RegressionPredictionTransformer<LinearRegressionModelParameters>, LinearRegressionModelParameters>
{
internal const string LoadNameValue = "SDCAR";
internal const string UserNameValue = "Fast Linear Regression (SA-SDCA)";
internal const string ShortName = "sasdcar";
internal const string Summary = "The SDCA linear regression trainer.";

public sealed class Arguments : ArgumentsBase
public sealed class Options : ArgumentsBase
{
[Argument(ArgumentType.Multiple, HelpText = "Loss Function", ShortName = "loss", SortOrder = 50)]
public ISupportSdcaRegressionLossFactory LossFunction = new SquaredLossFactory();

public Arguments()
public Options()
{
// Using a higher default tolerance for better RMS.
ConvergenceTolerance = 0.01f;
Expand All @@ -61,40 +61,35 @@ public Arguments()
/// <param name="l2Const">The L2 regularization hyperparameter.</param>
/// <param name="l1Threshold">The L1 regularization hyperparameter. Higher values will tend to lead to more sparse model.</param>
/// <param name="maxIterations">The maximum number of passes to perform over the data.</param>
/// <param name="advancedSettings">A delegate to set more settings.
/// The settings here will override the ones provided in the direct method signature,
/// if both are present and have different values.
/// The columns names, however need to be provided directly, not through the <paramref name="advancedSettings"/>.</param>
public SdcaRegressionTrainer(IHostEnvironment env,
internal SdcaRegressionTrainer(IHostEnvironment env,
string labelColumn = DefaultColumnNames.Label,
string featureColumn = DefaultColumnNames.Features,
string weights = null,
ISupportSdcaRegressionLoss loss = null,
float? l2Const = null,
float? l1Threshold = null,
int? maxIterations = null,
Action<Arguments> advancedSettings = null)
: base(env, featureColumn, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights), advancedSettings,
l2Const, l1Threshold, maxIterations)
int? maxIterations = null)
: base(env, featureColumn, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weights),
l2Const, l1Threshold, maxIterations)
{
Host.CheckNonEmpty(featureColumn, nameof(featureColumn));
Host.CheckNonEmpty(labelColumn, nameof(labelColumn));
_loss = loss ?? Args.LossFunction.CreateComponent(env);
Loss = _loss;
}

internal SdcaRegressionTrainer(IHostEnvironment env, Arguments args, string featureColumn, string labelColumn, string weightColumn = null)
: base(env, args, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
internal SdcaRegressionTrainer(IHostEnvironment env, Options options, string featureColumn, string labelColumn, string weightColumn = null)
: base(env, options, TrainerUtils.MakeR4ScalarColumn(labelColumn), TrainerUtils.MakeR4ScalarWeightColumn(weightColumn))
{
Host.CheckValue(labelColumn, nameof(labelColumn));
Host.CheckValue(featureColumn, nameof(featureColumn));

_loss = args.LossFunction.CreateComponent(env);
_loss = options.LossFunction.CreateComponent(env);
Loss = _loss;
}

internal SdcaRegressionTrainer(IHostEnvironment env, Arguments args)
: this(env, args, args.FeatureColumn, args.LabelColumn)
internal SdcaRegressionTrainer(IHostEnvironment env, Options options)
: this(env, options, options.FeatureColumn, options.LabelColumn)
{
}

Expand Down Expand Up @@ -178,14 +173,14 @@ public static partial class Sdca
ShortName = SdcaRegressionTrainer.ShortName,
XmlInclude = new[] { @"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/member[@name=""SDCA""]/*' />",
@"<include file='../Microsoft.ML.StandardLearners/Standard/doc.xml' path='doc/members/example[@name=""StochasticDualCoordinateAscentRegressor""]/*' />" })]
public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, SdcaRegressionTrainer.Arguments input)
public static CommonOutputs.RegressionOutput TrainRegression(IHostEnvironment env, SdcaRegressionTrainer.Options input)
{
Contracts.CheckValue(env, nameof(env));
var host = env.Register("TrainSDCA");
host.CheckValue(input, nameof(input));
EntryPointUtils.CheckInputArgs(host, input);

return LearnerEntryPointsUtils.Train<SdcaRegressionTrainer.Arguments, CommonOutputs.RegressionOutput>(host, input,
return LearnerEntryPointsUtils.Train<SdcaRegressionTrainer.Options, CommonOutputs.RegressionOutput>(host, input,
() => new SdcaRegressionTrainer(host, input),
() => LearnerEntryPointsUtils.FindColumn(host, input.TrainingData.Schema, input.LabelColumn));
}
Expand Down
Loading

0 comments on commit b65d725

Please sign in to comment.