Skip to content

Commit

Permalink
Format the generated code + bunch of misc tasks (dotnet#152)
Browse files Browse the repository at this point in the history
* added formatting and minor changes for reordering cv

* fixing the template

* minor changes

* formatting changes

* fixed approval test

* removed unused nuget

* added missing value replacing

* added test for new transform

* fix test

* Update src/mlnet/Templates/Console/MLCodeGen.cs

Co-Authored-By: srsaggam <41802116+srsaggam@users.noreply.github.com>
  • Loading branch information
srsaggam authored and Dmitry-A committed Aug 22, 2019
1 parent 0aeb75f commit b92039a
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 69 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ using Microsoft.Data.DataView;
using Microsoft.ML.LightGBM;



namespace MyNamespace
{
class Program
Expand Down Expand Up @@ -41,30 +40,29 @@ namespace MyNamespace
// Data loading
IDataView trainingDataView = mlContext.Data.ReadFromTextFile<SampleObservation>(
path: TrainDataPath,
hasHeader : true,
separatorChar : ',',
allowQuotedStrings : true,
trimWhitespace : false ,
supportSparse : true);
hasHeader: true,
separatorChar: ',',
allowQuotedStrings: true,
trimWhitespace: false,
supportSparse: true);
IDataView testDataView = mlContext.Data.ReadFromTextFile<SampleObservation>(
path: TestDataPath,
hasHeader : true,
separatorChar : ',',
allowQuotedStrings : true,
trimWhitespace : false ,
supportSparse : true);

// Common data process configuration with pipeline data transformations
hasHeader: true,
separatorChar: ',',
allowQuotedStrings: true,
trimWhitespace: false,
supportSparse: true);

var dataProcessPipeline = mlContext.Transforms.Concatenate("Out",new []{"In"});
// Common data process configuration with pipeline data transformations
var dataProcessPipeline = mlContext.Transforms.Concatenate("Out", new[] { "In" });

// Set the training algorithm, then create and config the modelBuilder
var trainer = mlContext.BinaryClassification.Trainers.LightGbm(new Options(){NumLeaves=2,Booster=new Options.TreeBooster.Arguments(){},LabelColumn="Label",FeatureColumn="Features"});

var trainer = mlContext.BinaryClassification.Trainers.LightGbm(new Options() { NumLeaves = 2, Booster = new Options.TreeBooster.Arguments() { }, LabelColumn = "Label", FeatureColumn = "Features" });
var trainingPipeline = dataProcessPipeline.Append(trainer);

// Train the model fitting to the DataSet
var trainingPipeline = dataProcessPipeline.Append(trainer);
var trainedModel = trainingPipeline.Fit(trainingDataView);

// Evaluate the model and show accuracy stats
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
var predictions = trainedModel.Transform(testDataView);
Expand All @@ -86,11 +84,11 @@ namespace MyNamespace
//Load data to test. Could be any test data. For demonstration purpose train data is used here.
IDataView trainingDataView = mlContext.Data.ReadFromTextFile<SampleObservation>(
path: TrainDataPath,
hasHeader : true,
separatorChar : ',',
allowQuotedStrings : true,
trimWhitespace : false ,
supportSparse : true);
hasHeader: true,
separatorChar: ',',
allowQuotedStrings: true,
trimWhitespace: false,
supportSparse: true);

var sample = mlContext.CreateEnumerable<SampleObservation>(trainingDataView, false).First();

Expand All @@ -101,7 +99,7 @@ namespace MyNamespace
}

// Create prediction engine related to the loaded trained model
var predEngine= trainedModel.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlContext);
var predEngine = trainedModel.CreatePredictionEngine<SampleObservation, SamplePrediction>(mlContext);

//Score
var resultprediction = predEngine.Predict(sample);
Expand All @@ -115,29 +113,29 @@ namespace MyNamespace

public class SampleObservation
{
[ColumnName("Label"), LoadColumn(0)]
public bool Label{get; set;}

[ColumnName("Label"), LoadColumn(0)]
public bool Label { get; set; }


[ColumnName("col1"), LoadColumn(1)]
public float Col1 { get; set; }


[ColumnName("col2"), LoadColumn(0)]
public float Col2 { get; set; }


[ColumnName("col3"), LoadColumn(0)]
public string Col3 { get; set; }

[ColumnName("col1"), LoadColumn(1)]
public float Col1{get; set;}


[ColumnName("col2"), LoadColumn(0)]
public float Col2{get; set;}

[ColumnName("col4"), LoadColumn(0)]
public int Col4 { get; set; }

[ColumnName("col3"), LoadColumn(0)]
public string Col3{get; set;}


[ColumnName("col4"), LoadColumn(0)]
public int Col4{get; set;}

[ColumnName("col5"), LoadColumn(0)]
public uint Col5 { get; set; }

[ColumnName("col5"), LoadColumn(0)]
public uint Col5{get; set;}


}

Expand Down
15 changes: 15 additions & 0 deletions src/mlnet.Test/CodeGenTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,21 @@ public void TrainerComplexParameterTest()
}

#region Transform Tests
[TestMethod]
public void MissingValueReplacingTest()
{
var context = new MLContext();
var elementProperties = new Dictionary<string, object>();//categorical
PipelineNode node = new PipelineNode("MissingValueReplacing", PipelineNodeType.Transform, new string[] { "categorical_column_1" }, new string[] { "categorical_column_1" }, elementProperties);
Pipeline pipeline = new Pipeline(new PipelineNode[] { node });
CodeGenerator codeGenerator = new CodeGenerator(pipeline, (null, null), null);
var actual = codeGenerator.GenerateTransformsAndUsings();
var expectedTransform = "ReplaceMissingValues(new []{new MissingValueReplacingTransformer.ColumnInfo(\"categorical_column_1\",\"categorical_column_1\")})";
string expectedUsings = "using Microsoft.ML.Transforms;\r\n";
Assert.AreEqual(expectedTransform, actual[0].Item1);
Assert.AreEqual(expectedUsings, actual[0].Item2);
}

[TestMethod]
public void OneHotEncodingTest()
{
Expand Down
14 changes: 10 additions & 4 deletions src/mlnet/CodeGenerator/CSharp/CodeGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
using System.IO;
using System.Linq;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.ML.Auto;
using Microsoft.ML.CLI.Templates.Console;
using static Microsoft.ML.Data.TextLoader;
Expand Down Expand Up @@ -50,15 +53,18 @@ public void GenerateOutput()
var namespaceValue = Normalize(options.OutputName);

// Generate code for training and scoring
var trainScoreCode = GenerateTrainCode(usings, trainer, transforms, columns, classLabels, namespaceValue);
var trainFileContent = GenerateTrainCode(usings, trainer, transforms, columns, classLabels, namespaceValue);
var tree = CSharpSyntaxTree.ParseText(trainFileContent);
var syntaxNode = tree.GetRoot();
trainFileContent = Formatter.Format(syntaxNode, new AdhocWorkspace()).ToFullString();

// Generate csproj
var projectSourceCode = GeneratProjectCode();
var projectFileContent = GeneratProjectCode();

// Generate Helper class
var consoleHelperCode = GenerateConsoleHelper(namespaceValue);
var consoleHelperFileContent = GenerateConsoleHelper(namespaceValue);

return (trainScoreCode, projectSourceCode, consoleHelperCode);
return (trainFileContent, projectFileContent, consoleHelperFileContent);
}

internal void WriteOutputToFiles(string trainScoreCode, string projectSourceCode, string consoleHelperCode)
Expand Down
4 changes: 3 additions & 1 deletion src/mlnet/CodeGenerator/CSharp/TransformGeneratorFactory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ internal static ITransformGenerator GetInstance(PipelineNode node)
case EstimatorName.MissingValueIndicating:
result = new MissingValueIndicator(node);
break;
//todo : add missing value replacing too.
case EstimatorName.MissingValueReplacing:
result = new MissingValueReplacer(node);
break;
case EstimatorName.OneHotHashEncoding:
result = new OneHotHashEncoding(node);
break;
Expand Down
36 changes: 36 additions & 0 deletions src/mlnet/CodeGenerator/CSharp/TransformGenerators.cs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,42 @@ public override string GenerateTransformer()
}
}

internal class MissingValueReplacer : TransformGeneratorBase
{
public MissingValueReplacer(PipelineNode node) : base(node)
{
}

internal override string MethodName => "ReplaceMissingValues";

private string ArgumentsName = "MissingValueReplacingTransformer.ColumnInfo";
internal override string Usings => "using Microsoft.ML.Transforms;\r\n";

public override string GenerateTransformer()
{
StringBuilder sb = new StringBuilder();
sb.Append(MethodName);
sb.Append("(");
sb.Append("new []{");
for (int i = 0; i < inputColumns.Length; i++)
{
sb.Append("new ");
sb.Append(ArgumentsName);
sb.Append("(");
sb.Append(outputColumns[i]);
sb.Append(",");
sb.Append(inputColumns[i]);
sb.Append(")");
sb.Append(",");
}
sb.Remove(sb.Length - 1, 1); // remove extra ,

sb.Append("}");
sb.Append(")");
return sb.ToString();
}
}

internal class OneHotHashEncoding : TransformGeneratorBase
{
public OneHotHashEncoding(PipelineNode node) : base(node)
Expand Down
23 changes: 11 additions & 12 deletions src/mlnet/Templates/Console/MLCodeGen.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public virtual string TransformText()
" Microsoft.ML.Core.Data;\r\nusing Microsoft.ML.Data;\r\nusing Microsoft.Data.DataVie" +
"w;\r\n");
this.Write(this.ToStringHelper.ToStringWithCulture(GeneratedUsings));
this.Write("\r\n\r\n\r\nnamespace ");
this.Write("\r\n\r\nnamespace ");
this.Write(this.ToStringHelper.ToStringWithCulture(Namespace));
this.Write("\r\n{\r\n class Program\r\n {\r\n private static string TrainDataPath = @\"");
this.Write(this.ToStringHelper.ToStringWithCulture(Path));
Expand Down Expand Up @@ -93,7 +93,7 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
this.Write("\r\n");
if(Transforms.Count >0 ) {
this.Write(" // Common data process configuration with pipeline data transformatio" +
"ns \r\n\r\n var dataProcessPipeline = ");
"ns\r\n var dataProcessPipeline = ");
for(int i=0;i<Transforms.Count;i++)
{
if(i>0)
Expand All @@ -111,7 +111,13 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
this.Write(this.ToStringHelper.ToStringWithCulture(TaskType));
this.Write(".Trainers.");
this.Write(this.ToStringHelper.ToStringWithCulture(Trainer));
this.Write(";\r\n\r\n");
this.Write(";\r\n");
if (Transforms.Count > 0) {
this.Write(" var trainingPipeline = dataProcessPipeline.Append(trainer);\r\n");
}
else{
this.Write(" var trainingPipeline = trainer;\r\n");
}
if(string.IsNullOrEmpty(TestPath)){
this.Write(@"
// Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
Expand All @@ -135,15 +141,8 @@ private static ITransformer BuildTrainEvaluateAndSaveModel(MLContext mlContext)
"rics(trainer.ToString(), crossValidationResults);\r\n");
}
}
this.Write("\r\n // Train the model fitting to the DataSet\r\n");
if(Transforms.Count >0 ) {
this.Write(" var trainingPipeline = dataProcessPipeline.Append(trainer);\r\n " +
" var trainedModel = trainingPipeline.Fit(trainingDataView);\r\n");
}
else{
this.Write(" var trainingPipeline = trainer;\r\n var trainedModel = train" +
"ingPipeline.Fit(trainingDataView);\r\n");
}
this.Write("\r\n // Train the model fitting to the DataSet\r\n var trainedM" +
"odel = trainingPipeline.Fit(trainingDataView);\r\n\r\n");
if(!string.IsNullOrEmpty(TestPath)){
this.Write(" // Evaluate the model and show accuracy stats\r\n Console.Wr" +
"iteLine(\"===== Evaluating Model\'s accuracy with Test data =====\");\r\n " +
Expand Down
19 changes: 8 additions & 11 deletions src/mlnet/Templates/Console/MLCodeGen.tt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ using Microsoft.ML.Data;
using Microsoft.Data.DataView;
<#= GeneratedUsings #>


namespace <#= Namespace #>
{
class Program
Expand Down Expand Up @@ -63,8 +62,7 @@ namespace <#= Namespace #>
<# } #>

<# if(Transforms.Count >0 ) {#>
// Common data process configuration with pipeline data transformations

// Common data process configuration with pipeline data transformations
var dataProcessPipeline = <# for(int i=0;i<Transforms.Count;i++)
{
if(i>0)
Expand All @@ -79,7 +77,12 @@ namespace <#= Namespace #>

// Set the training algorithm, then create and config the modelBuilder
var trainer = mlContext.<#= TaskType #>.Trainers.<#= Trainer #>;

<# if(Transforms.Count >0 ) {#>
var trainingPipeline = dataProcessPipeline.Append(trainer);
<# }
else{#>
var trainingPipeline = trainer;
<#}#>
<# if(string.IsNullOrEmpty(TestPath)){ #>

// Cross-Validate with single dataset (since we don't have two datasets, one for training and for evaluate)
Expand All @@ -95,14 +98,8 @@ namespace <#= Namespace #>
} #>

// Train the model fitting to the DataSet
<# if(Transforms.Count >0 ) {#>
var trainingPipeline = dataProcessPipeline.Append(trainer);
var trainedModel = trainingPipeline.Fit(trainingDataView);
<# }
else{#>
var trainingPipeline = trainer;
var trainedModel = trainingPipeline.Fit(trainingDataView);
<#}#>

<# if(!string.IsNullOrEmpty(TestPath)){ #>
// Evaluate the model and show accuracy stats
Console.WriteLine("===== Evaluating Model's accuracy with Test data =====");
Expand Down
1 change: 1 addition & 0 deletions src/mlnet/mlnet.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Microsoft.CodeAnalysis" Version="2.10.0" />
<PackageReference Include="NLog" Version="4.5.11" />
<PackageReference Include="NLog.Config" Version="4.5.11" />
<PackageReference Include="System.CommandLine.Experimental" Version="0.1.0-alpha-63728-01" />
Expand Down

0 comments on commit b92039a

Please sign in to comment.