Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoML] Generated project - FastTree nuget package inclusion dynamically #3567

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 19 additions & 12 deletions src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ internal class CodeGenerator : IProjectGenerator
private readonly ColumnInferenceResults columnInferenceResult;
private readonly HashSet<string> LightGBMTrainers = new HashSet<string>() { TrainerName.LightGbmBinary.ToString(), TrainerName.LightGbmMulti.ToString(), TrainerName.LightGbmRegression.ToString() };
private readonly HashSet<string> mklComponentsTrainers = new HashSet<string>() { TrainerName.OlsRegression.ToString(), TrainerName.SymbolicSgdLogisticRegressionBinary.ToString() };
private readonly HashSet<string> FastTreeTrainers = new HashSet<string>() { TrainerName.FastForestBinary.ToString(), TrainerName.FastForestRegression.ToString(), TrainerName.FastTreeBinary.ToString(), TrainerName.FastTreeRegression.ToString(), TrainerName.FastTreeTweedieRegression.ToString() };


internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInferenceResult, CodeGeneratorSettings settings)
{
Expand All @@ -36,15 +38,16 @@ public void GenerateOutput()

bool includeLightGbmPackage = false;
bool includeMklComponentsPackage = false;
SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage);
bool includeFastTreeePackage = false;
SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage, ref includeFastTreeePackage);

// Get Namespace
var namespaceValue = Utils.Normalize(settings.OutputName);
var labelType = columnInferenceResult.TextLoaderOptions.Columns.Where(t => t.Name == columnInferenceResult.ColumnInformation.LabelColumnName).First().DataKind;
Type labelTypeCsharp = Utils.GetCSharpType(labelType);

// Generate Model Project
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage);
var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage);

// Write files to disk.
var modelprojectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.Model");
Expand All @@ -56,7 +59,7 @@ public void GenerateOutput()
Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir);

// Generate ConsoleApp Project
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage);
var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage);

// Write files to disk.
var consoleAppProjectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.ConsoleApp");
Expand All @@ -74,7 +77,7 @@ public void GenerateOutput()
Utils.AddProjectsToSolution(modelprojectDir, modelProjectName, consoleAppProjectDir, consoleAppProjectName, solutionPath);
}

private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage)
private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage, ref bool includeFastTreePackage)
{
foreach (var node in trainerNodes)
{
Expand All @@ -92,15 +95,19 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
{
includeMklComponentsPackage = true;
}
else if (FastTreeTrainers.Contains(currentNode.Name))
{
includeFastTreePackage = true;
}
}
}

internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage)
internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
{
var predictProgramCSFileContent = GeneratePredictProgramCSFileContent(namespaceValue);
predictProgramCSFileContent = Utils.FormatCode(predictProgramCSFileContent);

var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, includeLightGbmPackage, includeMklComponentsPackage);
var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage);

var transformsAndTrainers = GenerateTransformsAndTrainers();
var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent(transformsAndTrainers.Usings, transformsAndTrainers.TrainerMethod, transformsAndTrainers.PreTrainerTransforms, transformsAndTrainers.PostTrainerTransforms, namespaceValue, pipeline.CacheBeforeTrainer, labelTypeCsharp.Name);
Expand All @@ -109,14 +116,14 @@ private void SetRequiredNugetPackages(IEnumerable<PipelineNode> trainerNodes, re
return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent);
}

internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage)
internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
{
var classLabels = this.GenerateClassLabels();
var observationCSFileContent = GenerateObservationCSFileContent(namespaceValue, classLabels);
observationCSFileContent = Utils.FormatCode(observationCSFileContent);
var predictionCSFileContent = GeneratePredictionCSFileContent(labelTypeCsharp.Name, namespaceValue);
predictionCSFileContent = Utils.FormatCode(predictionCSFileContent);
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage);
var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage);
return (observationCSFileContent, predictionCSFileContent, modelProjectFileContent);
}

Expand Down Expand Up @@ -248,9 +255,9 @@ internal IList<string> GenerateClassLabels()
}

#region Model project
private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage)
private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
{
ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage };
ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeFastTreePackage = includeFastTreePackage };
return modelProject.TransformText();
}

Expand All @@ -268,9 +275,9 @@ private string GenerateObservationCSFileContent(string namespaceValue, IList<str
#endregion

#region Predict Project
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage, bool includeMklComponentsPackage)
private static string GeneratPredictProjectFileContent(string namespaceValue, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage)
{
var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGbmPackage };
var predictProjectFileContent = new PredictProject() { Namespace = namespaceValue, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeLightGBMPackage = includeLightGbmPackage, IncludeFastTreePackage = includeFastTreePackage };
return predictProjectFileContent.TransformText();
}

Expand Down
1 change: 0 additions & 1 deletion src/mlnet/CodeGenerator/CSharp/TrainerGeneratorFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
// See the LICENSE file in the project root for more information.

using System;
using System.Linq;
using Microsoft.ML.Auto;
using static Microsoft.ML.CLI.CodeGenerator.CSharp.TrainerGenerators;

Expand Down
2 changes: 0 additions & 2 deletions src/mlnet/CodeGenerator/CSharp/TrainerGenerators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -555,8 +555,6 @@ public override string[] GenerateUsings()
{
return binaryTrainerUsings;
}

}

}
}
6 changes: 3 additions & 3 deletions src/mlnet/Telemetry/MlTelemetry.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public class MlTelemetry

public void SetCommandAndParameters(string command, IEnumerable<string> parameters)
{
if(parameters != null)
if (parameters != null)
{
_parameters.AddRange(parameters);
}
Expand All @@ -28,7 +28,7 @@ public void LogAutoTrainMlCommand(string dataFileName, string task, long dataFil
{
CheckFistTimeUse();

if(!_enabled)
if (!_enabled)
{
return;
}
Expand Down Expand Up @@ -71,7 +71,7 @@ private void CheckFistTimeUse()
@"Welcome to the ML.NET CLI!
--------------------------
Learn more about ML.NET CLI: https://aka.ms/mlnet-cli
Use 'dotnet ml --help' to see available commands or visit: https://aka.ms/mlnet-cli-docs
Use 'mlnet --help' to see available commands or visit: https://aka.ms/mlnet-cli-docs

Telemetry
---------
Expand Down
17 changes: 16 additions & 1 deletion src/mlnet/Templates/Console/ModelProject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,20 @@ public virtual string TransformText()
#line 18 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
}

#line default
#line hidden

#line 19 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
if(IncludeFastTreePackage){

#line default
#line hidden
this.Write(" <PackageReference Include=\"Microsoft.ML.FastTree\" Version=\"1.0.0-preview\" />\r" +
"\n");

#line 21 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
}

#line default
#line hidden
this.Write(" </ItemGroup>\r\n\r\n <ItemGroup>\r\n <None Update=\"MLModel.zip\">\r\n <CopyToOu" +
Expand All @@ -65,10 +79,11 @@ public virtual string TransformText()
return this.GenerationEnvironment.ToString();
}

#line 28 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"
#line 31 "E:\src\machinelearning\src\mlnet\Templates\Console\ModelProject.tt"

public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}


#line default
Expand Down
4 changes: 4 additions & 0 deletions src/mlnet/Templates/Console/ModelProject.tt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
<#}#>
<# if(IncludeMklComponentsPackage){ #>
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
<#}#>
<# if(IncludeFastTreePackage){ #>
<PackageReference Include="Microsoft.ML.FastTree" Version="1.0.0-preview" />
<#}#>
</ItemGroup>

Expand All @@ -28,4 +31,5 @@
<#+
public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}
#>
5 changes: 5 additions & 0 deletions src/mlnet/Templates/Console/PredictProject.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ public virtual string TransformText()
if(IncludeMklComponentsPackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.Mkl.Components\" Version=\"1.0.0-previe" +
"w\" />\r\n");
}
if(IncludeFastTreePackage){
this.Write(" <PackageReference Include=\"Microsoft.ML.FastTree\" Version=\"1.0.0-preview\" />\r" +
"\n");
}
this.Write(" </ItemGroup>\r\n <ItemGroup>\r\n <ProjectReference Include=\"..\\");
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
Expand All @@ -49,6 +53,7 @@ public virtual string TransformText()
public string Namespace {get;set;}
public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}

}
#region Base class
Expand Down
4 changes: 4 additions & 0 deletions src/mlnet/Templates/Console/PredictProject.tt
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
<#}#>
<# if(IncludeMklComponentsPackage){ #>
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
<#}#>
<# if(IncludeFastTreePackage){ #>
<PackageReference Include="Microsoft.ML.FastTree" Version="1.0.0-preview" />
<#}#>
</ItemGroup>
<ItemGroup>
Expand All @@ -27,4 +30,5 @@
public string Namespace {get;set;}
public bool IncludeLightGBMPackage {get;set;}
public bool IncludeMklComponentsPackage {get;set;}
public bool IncludeFastTreePackage {get;set;}
#>
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
<PackageReference Include="Microsoft.ML" Version="1.0.0-preview" />
<PackageReference Include="Microsoft.ML.LightGBM" Version="1.0.0-preview" />
<PackageReference Include="Microsoft.ML.Mkl.Components" Version="1.0.0-preview" />
<PackageReference Include="Microsoft.ML.FastTree" Version="1.0.0-preview" />
</ItemGroup>

<ItemGroup>
Expand Down
16 changes: 8 additions & 8 deletions test/mlnet.Tests/ApprovalTests/ConsoleCodeGeneratorTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public void ConsoleAppModelBuilderCSFileContentOvaTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.modelBuilderCSFileContent);
}
Expand All @@ -65,7 +65,7 @@ public void ConsoleAppModelBuilderCSFileContentBinaryTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.modelBuilderCSFileContent);
}
Expand All @@ -88,7 +88,7 @@ public void ConsoleAppModelBuilderCSFileContentRegressionTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.modelBuilderCSFileContent);
}
Expand All @@ -111,7 +111,7 @@ public void ModelProjectFileContentTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true, true);

Approvals.Verify(result.ModelProjectFileContent);
}
Expand All @@ -134,7 +134,7 @@ public void ObservationCSFileContentTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.ObservationCSFileContent);
}
Expand All @@ -158,7 +158,7 @@ public void PredictionCSFileContentTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateModelProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.PredictionCSFileContent);
}
Expand All @@ -181,7 +181,7 @@ public void ConsoleAppProgramCSFileContentTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.ConsoleAppProgramCSFileContent);
}
Expand All @@ -204,7 +204,7 @@ public void ConsoleAppProjectFileContentTest()
LabelName = "Label",
ModelPath = "x:\\models\\model.zip"
});
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true);
var result = consoleCodeGen.GenerateConsoleAppProjectContents(namespaceValue, typeof(float), true, true, false);

Approvals.Verify(result.ConsoleAppProjectFileContent);
}
Expand Down