From 824e36a30dc376dfbaf690794f066dea61b62927 Mon Sep 17 00:00:00 2001 From: Srujan Saggam Date: Wed, 24 Apr 2019 13:28:04 -0700 Subject: [PATCH 1/6] added support for fast tree nuget pack inclusion in generated project --- .../CodeGenerator/CSharp/CodeGenerator.cs | 31 ++++++++++++------- src/mlnet/Templates/Console/ModelProject.cs | 17 +++++++++- src/mlnet/Templates/Console/ModelProject.tt | 4 +++ src/mlnet/Templates/Console/PredictProject.cs | 5 +++ src/mlnet/Templates/Console/PredictProject.tt | 4 +++ .../ConsoleCodeGeneratorTests.cs | 16 +++++----- 6 files changed, 56 insertions(+), 21 deletions(-) diff --git a/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs b/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs index bb0fc079d5..c816852f0a 100644 --- a/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs +++ b/src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs @@ -21,6 +21,8 @@ internal class CodeGenerator : IProjectGenerator private readonly ColumnInferenceResults columnInferenceResult; private readonly HashSet LightGBMTrainers = new HashSet() { TrainerName.LightGbmBinary.ToString(), TrainerName.LightGbmMulti.ToString(), TrainerName.LightGbmRegression.ToString() }; private readonly HashSet mklComponentsTrainers = new HashSet() { TrainerName.OlsRegression.ToString(), TrainerName.SymbolicSgdLogisticRegressionBinary.ToString() }; + private readonly HashSet FastTreeTrainers = new HashSet() { TrainerName.FastForestBinary.ToString(), TrainerName.FastForestRegression.ToString(), TrainerName.FastTreeBinary.ToString(), TrainerName.FastTreeRegression.ToString(), TrainerName.FastTreeTweedieRegression.ToString() }; + internal CodeGenerator(Pipeline pipeline, ColumnInferenceResults columnInferenceResult, CodeGeneratorSettings settings) { @@ -36,7 +38,8 @@ public void GenerateOutput() bool includeLightGbmPackage = false; bool includeMklComponentsPackage = false; - SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage); + bool includeFastTreeePackage = false; + SetRequiredNugetPackages(trainerNodes, ref includeLightGbmPackage, ref includeMklComponentsPackage, ref includeFastTreeePackage); // Get Namespace var namespaceValue = Utils.Normalize(settings.OutputName); @@ -44,7 +47,7 @@ public void GenerateOutput() Type labelTypeCsharp = Utils.GetCSharpType(labelType); // Generate Model Project - var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage); + var modelProjectContents = GenerateModelProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage); // Write files to disk. var modelprojectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.Model"); @@ -56,7 +59,7 @@ public void GenerateOutput() Utils.WriteOutputToFiles(modelProjectContents.ModelProjectFileContent, modelProjectName, modelprojectDir); // Generate ConsoleApp Project - var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage); + var consoleAppProjectContents = GenerateConsoleAppProjectContents(namespaceValue, labelTypeCsharp, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreeePackage); // Write files to disk. var consoleAppProjectDir = Path.Combine(settings.OutputBaseDir, $"{settings.OutputName}.ConsoleApp"); @@ -74,7 +77,7 @@ public void GenerateOutput() Utils.AddProjectsToSolution(modelprojectDir, modelProjectName, consoleAppProjectDir, consoleAppProjectName, solutionPath); } - private void SetRequiredNugetPackages(IEnumerable trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage) + private void SetRequiredNugetPackages(IEnumerable trainerNodes, ref bool includeLightGbmPackage, ref bool includeMklComponentsPackage, ref bool includeFastTreePackage) { foreach (var node in trainerNodes) { @@ -92,15 +95,19 @@ private void SetRequiredNugetPackages(IEnumerable trainerNodes, re { includeMklComponentsPackage = true; } + else if (FastTreeTrainers.Contains(currentNode.Name)) + { + includeFastTreePackage = true; + } } } - internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage) + internal (string ConsoleAppProgramCSFileContent, string ConsoleAppProjectFileContent, string modelBuilderCSFileContent) GenerateConsoleAppProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage) { var predictProgramCSFileContent = GeneratePredictProgramCSFileContent(namespaceValue); predictProgramCSFileContent = Utils.FormatCode(predictProgramCSFileContent); - var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, includeLightGbmPackage, includeMklComponentsPackage); + var predictProjectFileContent = GeneratPredictProjectFileContent(namespaceValue, includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage); var transformsAndTrainers = GenerateTransformsAndTrainers(); var modelBuilderCSFileContent = GenerateModelBuilderCSFileContent(transformsAndTrainers.Usings, transformsAndTrainers.TrainerMethod, transformsAndTrainers.PreTrainerTransforms, transformsAndTrainers.PostTrainerTransforms, namespaceValue, pipeline.CacheBeforeTrainer, labelTypeCsharp.Name); @@ -109,14 +116,14 @@ private void SetRequiredNugetPackages(IEnumerable trainerNodes, re return (predictProgramCSFileContent, predictProjectFileContent, modelBuilderCSFileContent); } - internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage) + internal (string ObservationCSFileContent, string PredictionCSFileContent, string ModelProjectFileContent) GenerateModelProjectContents(string namespaceValue, Type labelTypeCsharp, bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage) { var classLabels = this.GenerateClassLabels(); var observationCSFileContent = GenerateObservationCSFileContent(namespaceValue, classLabels); observationCSFileContent = Utils.FormatCode(observationCSFileContent); var predictionCSFileContent = GeneratePredictionCSFileContent(labelTypeCsharp.Name, namespaceValue); predictionCSFileContent = Utils.FormatCode(predictionCSFileContent); - var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage); + var modelProjectFileContent = GenerateModelProjectFileContent(includeLightGbmPackage, includeMklComponentsPackage, includeFastTreePackage); return (observationCSFileContent, predictionCSFileContent, modelProjectFileContent); } @@ -248,9 +255,9 @@ internal IList GenerateClassLabels() } #region Model project - private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage) + private static string GenerateModelProjectFileContent(bool includeLightGbmPackage, bool includeMklComponentsPackage, bool includeFastTreePackage) { - ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage }; + ModelProject modelProject = new ModelProject() { IncludeLightGBMPackage = includeLightGbmPackage, IncludeMklComponentsPackage = includeMklComponentsPackage, IncludeFastTreePackage = includeFastTreePackage }; return modelProject.TransformText(); } @@ -268,9 +275,9 @@ private string GenerateObservationCSFileContent(string namespaceValue, IList