From e3830910531f00013c27391914233a085a1394a4 Mon Sep 17 00:00:00 2001 From: Senja Filipi Date: Tue, 29 Jan 2019 13:27:48 -0800 Subject: [PATCH] Input output swap (#2239) * All the code changes in, and most of the tests updated. * all the tests pass * 1 - Changing the "source" parameter name and field in the Columninfo classes, to be "sourceColumnName", as suggested. Changing the "name" parameter to "outputColumnName" in the: - estimator extension APIs - estimator ctors - column pairs expressed through tuples, because in context it reads better than name. Note: in the columnInfo classes i left it to "name" because "outputColumnName" makes no sense. 2 - Nit on standartizing the XML comments. 3 - Arranging the order of the parameters to be: outputColumnName, required parameters, nullable sourceColumnName. --- .../Dynamic/FeatureSelectionTransform.cs | 4 +- .../Dynamic/KeyToValue_Term.cs | 4 +- .../Dynamic/NgramExtraction.cs | 6 +- .../Dynamic/Normalizer.cs | 6 +- .../Dynamic/OnnxTransform.cs | 2 +- .../Dynamic/TensorFlowTransform.cs | 4 +- .../Dynamic/TextTransform.cs | 4 +- src/Microsoft.ML.Core/Data/IEstimator.cs | 18 +- .../Evaluators/AnomalyDetectionEvaluator.cs | 8 +- .../Evaluators/BinaryClassifierEvaluator.cs | 8 +- .../Evaluators/MamlEvaluator.cs | 4 +- .../MultiClassClassifierEvaluator.cs | 2 +- src/Microsoft.ML.Data/TrainCatalog.cs | 4 +- .../ColumnConcatenatingEstimator.cs | 16 +- .../ColumnConcatenatingTransformer.cs | 82 +++--- .../Transforms/ColumnCopying.cs | 47 ++-- .../ConversionsExtensionsCatalog.cs | 45 ++-- .../Transforms/DropSlotsTransform.cs | 50 ++-- .../Transforms/ExtensionsCatalog.cs | 20 +- ...FeatureContributionCalculationTransform.cs | 8 +- src/Microsoft.ML.Data/Transforms/Hashing.cs | 75 +++--- .../Transforms/KeyToValue.cs | 34 +-- .../Transforms/KeyToVector.cs | 79 +++--- .../Transforms/LabelConvertTransform.cs | 8 +- .../Transforms/NormalizeColumn.cs | 32 +-- .../Transforms/Normalizer.cs | 125 ++++----- .../Transforms/NormalizerCatalog.cs | 11 +- .../Transforms/OneToOneTransformerBase.cs | 28 +- .../Transforms/TransformBase.cs | 6 +- .../Transforms/TypeConverting.cs | 69 +++-- .../Transforms/ValueMapping.cs | 56 ++-- .../Transforms/ValueToKeyMappingEstimator.cs | 16 +- .../ValueToKeyMappingTransformer.cs | 53 ++-- .../ValueToKeyMappingTransformerImpl.cs | 12 +- .../AlexNetExtension.cs | 16 +- .../ResNet101Extension.cs | 16 +- .../ResNet18Extension.cs | 16 +- .../ResNet50Extension.cs | 16 +- .../FeatureCombiner.cs | 14 +- .../ScoreColumnSelector.cs | 6 +- src/Microsoft.ML.FastTree/FastTree.cs | 2 +- .../VectorWhiteningStaticExtensions.cs | 2 +- .../HalLearnersCatalog.cs | 10 +- .../VectorWhitening.cs | 60 ++--- .../ExtensionsCatalog.cs | 22 +- .../ImageGrayscaleTransform.cs | 20 +- .../ImageLoaderTransform.cs | 22 +- .../ImagePixelExtractorTransform.cs | 75 +++--- .../ImageResizerTransform.cs | 62 ++--- .../DnnImageFeaturizerStaticExtensions.cs | 2 +- .../OnnxStaticExtensions.cs | 2 +- .../DnnImageFeaturizerTransform.cs | 14 +- src/Microsoft.ML.OnnxTransform/OnnxCatalog.cs | 10 +- .../OnnxTransform.cs | 30 +-- src/Microsoft.ML.PCA/PCACatalog.cs | 10 +- src/Microsoft.ML.PCA/PcaTransformer.cs | 60 +++-- .../CategoricalHashStaticExtensions.cs | 2 +- .../CategoricalStaticExtensions.cs | 2 +- .../ImageTransformsStatic.cs | 24 +- .../LdaStaticExtensions.cs | 3 +- .../LpNormalizerStaticExtensions.cs | 4 +- .../NormalizerStaticExtensions.cs | 14 +- .../StaticPipeUtils.cs | 12 +- .../TextStaticExtensions.cs | 26 +- .../TrainerEstimatorReconciler.cs | 8 +- .../TransformsStatic.cs | 38 +-- .../WordEmbeddingsStaticExtensions.cs | 2 +- .../TensorFlowStaticExtensions.cs | 8 +- .../TensorflowCatalog.cs | 20 +- .../TensorflowTransform.cs | 64 ++--- .../TimeSeriesStatic.cs | 8 +- .../ExponentialAverageTransform.cs | 2 +- .../IidChangePointDetector.cs | 14 +- .../IidSpikeDetector.cs | 12 +- .../MovingAverageTransform.cs | 2 +- .../PValueTransform.cs | 2 +- .../PercentileThresholdTransform.cs | 2 +- ...SequentialAnomalyDetectionTransformBase.cs | 2 +- .../SequentialTransformBase.cs | 20 +- .../SequentialTransformerBase.cs | 8 +- .../SlidingWindowTransformBase.cs | 2 +- .../SsaChangePointDetector.cs | 27 +- .../SsaSpikeDetector.cs | 20 +- .../CategoricalCatalog.cs | 20 +- .../ConversionsCatalog.cs | 10 +- .../CountFeatureSelection.cs | 49 ++-- .../ExtensionsCatalog.cs | 34 +-- .../FeatureSelectionCatalog.cs | 20 +- src/Microsoft.ML.Transforms/GcnTransform.cs | 92 +++---- src/Microsoft.ML.Transforms/GroupTransform.cs | 8 +- .../HashJoiningTransform.cs | 2 +- .../KeyToVectorMapping.cs | 63 ++--- .../MissingValueDroppingTransformer.cs | 30 +-- .../MissingValueHandlingTransformer.cs | 14 +- .../MissingValueIndicatorTransformer.cs | 44 ++-- .../MissingValueReplacing.cs | 93 +++---- .../MutualInformationFeatureSelection.cs | 36 +-- src/Microsoft.ML.Transforms/OneHotEncoding.cs | 34 +-- .../OneHotHashEncoding.cs | 39 +-- .../ProduceIdTransform.cs | 2 +- .../ProjectionCatalog.cs | 26 +- .../RandomFourierFeaturizing.cs | 55 ++-- .../SerializableLambdaTransform.cs | 4 +- .../Text/LdaTransform.cs | 62 ++--- .../Text/NgramHashingTransformer.cs | 115 +++++---- .../Text/NgramTransform.cs | 66 ++--- .../Text/SentimentAnalyzingTransform.cs | 12 +- .../Text/StopWordsRemovingTransformer.cs | 112 ++++---- .../Text/TextCatalog.cs | 241 +++++++++--------- .../Text/TextFeaturizingEstimator.cs | 44 ++-- .../Text/TextNormalizing.cs | 45 ++-- .../Text/TokenizingByCharacters.cs | 39 +-- .../Text/WordBagTransform.cs | 7 +- .../Text/WordEmbeddingsExtractor.cs | 78 +++--- .../Text/WordHashBagProducingTransform.cs | 10 +- .../Text/WordTokenizing.cs | 62 ++--- .../Text/WrappedTextTransformers.cs | 74 +++--- test/Microsoft.ML.Benchmarks/HashBench.cs | 2 +- .../PredictionEngineBench.cs | 2 +- test/Microsoft.ML.Benchmarks/README.md | 21 ++ test/Microsoft.ML.Benchmarks/RffTransform.cs | 2 +- ...sticDualCoordinateAscentClassifierBench.cs | 4 +- .../UnitTests/TestEntryPoints.cs | 16 +- test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 6 +- .../DnnImageFeaturizerTest.cs | 6 +- .../OnnxTransformTests.cs | 8 +- .../TestPredictors.cs | 2 +- test/Microsoft.ML.Tests/CachingTests.cs | 12 +- test/Microsoft.ML.Tests/ImagesTests.cs | 90 +++---- test/Microsoft.ML.Tests/OnnxConversionTest.cs | 6 +- test/Microsoft.ML.Tests/RangeFilterTests.cs | 2 +- .../CookbookSamplesDynamicApi.cs | 32 +-- .../Api/Estimators/CrossValidation.cs | 2 +- .../Scenarios/Api/Estimators/Evaluation.cs | 2 +- .../Api/Estimators/FileBasedSavingOfData.cs | 2 +- .../Api/Estimators/IntrospectiveTraining.cs | 2 +- .../Api/Estimators/MultithreadedPrediction.cs | 2 +- .../Estimators/ReconfigurablePrediction.cs | 2 +- .../Api/Estimators/SimpleTrainAndPredict.cs | 4 +- .../Estimators/TrainSaveModelAndPredict.cs | 2 +- .../Estimators/TrainWithInitialPredictor.cs | 2 +- .../Api/Estimators/TrainWithValidationSet.cs | 2 +- .../Scenarios/Api/Estimators/Visibility.cs | 2 +- ...PlantClassificationWithStringLabelTests.cs | 4 +- .../Scenarios/TensorflowTests.cs | 8 +- .../TensorflowTests.cs | 62 ++--- .../TensorFlowEstimatorTests.cs | 4 +- test/Microsoft.ML.Tests/TermEstimatorTests.cs | 28 +- .../TrainerEstimators/TrainerEstimators.cs | 6 +- .../Transformers/CategoricalHashTests.cs | 36 +-- .../Transformers/CategoricalTests.cs | 56 ++-- .../Transformers/CharTokenizeTests.cs | 4 +- .../Transformers/ConvertTests.cs | 46 ++-- .../Transformers/CopyColumnEstimatorTests.cs | 12 +- .../Transformers/FeatureSelectionTests.cs | 40 +-- .../Transformers/HashTests.cs | 34 +-- .../KeyToBinaryVectorEstimatorTest.cs | 42 +-- .../Transformers/KeyToValueTests.cs | 14 +- .../Transformers/KeyToVectorEstimatorTests.cs | 60 ++--- .../Transformers/NAIndicatorTests.cs | 10 +- .../Transformers/NAReplaceTests.cs | 16 +- .../Transformers/NormalizerTests.cs | 94 +++---- .../Transformers/PcaTests.cs | 6 +- .../Transformers/RffTests.cs | 8 +- .../Transformers/TextFeaturizerTests.cs | 26 +- .../Transformers/TextNormalizer.cs | 14 +- .../Transformers/ValueMappingTests.cs | 38 +-- .../Transformers/WordEmbeddingsTests.cs | 2 +- .../Transformers/WordTokenizeTests.cs | 8 +- .../TimeSeriesDirectApi.cs | 4 +- .../TimeSeriesEstimatorTests.cs | 14 +- 171 files changed, 2216 insertions(+), 2138 deletions(-) diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs index fdeaf6f42c..cb326ecca6 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/FeatureSelectionTransform.cs @@ -48,14 +48,14 @@ public static void FeatureSelectionTransform() // In this example we define a CountFeatureSelectingEstimator, that selects slots in a feature vector that have more non-default // values than the specified count. This transformation can be used to remove slots with too many missing values. var countSelectEst = ml.Transforms.FeatureSelection.SelectFeaturesBasedOnCount( - inputColumn: "Features", outputColumn: "FeaturesCountSelect", count: 695); + outputColumnName: "FeaturesCountSelect", inputColumnName: "Features", count: 695); // We also define a MutualInformationFeatureSelectingEstimator that selects the top k slots in a feature // vector based on highest mutual information between that slot and a specified label. Notice that it is possible to // specify the parameter `numBins', which controls the number of bins used in the approximation of the mutual information // between features and label. var mutualInfoEst = ml.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation( - inputColumn: "FeaturesCountSelect", outputColumn: "FeaturesMISelect", labelColumn: "Label", slotsInOutput: 5); + outputColumnName: "FeaturesMISelect", inputColumnName: "FeaturesCountSelect", labelColumn: "Label", slotsInOutput: 5); // Now, we can put the previous two transformations together in a pipeline. var pipeline = countSelectEst.Append(mutualInfoEst); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs index 039920df19..882000a773 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/KeyToValue_Term.cs @@ -32,7 +32,7 @@ public static void KeyToValue_Term() string defaultColumnName = "DefaultKeys"; // REVIEW create through the catalog extension var default_pipeline = new WordTokenizingEstimator(ml, "Review") - .Append(new ValueToKeyMappingEstimator(ml, "Review", defaultColumnName)); + .Append(new ValueToKeyMappingEstimator(ml, defaultColumnName, "Review")); // Another pipeline, that customizes the advanced settings of the TermEstimator. // We can change the maxNumTerm to limit how many keys will get generated out of the set of words, @@ -40,7 +40,7 @@ public static void KeyToValue_Term() // to value/alphabetically. string customizedColumnName = "CustomizedKeys"; var customized_pipeline = new WordTokenizingEstimator(ml, "Review") - .Append(new ValueToKeyMappingEstimator(ml, "Review", customizedColumnName, maxNumTerms: 10, sort: ValueToKeyMappingTransformer.SortOrder.Value)); + .Append(new ValueToKeyMappingEstimator(ml,customizedColumnName, "Review", maxNumTerms: 10, sort: ValueToKeyMappingTransformer.SortOrder.Value)); // The transformed data. var transformedData_default = default_pipeline.Fit(trainData).Transform(trainData); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs index a476ae0e3e..1f19d6ec21 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/NgramExtraction.cs @@ -26,9 +26,9 @@ public static void NgramTransform() // A pipeline to tokenize text as characters and then combine them together into ngrams // The pipeline uses the default settings to featurize. - var charsPipeline = ml.Transforms.Text.TokenizeCharacters("SentimentText", "Chars", useMarkerCharacters:false); - var ngramOnePipeline = ml.Transforms.Text.ProduceNgrams("Chars", "CharsUnigrams", ngramLength:1); - var ngramTwpPipeline = ml.Transforms.Text.ProduceNgrams("Chars", "CharsTwograms"); + var charsPipeline = ml.Transforms.Text.TokenizeCharacters("Chars", "SentimentText", useMarkerCharacters:false); + var ngramOnePipeline = ml.Transforms.Text.ProduceNgrams("CharsUnigrams", "Chars", ngramLength:1); + var ngramTwpPipeline = ml.Transforms.Text.ProduceNgrams("CharsTwograms", "Chars"); var oneCharsPipeline = charsPipeline.Append(ngramOnePipeline); var twoCharsPipeline = charsPipeline.Append(ngramTwpPipeline); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs index 1ff32b66a3..02a4594a82 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/Normalizer.cs @@ -33,7 +33,7 @@ public static void Normalizer() var transformer = pipeline.Fit(trainData); var modelParams = transformer.Columns - .First(x => x.Output == "Induced") + .First(x => x.Name == "Induced") .ModelParameters as NormalizingTransformer.AffineNormalizerModelParameters; Console.WriteLine($"The normalization parameters are: Scale = {modelParams.Scale} and Offset = {modelParams.Offset}"); @@ -66,7 +66,7 @@ public static void Normalizer() // Composing a different pipeline if we wanted to normalize more than one column at a time. // Using log scale as the normalization mode. - var multiColPipeline = ml.Transforms.Normalize(NormalizingEstimator.NormalizerMode.LogMeanVariance, new[] { ("Induced", "LogInduced"), ("Spontaneous", "LogSpontaneous") }); + var multiColPipeline = ml.Transforms.Normalize(NormalizingEstimator.NormalizerMode.LogMeanVariance, new[] { ("LogInduced", "Induced"), ("LogSpontaneous", "Spontaneous") }); // The transformed data. var multiColtransformer = multiColPipeline.Fit(trainData); var multiColtransformedData = multiColtransformer.Transform(trainData); @@ -97,7 +97,7 @@ public static void Normalizer() // Inspect the weights of normalizing the columns var multiColModelParams = multiColtransformer.Columns - .First(x=> x.Output == "LogInduced") + .First(x=> x.Name == "LogInduced") .ModelParameters as NormalizingTransformer.CdfNormalizerModelParameters; Console.WriteLine($"The normalization parameters are: Mean = {multiColModelParams.Mean} and Stddev = {multiColModelParams.Stddev}"); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs index 034427fcc6..7549d276e5 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/OnnxTransform.cs @@ -35,7 +35,7 @@ public static void OnnxTransformSample() var mlContext = new MLContext(); var data = GetTensorData(); var idv = mlContext.Data.ReadFromEnumerable(data); - var pipeline = new OnnxScoringEstimator(mlContext, modelPath, new[] { inputInfo.Key }, new[] { outputInfo.Key }); + var pipeline = new OnnxScoringEstimator(mlContext, new[] { outputInfo.Key }, new[] { inputInfo.Key }, modelPath); // Run the pipeline and get the transformed values var transformedValues = pipeline.Fit(idv).Transform(idv); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs index fcc570c35c..bd8c0ee6d1 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/TensorFlowTransform.cs @@ -22,8 +22,8 @@ public static void TensorFlowScoringSample() // Create a ML pipeline. var pipeline = mlContext.Transforms.ScoreTensorFlowModel( modelLocation, - new[] { nameof(TensorData.input) }, - new[] { nameof(OutputScores.output) }); + new[] { nameof(OutputScores.output) }, + new[] { nameof(TensorData.input) }); // Run the pipeline and get the transformed values. var estimator = pipeline.Fit(idv); diff --git a/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs b/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs index d1eda44dd0..c1583d56d1 100644 --- a/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs +++ b/docs/samples/Microsoft.ML.Samples/Dynamic/TextTransform.cs @@ -27,11 +27,11 @@ public static void TextTransform() // A pipeline for featurization of the "SentimentText" column, and placing the output in a new column named "DefaultTextFeatures" // The pipeline uses the default settings to featurize. string defaultColumnName = "DefaultTextFeatures"; - var default_pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", defaultColumnName); + var default_pipeline = ml.Transforms.Text.FeaturizeText(defaultColumnName , "SentimentText"); // Another pipeline, that customizes the advanced settings of the FeaturizeText transformer. string customizedColumnName = "CustomizedTextFeatures"; - var customized_pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", customizedColumnName, s => + var customized_pipeline = ml.Transforms.Text.FeaturizeText(customizedColumnName, "SentimentText", s => { s.KeepPunctuations = false; s.KeepNumbers = false; diff --git a/src/Microsoft.ML.Core/Data/IEstimator.cs b/src/Microsoft.ML.Core/Data/IEstimator.cs index 45464af3b9..6ebef55827 100644 --- a/src/Microsoft.ML.Core/Data/IEstimator.cs +++ b/src/Microsoft.ML.Core/Data/IEstimator.cs @@ -75,29 +75,29 @@ internal Column(string name, VectorKind vecKind, ColumnType itemType, bool isKey } /// - /// Returns whether is a valid input, if this object represents a + /// Returns whether is a valid input, if this object represents a /// requirement. /// /// Namely, it returns true iff: /// - The , , , fields match. - /// - The columns of of is a superset of our columns. + /// - The columns of of is a superset of our columns. /// - Each such metadata column is itself compatible with the input metadata column. /// [BestFriend] - internal bool IsCompatibleWith(Column inputColumn) + internal bool IsCompatibleWith(Column source) { - Contracts.Check(inputColumn.IsValid, nameof(inputColumn)); - if (Name != inputColumn.Name) + Contracts.Check(source.IsValid, nameof(source)); + if (Name != source.Name) return false; - if (Kind != inputColumn.Kind) + if (Kind != source.Kind) return false; - if (!ItemType.Equals(inputColumn.ItemType)) + if (!ItemType.Equals(source.ItemType)) return false; - if (IsKey != inputColumn.IsKey) + if (IsKey != source.IsKey) return false; foreach (var metaCol in Metadata) { - if (!inputColumn.Metadata.TryFindColumn(metaCol.Name, out var inputMetaCol)) + if (!source.Metadata.TryFindColumn(metaCol.Name, out var inputMetaCol)) return false; if (!metaCol.IsCompatibleWith(inputMetaCol)) return false; diff --git a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs index b3adf4c13d..143ecb44b5 100644 --- a/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/AnomalyDetectionEvaluator.cs @@ -708,11 +708,11 @@ private protected override void PrintFoldResultsCore(IChannel ch, Dictionary(); diff --git a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs index ed5b1bc84b..44209089ef 100644 --- a/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MamlEvaluator.cs @@ -230,7 +230,7 @@ private IDataView WrapPerInstance(RoleMappedData perInst) // Make a list of column names that Maml outputs as part of the per-instance data view, and then wrap // the per-instance data computed by the evaluator in a SelectColumnsTransform. - var cols = new List<(string Source, string Name)>(); + var cols = new List<(string name, string source)>(); var colsToKeep = new List(); // If perInst is the result of cross-validation and contains a fold Id column, include it. @@ -241,7 +241,7 @@ private IDataView WrapPerInstance(RoleMappedData perInst) // Maml always outputs a name column, if it doesn't exist add a GenerateNumberTransform. if (perInst.Schema.Name?.Name is string nameName) { - cols.Add((nameName, "Instance")); + cols.Add(("Instance", nameName)); colsToKeep.Add("Instance"); } else diff --git a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs index 7b87a7a475..c6447281a4 100644 --- a/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs +++ b/src/Microsoft.ML.Data/Evaluators/MultiClassClassifierEvaluator.cs @@ -950,7 +950,7 @@ private protected override IDataView GetOverallResultsCore(IDataView overall) private IDataView ChangeTopKAccColumnName(IDataView input) { - input = new ColumnCopyingTransformer(Host, (MultiClassClassifierEvaluator.TopKAccuracy, string.Format(TopKAccuracyFormat, _outputTopKAcc))).Transform(input); + input = new ColumnCopyingTransformer(Host, (string.Format(TopKAccuracyFormat, _outputTopKAcc), MultiClassClassifierEvaluator.TopKAccuracy)).Transform(input); return ColumnSelectingTransformer.CreateDrop(Host, input, MultiClassClassifierEvaluator.TopKAccuracy); } diff --git a/src/Microsoft.ML.Data/TrainCatalog.cs b/src/Microsoft.ML.Data/TrainCatalog.cs index 81234dc764..3b886ae908 100644 --- a/src/Microsoft.ML.Data/TrainCatalog.cs +++ b/src/Microsoft.ML.Data/TrainCatalog.cs @@ -154,9 +154,9 @@ private void EnsureStratificationColumn(ref IDataView data, ref string stratific stratificationColumn = string.Format("{0}_{1:000}", origStratCol, ++inc); HashingTransformer.ColumnInfo columnInfo; if (seed.HasValue) - columnInfo = new HashingTransformer.ColumnInfo(origStratCol, stratificationColumn, 30, seed.Value); + columnInfo = new HashingTransformer.ColumnInfo(stratificationColumn, origStratCol, 30, seed.Value); else - columnInfo = new HashingTransformer.ColumnInfo(origStratCol, stratificationColumn, 30); + columnInfo = new HashingTransformer.ColumnInfo(stratificationColumn, origStratCol, 30); data = new HashingEstimator(Host, columnInfo).Fit(data).Transform(data); } } diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs index 128d23e31e..e901375f83 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingEstimator.cs @@ -21,20 +21,20 @@ public sealed class ColumnConcatenatingEstimator : IEstimator /// Initializes a new instance of /// /// The local instance of . - /// The name of the resulting column. - /// The columns to concatenate together. - public ColumnConcatenatingEstimator (IHostEnvironment env, string outputColumn, params string[] inputColumns) + /// The name of the resulting column. + /// The columns to concatenate together. + public ColumnConcatenatingEstimator(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames) { Contracts.CheckValue(env, nameof(env)); _host = env.Register("ColumnConcatenatingEstimator "); - _host.CheckNonEmpty(outputColumn, nameof(outputColumn)); - _host.CheckValue(inputColumns, nameof(inputColumns)); - _host.CheckParam(!inputColumns.Any(r => string.IsNullOrEmpty(r)), nameof(inputColumns), + _host.CheckNonEmpty(outputColumnName, nameof(outputColumnName)); + _host.CheckValue(inputColumnNames, nameof(inputColumnNames)); + _host.CheckParam(!inputColumnNames.Any(r => string.IsNullOrEmpty(r)), nameof(inputColumnNames), "Contained some null or empty items"); - _name = outputColumn; - _source = inputColumns; + _name = outputColumnName; + _source = inputColumnNames; } public ITransformer Fit(IDataView input) diff --git a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs index faa9d93e74..a15ee34609 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnConcatenatingTransformer.cs @@ -70,7 +70,7 @@ public sealed class TaggedColumn // tag if it is non empty. For vector columns, the slot names will be 'ColumnName.SlotName' if the // tag is empty, 'Tag.SlotName' if tag is non empty, and simply the slot name if tag is non empty // and equal to the column name. - [Argument(ArgumentType.Multiple, HelpText = "Name of the source column", ShortName = "src")] + [Argument(ArgumentType.Multiple, HelpText = "Names of the source columns", ShortName = "src")] public KeyValuePair[] Source; internal static TaggedColumn Parse(string str) @@ -126,43 +126,43 @@ public sealed class TaggedArguments public sealed class ColumnInfo { - public readonly string Output; - private readonly (string name, string alias)[] _inputs; - public IReadOnlyList<(string name, string alias)> Inputs => _inputs.AsReadOnly(); + public readonly string Name; + private readonly (string name, string alias)[] _sources; + public IReadOnlyList<(string name, string alias)> Sources => _sources.AsReadOnly(); /// - /// This denotes a concatenation of all into column called . + /// This denotes a concatenation of all into column called . /// - public ColumnInfo(string outputName, params string[] inputNames) - : this(outputName, GetPairs(inputNames)) + public ColumnInfo(string name, params string[] inputColumnNames) + : this(name, GetPairs(inputColumnNames)) { } - private static IEnumerable<(string name, string alias)> GetPairs(string[] inputNames) + private static IEnumerable<(string name, string alias)> GetPairs(string[] inputColumnNames) { - Contracts.CheckValue(inputNames, nameof(inputNames)); - return inputNames.Select(name => (name, (string)null)); + Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames)); + return inputColumnNames.Select(name => (name, (string)null)); } /// - /// This denotes a concatenation of input columns into one column called . + /// This denotes a concatenation of input columns into one column called . /// For each input column, an 'alias' can be specified, to be used in constructing the resulting slot names. /// If the alias is not specified, it defaults to be column name. /// - public ColumnInfo(string outputName, IEnumerable<(string name, string alias)> inputs) + public ColumnInfo(string name, IEnumerable<(string name, string alias)> inputColumnNames) { - Contracts.CheckNonEmpty(outputName, nameof(outputName)); - Contracts.CheckValue(inputs, nameof(inputs)); - Contracts.CheckParam(inputs.Any(), nameof(inputs), "Can not be empty"); + Contracts.CheckNonEmpty(name, nameof(name)); + Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames)); + Contracts.CheckParam(inputColumnNames.Any(), nameof(inputColumnNames), "Can not be empty"); - foreach (var (name, alias) in inputs) + foreach (var (output, alias) in inputColumnNames) { - Contracts.CheckNonEmpty(name, nameof(inputs)); + Contracts.CheckNonEmpty(output, nameof(inputColumnNames)); Contracts.CheckValueOrNull(alias); } - Output = outputName; - _inputs = inputs.ToArray(); + Name = name; + _sources = inputColumnNames.ToArray(); } public void Save(ModelSaveContext ctx) @@ -175,10 +175,10 @@ public void Save(ModelSaveContext ctx) // int: id of name // int: id of alias - ctx.SaveNonEmptyString(Output); - Contracts.Assert(_inputs.Length > 0); - ctx.Writer.Write(_inputs.Length); - foreach (var (name, alias) in _inputs) + ctx.SaveNonEmptyString(Name); + Contracts.Assert(_sources.Length > 0); + ctx.Writer.Write(_sources.Length); + foreach (var (name, alias) in _sources) { ctx.SaveNonEmptyString(name); ctx.SaveStringOrNull(alias); @@ -195,15 +195,15 @@ internal ColumnInfo(ModelLoadContext ctx) // int: id of name // int: id of alias - Output = ctx.LoadNonEmptyString(); + Name = ctx.LoadNonEmptyString(); int n = ctx.Reader.ReadInt32(); Contracts.CheckDecode(n > 0); - _inputs = new (string name, string alias)[n]; + _sources = new (string name, string alias)[n]; for (int i = 0; i < n; i++) { var name = ctx.LoadNonEmptyString(); var alias = ctx.LoadStringOrNull(); - _inputs[i] = (name, alias); + _sources[i] = (name, alias); } } } @@ -213,12 +213,12 @@ internal ColumnInfo(ModelLoadContext ctx) public IReadOnlyCollection Columns => _columns.AsReadOnly(); /// - /// Concatename columns in into one column . + /// Concatename columns in into one column . /// Original columns are also preserved. /// The column types must match, and the output column type is always a vector. /// - public ColumnConcatenatingTransformer(IHostEnvironment env, string outputName, params string[] inputNames) - : this(env, new ColumnInfo(outputName, inputNames)) + public ColumnConcatenatingTransformer(IHostEnvironment env, string outputColumnName, params string[] inputColumnNames) + : this(env, new ColumnInfo(outputColumnName, inputColumnNames)) { } @@ -432,7 +432,7 @@ private BoundColumn MakeColumn(Schema inputSchema, int iinfo) Contracts.Assert(0 <= iinfo && iinfo < _parent._columns.Length); ColumnType itemType = null; - int[] sources = new int[_parent._columns[iinfo].Inputs.Count]; + int[] sources = new int[_parent._columns[iinfo].Sources.Count]; // Go through the columns, and establish the following: // - indices of input columns in the input schema. Throw if they are not there. // - output type. Throw if the types of inputs are not the same. @@ -449,9 +449,9 @@ private BoundColumn MakeColumn(Schema inputSchema, int iinfo) bool isNormalized = true; bool hasSlotNames = false; bool hasCategoricals = false; - for (int i = 0; i < _parent._columns[iinfo].Inputs.Count; i++) + for (int i = 0; i < _parent._columns[iinfo].Sources.Count; i++) { - var (srcName, srcAlias) = _parent._columns[iinfo].Inputs[i]; + var (srcName, srcAlias) = _parent._columns[iinfo].Sources[i]; if (!inputSchema.TryGetColumnIndex(srcName, out int srcCol)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName); sources[i] = srcCol; @@ -556,7 +556,7 @@ public Schema.DetachedColumn MakeSchemaColumn() if (_isIdentity) { var inputCol = _inputSchema[SrcIndices[0]]; - return new Schema.DetachedColumn(_columnInfo.Output, inputCol.Type, inputCol.Metadata); + return new Schema.DetachedColumn(_columnInfo.Name, inputCol.Type, inputCol.Metadata); } var metadata = new MetadataBuilder(); @@ -567,7 +567,7 @@ public Schema.DetachedColumn MakeSchemaColumn() if (_hasCategoricals) metadata.Add(MetadataUtils.Kinds.CategoricalSlotRanges, _categoricalRangeType, (ValueGetter>)GetCategoricalSlotRanges); - return new Schema.DetachedColumn(_columnInfo.Output, OutputType, metadata.GetMetadata()); + return new Schema.DetachedColumn(_columnInfo.Name, OutputType, metadata.GetMetadata()); } private void GetIsNormalized(ref bool value) => value = _isNormalized; @@ -616,9 +616,9 @@ private void GetSlotNames(ref VBuffer> dst) { int colSrc = SrcIndices[i]; var typeSrc = _srcTypes[i]; - Contracts.Assert(_columnInfo.Inputs[i].alias != ""); + Contracts.Assert(_columnInfo.Sources[i].alias != ""); var colName = _inputSchema[colSrc].Name; - var nameSrc = _columnInfo.Inputs[i].alias ?? colName; + var nameSrc = _columnInfo.Sources[i].alias ?? colName; if (!(typeSrc is VectorType vectorTypeSrc)) { bldr.AddFeature(slot++, nameSrc.AsMemory()); @@ -636,7 +636,7 @@ private void GetSlotNames(ref VBuffer> dst) { inputMetadata.GetValue(MetadataUtils.Kinds.SlotNames, ref names); sb.Clear(); - if (_columnInfo.Inputs[i].alias != colName) + if (_columnInfo.Sources[i].alias != colName) sb.Append(nameSrc).Append("."); int len = sb.Length; foreach (var kvp in names.Items()) @@ -787,7 +787,7 @@ private Delegate MakeGetter(Row input) public KeyValuePair SavePfaInfo(BoundPfaContext ctx) { Contracts.AssertValue(ctx); - string outName = _columnInfo.Output; + string outName = _columnInfo.Name; if (!OutputType.IsKnownSize) // Do not attempt variable length. return new KeyValuePair(outName, null); @@ -795,7 +795,7 @@ public KeyValuePair SavePfaInfo(BoundPfaContext ctx) bool[] srcPrimitive = new bool[SrcIndices.Length]; for (int i = 0; i < SrcIndices.Length; ++i) { - var srcName = _columnInfo.Inputs[i].name; + var srcName = _columnInfo.Sources[i].name; if ((srcTokens[i] = ctx.TokenOrNullForName(srcName)) == null) return new KeyValuePair(outName, null); srcPrimitive[i] = _srcTypes[i] is PrimitiveType; @@ -888,7 +888,7 @@ public void SaveAsOnnx(OnnxContext ctx) var colInfo = _parent._columns[iinfo]; var boundCol = _columns[iinfo]; - string outName = colInfo.Output; + string outName = colInfo.Name; var outColType = boundCol.OutputType; if (!outColType.IsKnownSize) { @@ -899,7 +899,7 @@ public void SaveAsOnnx(OnnxContext ctx) List> inputList = new List>(); for (int i = 0; i < boundCol.SrcIndices.Length; ++i) { - var srcName = colInfo.Inputs[i].name; + var srcName = colInfo.Sources[i].name; if (!ctx.ContainsColumn(srcName)) { ctx.RemoveColumn(outName, false); diff --git a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs index 3bffecf68f..cc3e407ad5 100644 --- a/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs +++ b/src/Microsoft.ML.Data/Transforms/ColumnCopying.cs @@ -35,12 +35,12 @@ namespace Microsoft.ML.Transforms { public sealed class ColumnCopyingEstimator : TrivialEstimator { - public ColumnCopyingEstimator(IHostEnvironment env, string input, string output) : - this(env, (input, output)) + public ColumnCopyingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName) : + this(env, (outputColumnName, inputColumnName)) { } - public ColumnCopyingEstimator(IHostEnvironment env, params (string input, string output)[] columns) + public ColumnCopyingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnCopyingEstimator)), new ColumnCopyingTransformer(env, columns)) { } @@ -50,12 +50,12 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) Host.CheckValue(inputSchema, nameof(inputSchema)); var resultDic = inputSchema.ToDictionary(x => x.Name); - foreach (var (Source, Name) in Transformer.Columns) + foreach (var (outputColumnName, inputColumnName) in Transformer.Columns) { - if (!inputSchema.TryFindColumn(Source, out var originalColumn)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Source); - var col = new SchemaShape.Column(Name, originalColumn.Kind, originalColumn.ItemType, originalColumn.IsKey, originalColumn.Metadata); - resultDic[Name] = col; + if (!inputSchema.TryFindColumn(inputColumnName, out var originalColumn)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnName); + var col = new SchemaShape.Column(outputColumnName, originalColumn.Kind, originalColumn.ItemType, originalColumn.IsKey, originalColumn.Metadata); + resultDic[outputColumnName] = col; } return new SchemaShape(resultDic.Values); } @@ -63,12 +63,13 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) public sealed class ColumnCopyingTransformer : OneToOneTransformerBase { - public const string LoaderSignature = "CopyTransform"; + [BestFriend] + internal const string LoaderSignature = "CopyTransform"; internal const string Summary = "Copy a source column to a new column."; internal const string UserName = "Copy Columns Transform"; internal const string ShortName = "Copy"; - public IReadOnlyCollection<(string Input, string Output)> Columns => ColumnPairs.AsReadOnly(); + public IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); private static VersionInfo GetVersionInfo() { @@ -81,7 +82,7 @@ private static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(ColumnCopyingTransformer).Assembly.FullName); } - public ColumnCopyingTransformer(IHostEnvironment env, params (string input, string output)[] columns) + public ColumnCopyingTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ColumnCopyingTransformer)), columns) { } @@ -117,7 +118,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat Contracts.CheckValue(env, nameof(env)); env.CheckValue(args, nameof(args)); - var transformer = new ColumnCopyingTransformer(env, args.Column.Select(x => (x.Source, x.Name)).ToArray()); + var transformer = new ColumnCopyingTransformer(env, args.Column.Select(x => (x.Name, x.Source)).ToArray()); return transformer.MakeDataTransform(input); } @@ -135,11 +136,11 @@ private static ColumnCopyingTransformer Create(IHostEnvironment env, ModelLoadCo // string: input column name var length = ctx.Reader.ReadInt32(); - var columns = new (string Input, string Output)[length]; + var columns = new (string outputColumnName, string inputColumnName)[length]; for (int i = 0; i < length; i++) { - columns[i].Output = ctx.LoadNonEmptyString(); - columns[i].Input = ctx.LoadNonEmptyString(); + columns[i].outputColumnName = ctx.LoadNonEmptyString(); + columns[i].inputColumnName = ctx.LoadNonEmptyString(); } return new ColumnCopyingTransformer(env, columns); } @@ -164,11 +165,11 @@ private protected override IRowMapper MakeRowMapper(Schema inputSchema) private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx { private readonly Schema _schema; - private readonly (string Input, string Output)[] _columns; + private readonly (string outputColumnName, string inputColumnName)[] _columns; public bool CanSaveOnnx(OnnxContext ctx) => ctx.GetOnnxVersion() == OnnxVersion.Experimental; - internal Mapper(ColumnCopyingTransformer parent, Schema inputSchema, (string Input, string Output)[] columns) + internal Mapper(ColumnCopyingTransformer parent, Schema inputSchema, (string outputColumnName, string inputColumnName)[] columns) : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) { _schema = inputSchema; @@ -184,7 +185,7 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func act Delegate MakeGetter(Row row, int index) => input.GetGetter(index); - input.Schema.TryGetColumnIndex(_columns[iinfo].Input, out int colIndex); + input.Schema.TryGetColumnIndex(_columns[iinfo].inputColumnName, out int colIndex); var type = input.Schema[colIndex].Type; return Utils.MarshalInvoke(MakeGetter, type.RawType, input, colIndex); } @@ -194,8 +195,8 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_columns.Length]; for (int i = 0; i < _columns.Length; i++) { - var srcCol = _schema[_columns[i].Input]; - result[i] = new Schema.DetachedColumn(_columns[i].Output, srcCol.Type, srcCol.Metadata); + var srcCol = _schema[_columns[i].inputColumnName]; + result[i] = new Schema.DetachedColumn(_columns[i].outputColumnName, srcCol.Type, srcCol.Metadata); } return result; } @@ -206,9 +207,9 @@ public void SaveAsOnnx(OnnxContext ctx) foreach (var column in _columns) { - var srcVariableName = ctx.GetVariableName(column.Input); - _schema.TryGetColumnIndex(column.Input, out int colIndex); - var dstVariableName = ctx.AddIntermediateVariable(_schema[colIndex].Type, column.Output); + var srcVariableName = ctx.GetVariableName(column.inputColumnName); + _schema.TryGetColumnIndex(column.inputColumnName, out int colIndex); + var dstVariableName = ctx.AddIntermediateVariable(_schema[colIndex].Type, column.outputColumnName); var node = ctx.CreateNode(opType, srcVariableName, dstVariableName, ctx.GetNodeName(opType)); node.AddAttribute("type", LoaderSignature); } diff --git a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs index b8f9688307..aeb6e23743 100644 --- a/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/ConversionsExtensionsCatalog.cs @@ -21,16 +21,15 @@ public static class ConversionsExtensionsCatalog /// Hashes the values in the input column. /// /// The transform's catalog. - /// Name of the input column. - /// Name of the column to be transformed. If this is null '' will be used. - /// Number of bits to hash into. Must be between 1 and 31, inclusive. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Number of bits to hash into. Must be between 1 and 31, inclusive. /// During hashing we constuct mappings between original values and the produced hash values. /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms catalog, string inputColumn, string outputColumn = null, + public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms catalog, string outputColumnName, string inputColumnName = null, int hashBits = HashDefaults.HashBits, int invertHash = HashDefaults.InvertHash) - => new HashingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, hashBits, invertHash); + => new HashingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, hashBits, invertHash); /// /// Hashes the values in the input column. @@ -44,12 +43,12 @@ public static HashingEstimator Hash(this TransformsCatalog.ConversionTransforms /// Changes column type of the input column. /// /// The transform's catalog. - /// Name of the input column. - /// Name of the column to be transformed. If this is null '' will be used. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Number of bits to hash into. Must be between 1 and 31, inclusive. - public static TypeConvertingEstimator ConvertType(this TransformsCatalog.ConversionTransforms catalog, string inputColumn, string outputColumn = null, + public static TypeConvertingEstimator ConvertType(this TransformsCatalog.ConversionTransforms catalog, string outputColumnName, string inputColumnName = null, DataKind outputKind = ConvertDefaults.DefaultOutputKind) - => new TypeConvertingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, outputKind); + => new TypeConvertingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, outputKind); /// /// Changes column type of the input column. @@ -63,9 +62,9 @@ public static TypeConvertingEstimator ConvertType(this TransformsCatalog.Convers /// Convert the key types back to their original values. /// /// The categorical transform's catalog. - /// Name of the input column. - public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.ConversionTransforms catalog, string inputColumn) - => new KeyToValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn); + /// Name of the column to transform. + public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.ConversionTransforms catalog, string inputColumnName) + => new KeyToValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumnName); /// /// Convert the key types (name of the column specified in the first item of the tuple) back to their original values @@ -73,7 +72,7 @@ public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.Co /// /// The categorical transform's catalog /// The pairs of input and output columns. - public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.ConversionTransforms catalog, params (string input, string output)[] columns) + public static KeyToValueMappingEstimator MapKeyToValue(this TransformsCatalog.ConversionTransforms catalog, params (string outputColumnName, string inputColumnName)[] columns) => new KeyToValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), columns); /// @@ -89,28 +88,28 @@ public static KeyToVectorMappingEstimator MapKeyToVector(this TransformsCatalog. /// Convert the key types back to their original vectors. /// /// The categorical transform's catalog. - /// The name of the input column. - /// The name of the output column. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Whether bagging is used for the conversion. public static KeyToVectorMappingEstimator MapKeyToVector(this TransformsCatalog.ConversionTransforms catalog, - string inputColumn, string outputColumn = null, bool bag = KeyToVectorMappingEstimator.Defaults.Bag) - => new KeyToVectorMappingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, bag); + string outputColumnName, string inputColumnName = null, bool bag = KeyToVectorMappingEstimator.Defaults.Bag) + => new KeyToVectorMappingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, bag); /// /// Converts value types into . /// /// The categorical transform's catalog. - /// Name of the column to be transformed. - /// Name of the output column. If this is null '' will be used. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Maximum number of keys to keep per column when auto-training. /// How items should be ordered when vectorized. If choosen they will be in the order encountered. /// If , items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). public static ValueToKeyMappingEstimator MapValueToKey(this TransformsCatalog.ConversionTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, int maxNumTerms = ValueToKeyMappingEstimator.Defaults.MaxNumTerms, ValueToKeyMappingTransformer.SortOrder sort = ValueToKeyMappingEstimator.Defaults.Sort) - => new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, maxNumTerms, sort); + => new ValueToKeyMappingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, maxNumTerms, sort); /// /// Converts value types into , optionally loading the keys to use from . @@ -139,7 +138,7 @@ public static ValueMappingEstimator ValueMap keys, IEnumerable values, - params (string source, string name)[] columns) + params (string outputColumnName, string inputColumnName)[] columns) => new ValueMappingEstimator(CatalogUtils.GetEnvironment(catalog), keys, values, columns); } } diff --git a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs index 5f61154d35..240c994f34 100644 --- a/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/DropSlotsTransform.cs @@ -192,22 +192,24 @@ public bool IsValid() /// public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly (int min, int? max)[] Slots; /// /// Describes how the transformer handles one input-output column pair. /// - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. + /// If set to , the value of the will be used as source. /// Ranges of indices in the input column to be dropped. Setting max in to null sets max to int.MaxValue. - public ColumnInfo(string input, string output = null, params (int min, int? max)[] slots) + public ColumnInfo(string name, string inputColumnName = null, params (int min, int? max)[] slots) { - Input = input; - Contracts.CheckValue(Input, nameof(Input)); - Output = output ?? input; - Contracts.CheckValue(Output, nameof(Output)); + Name = name; + Contracts.CheckValue(Name, nameof(Name)); + InputColumnName = inputColumnName ?? name; + Contracts.CheckValue(InputColumnName, nameof(InputColumnName)); + // By default drop everything. Slots = (slots.Length > 0) ? slots : new (int min, int? max)[1]; foreach (var (min, max) in Slots) @@ -216,10 +218,10 @@ public ColumnInfo(string input, string output = null, params (int min, int? max) internal ColumnInfo(Column column) { - Input = column.Source ?? column.Name; - Contracts.CheckValue(Input, nameof(Input)); - Output = column.Name; - Contracts.CheckValue(Output, nameof(Output)); + Name = column.Name; + Contracts.CheckValue(Name, nameof(Name)); + InputColumnName = column.Source ?? column.Name; + Contracts.CheckValue(InputColumnName, nameof(InputColumnName)); Slots = column.Slots.Select(range => (range.Min, range.Max)).ToArray(); foreach (var (min, max) in Slots) Contracts.Assert(min >= 0 && (max == null || min <= max)); @@ -251,12 +253,12 @@ private static VersionInfo GetVersionInfo() /// Initializes a new object. /// /// The environment to use. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Specifies the lower bound of the range of slots to be dropped. The lower bound is inclusive. /// Specifies the upper bound of the range of slots to be dropped. The upper bound is exclusive. - public SlotsDroppingTransformer(IHostEnvironment env, string input, string output = null, int min = default, int? max = null) - : this(env, new ColumnInfo(input, output, (min, max))) + public SlotsDroppingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int min = default, int? max = null) + : this(env, new ColumnInfo(outputColumnName, inputColumnName, (min, max))) { } @@ -408,8 +410,8 @@ private static void GetSlotsMinMax(ColumnInfo[] columns, out int[][] slotsMin, o } } - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) - => columns.Select(c => (c.Input, c.Output ?? c.Input)).ToArray(); + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) + => columns.Select(c => (c.Name, c.InputColumnName ?? c.Name)).ToArray(); private static bool AreRangesValid(int[][] slotsMin, int[][] slotsMax) { @@ -461,14 +463,14 @@ public Mapper(SlotsDroppingTransformer parent, Schema inputSchema) for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - if (!InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _cols[i])) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + if (!InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _cols[i])) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName); _srcTypes[i] = inputSchema[_cols[i]].Type; VectorType srcVectorType = _srcTypes[i] as VectorType; ColumnType itemType = srcVectorType?.ItemType ?? _srcTypes[i]; if (!IsValidColumnType(itemType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName); int valueCount = srcVectorType?.Size ?? 1; _slotDropper[i] = new SlotDropper(valueCount, _parent.SlotsMin[i], _parent.SlotsMax[i]); @@ -819,7 +821,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() // Avoid closure when adding metadata. int iinfo = i; - InputSchema.TryGetColumnIndex(_parent.ColumnPairs[iinfo].input, out int colIndex); + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[iinfo].inputColumnName, out int colIndex); Host.Assert(colIndex >= 0); var builder = new MetadataBuilder(); @@ -859,7 +861,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() // Add isNormalize and KeyValues metadata. builder.Add(InputSchema[_cols[iinfo]].Metadata, x => x == MetadataUtils.Kinds.KeyValues || x == MetadataUtils.Kinds.IsNormalized); - result[iinfo] = new Schema.DetachedColumn(_parent.ColumnPairs[iinfo].output, _dstTypes[iinfo], builder.GetMetadata()); + result[iinfo] = new Schema.DetachedColumn(_parent.ColumnPairs[iinfo].outputColumnName, _dstTypes[iinfo], builder.GetMetadata()); } return result; } diff --git a/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs b/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs index 18393ba934..df37aaaae9 100644 --- a/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/ExtensionsCatalog.cs @@ -13,13 +13,13 @@ namespace Microsoft.ML public static class TransformExtensionsCatalog { /// - /// Copies the input column to another column named as specified in . + /// Copies the input column to another column named as specified in . /// /// The transform's catalog. - /// Name of the input column. - /// Name of the new column, resulting from copying. - public static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog, string inputColumn, string outputColumn) - => new ColumnCopyingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn); + /// Name of the column resulting from the transformation of . + /// Name of the columns to transform. + public static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog, string outputColumnName, string inputColumnName) + => new ColumnCopyingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName); /// /// Copies the input column, name specified in the first item of the tuple, @@ -27,17 +27,17 @@ public static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog, /// /// The transform's catalog /// The pairs of input and output columns. - public static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog, params (string source, string name)[] columns) + public static ColumnCopyingEstimator CopyColumns(this TransformsCatalog catalog, params (string outputColumnName, string inputColumnName)[] columns) => new ColumnCopyingEstimator(CatalogUtils.GetEnvironment(catalog), columns); /// /// Concatenates two columns together. /// /// The transform's catalog. - /// The name of the output column. - /// The names of the columns to concatenate together. - public static ColumnConcatenatingEstimator Concatenate(this TransformsCatalog catalog, string outputColumn, params string[] inputColumns) - => new ColumnConcatenatingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumn, inputColumns); + /// Name of the column resulting from the transformation of . + /// Name of the columns to transform. + public static ColumnConcatenatingEstimator Concatenate(this TransformsCatalog catalog, string outputColumnName, params string[] inputColumnNames) + => new ColumnConcatenatingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnNames); /// /// DropColumns is used to select a list of columns that user wants to drop from a given input. Any column not specified will diff --git a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs index f7c4b7b33f..66f3f92aca 100644 --- a/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/FeatureContributionCalculationTransform.cs @@ -133,7 +133,7 @@ public FeatureContributionCalculatingTransformer(IHostEnvironment env, ICalculat int numPositiveContributions = FeatureContributionCalculatingEstimator.Defaults.NumPositiveContributions, int numNegativeContributions = FeatureContributionCalculatingEstimator.Defaults.NumNegativeContributions, bool normalize = FeatureContributionCalculatingEstimator.Defaults.Normalize) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)), new[] { (input: featureColumn, output: DefaultColumnNames.FeatureContributions) }) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(FeatureContributionCalculatingTransformer)), new[] { (name: DefaultColumnNames.FeatureContributions, source: featureColumn) }) { Host.CheckValue(modelParameters, nameof(modelParameters)); Host.CheckNonEmpty(featureColumn, nameof(featureColumn)); @@ -221,11 +221,11 @@ public Mapper(FeatureContributionCalculatingTransformer parent, Schema schema) _parent = parent; // Check that the featureColumn is present and has the expected type. - if (!schema.TryGetColumnIndex(_parent.ColumnPairs[0].input, out _featureColumnIndex)) - throw Host.ExceptSchemaMismatch(nameof(schema), "input", _parent.ColumnPairs[0].input); + if (!schema.TryGetColumnIndex(_parent.ColumnPairs[0].inputColumnName, out _featureColumnIndex)) + throw Host.ExceptSchemaMismatch(nameof(schema), "input", _parent.ColumnPairs[0].inputColumnName); _featureColumnType = schema[_featureColumnIndex].Type as VectorType; if (_featureColumnType == null || _featureColumnType.ItemType != NumberType.R4) - throw Host.ExceptSchemaMismatch(nameof(schema), "feature column", _parent.ColumnPairs[0].input, "Expected type is vector of float.", _featureColumnType.ItemType.ToString()); + throw Host.ExceptSchemaMismatch(nameof(schema), "feature column", _parent.ColumnPairs[0].inputColumnName, "Expected type is vector of float.", _featureColumnType.ItemType.ToString()); if (InputSchema[_featureColumnIndex].HasSlotNames(_featureColumnType.Size)) InputSchema[_featureColumnIndex].Metadata.GetValue(MetadataUtils.Kinds.SlotNames, ref _slotNames); diff --git a/src/Microsoft.ML.Data/Transforms/Hashing.cs b/src/Microsoft.ML.Data/Transforms/Hashing.cs index f315a1b884..351bf44ad5 100644 --- a/src/Microsoft.ML.Data/Transforms/Hashing.cs +++ b/src/Microsoft.ML.Data/Transforms/Hashing.cs @@ -117,8 +117,8 @@ internal bool TryUnparse(StringBuilder sb) public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly int HashBits; public readonly uint Seed; public readonly bool Ordered; @@ -127,8 +127,8 @@ public sealed class ColumnInfo /// /// Describes how the transformer handles one column pair. /// - /// Name of input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Number of bits to hash into. Must be between 1 and 31, inclusive. /// Hashing seed. /// Whether the position of each term should be included in the hash. @@ -136,8 +136,8 @@ public sealed class ColumnInfo /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - public ColumnInfo(string input, - string output = null, + public ColumnInfo(string name, + string inputColumnName = null, int hashBits = HashingEstimator.Defaults.HashBits, uint seed = HashingEstimator.Defaults.Seed, bool ordered = HashingEstimator.Defaults.Ordered, @@ -147,19 +147,19 @@ public ColumnInfo(string input, throw Contracts.ExceptParam(nameof(invertHash), "Value too small, must be -1 or larger"); if (invertHash != 0 && hashBits >= 31) throw Contracts.ExceptParam(nameof(hashBits), $"Cannot support invertHash for a {0} bit hash. 30 is the maximum possible.", hashBits); - Contracts.CheckNonWhiteSpace(input, nameof(input)); - Input = input; - Output = output ?? input; + Contracts.CheckNonWhiteSpace(name, nameof(name)); + Name = name; + InputColumnName = inputColumnName ?? name; HashBits = hashBits; Seed = seed; Ordered = ordered; InvertHash = invertHash; } - internal ColumnInfo(string input, string output, ModelLoadContext ctx) + internal ColumnInfo(string name, string inputColumnName, ModelLoadContext ctx) { - Input = input; - Output = output; + Name = name; + InputColumnName = inputColumnName; // *** Binary format *** // int: HashBits // uint: HashSeed @@ -190,7 +190,7 @@ internal void Save(ModelSaveContext ctx) internal const string Summary = "Converts column values into hashes. This transform accepts text and keys as inputs. It works on single- and vector-valued columns, " + "and hashes each slot in a vector separately."; - public const string LoaderSignature = "HashTransform"; + internal const string LoaderSignature = "HashTransform"; private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -214,16 +214,16 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol throw Host.ExceptParam(nameof(inputSchema), HashingEstimator.ExpectedColumnType); } - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckNonEmpty(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } private ColumnType GetOutputType(Schema inputSchema, ColumnInfo column) { var keyCount = (ulong)1 << column.HashBits; - inputSchema.TryGetColumnIndex(column.Input, out int srcCol); + inputSchema.TryGetColumnIndex(column.InputColumnName, out int srcCol); var itemType = new KeyType(typeof(uint), keyCount); var srcType = inputSchema[srcCol].Type; if (srcType is VectorType vectorType) @@ -258,9 +258,9 @@ internal HashingTransformer(IHostEnvironment env, IDataView input, params Column var sourceColumnsForInvertHash = new List(); for (int i = 0; i < _columns.Length; i++) { - Schema.Column? srcCol = input.Schema.GetColumnOrNull(ColumnPairs[i].input); + Schema.Column? srcCol = input.Schema.GetColumnOrNull(ColumnPairs[i].inputColumnName); if (srcCol == null) - throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input); + throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].inputColumnName); CheckInputColumn(input.Schema, i, srcCol.Value.Index); types[i] = GetOutputType(input.Schema, _columns[i]); @@ -317,7 +317,7 @@ private Delegate GetGetterCore(Row input, int iinfo, out Action disposer) Host.AssertValue(input); Host.Assert(0 <= iinfo && iinfo < _columns.Length); disposer = null; - input.Schema.TryGetColumnIndex(_columns[iinfo].Input, out int srcCol); + input.Schema.TryGetColumnIndex(_columns[iinfo].InputColumnName, out int srcCol); var srcType = input.Schema[srcCol].Type; if (!(srcType is VectorType vectorType)) return ComposeGetterOne(input, iinfo, srcCol, srcType); @@ -344,7 +344,7 @@ private HashingTransformer(IHost host, ModelLoadContext ctx) var columnsLength = ColumnPairs.Length; _columns = new ColumnInfo[columnsLength]; for (int i = 0; i < columnsLength; i++) - _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, ctx); + _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, ctx); TextModelHelper.LoadAll(Host, ctx, columnsLength, out _keyValues, out _kvTypes); } @@ -388,8 +388,9 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData { var item = args.Column[i]; var kind = item.InvertHash ?? args.InvertHash; - cols[i] = new ColumnInfo(item.Source ?? item.Name, + cols[i] = new ColumnInfo( item.Name, + item.Source ?? item.Name, item.HashBits ?? args.HashBits, item.Seed ?? args.Seed, item.Ordered ?? args.Ordered, @@ -858,13 +859,13 @@ private sealed class Mapper : OneToOneMapperBase private sealed class ColInfo { public readonly string Name; - public readonly string Source; + public readonly string InputColumnName; public readonly ColumnType TypeSrc; - public ColInfo(string name, string source, ColumnType type) + public ColInfo(string outputColumnName, string inputColumnName, ColumnType type) { - Name = name; - Source = source; + Name = outputColumnName; + InputColumnName = inputColumnName; TypeSrc = type; } } @@ -886,14 +887,14 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colIndex); var meta = new MetadataBuilder(); meta.Add(InputSchema[colIndex].Metadata, name => name == MetadataUtils.Kinds.SlotNames); if (_parent._kvTypes != null && _parent._kvTypes[i] != null) AddMetaKeyValues(i, meta); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i], meta.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], meta.GetMetadata()); } return result; } @@ -921,7 +922,7 @@ private InvertHashHelper(Row row, ColumnInfo ex) { Contracts.AssertValue(row); Row = row; - row.Schema.TryGetColumnIndex(ex.Input, out int srcCol); + row.Schema.TryGetColumnIndex(ex.InputColumnName, out int srcCol); _srcCol = srcCol; _srcType = row.Schema[srcCol].Type; _ex = ex; @@ -940,8 +941,7 @@ private InvertHashHelper(Row row, ColumnInfo ex) /// A hash getter, built on top of . public static InvertHashHelper Create(Row row, ColumnInfo ex, int invertHashMaxCount, Delegate dstGetter) { - row.Schema.TryGetColumnIndex(ex.Input, out int srcCol); - + row.Schema.TryGetColumnIndex(ex.InputColumnName, out int srcCol); ColumnType typeSrc = row.Schema[srcCol].Type; VectorType vectorTypeSrc = typeSrc as VectorType; @@ -1206,16 +1206,17 @@ internal static bool IsColumnTypeValid(ColumnType type) /// Initializes a new instance of . /// /// Host Environment. - /// Name of the column to be transformed. - /// Name of the output column. If this is null '' will be used. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. + /// If set to , the value of the will be used as source. /// Number of bits to hash into. Must be between 1 and 31, inclusive. /// During hashing we constuct mappings between original values and the produced hash values. /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - public HashingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, + public HashingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int hashBits = Defaults.HashBits, int invertHash = Defaults.InvertHash) - : this(env, new HashingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, hashBits: hashBits, invertHash: invertHash)) + : this(env, new HashingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, hashBits: hashBits, invertHash: invertHash)) { } @@ -1239,8 +1240,8 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (!IsColumnTypeValid(col.ItemType)) throw _host.ExceptParam(nameof(inputSchema), ExpectedColumnType); var metadata = new List(); @@ -1248,7 +1249,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) metadata.Add(slotMeta); if (colInfo.InvertHash != 0) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.ItemType is VectorType ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(metadata)); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, col.ItemType is VectorType ? SchemaShape.Column.VectorKind.Vector : SchemaShape.Column.VectorKind.Scalar, NumberType.U4, true, new SchemaShape(metadata)); } return new SchemaShape(result.Values); } diff --git a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs index 3b2cb0788b..f0bba8d894 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToValue.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToValue.cs @@ -64,10 +64,12 @@ public sealed class Arguments : TransformInputBase public Column[] Column; } - public const string LoaderSignature = "KeyToValueTransform"; - public const string UserName = "Key To Value Transform"; + internal const string LoaderSignature = "KeyToValueTransform"; - public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); + [BestFriend] + internal const string UserName = "Key To Value Transform"; + + public IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); private static VersionInfo GetVersionInfo() { @@ -91,7 +93,7 @@ public KeyToValueMappingTransformer(IHostEnvironment env, string columnName) /// /// Create a that takes multiple pairs of columns. /// - public KeyToValueMappingTransformer(IHostEnvironment env, params (string input, string output)[] columns) + public KeyToValueMappingTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(KeyToValueMappingTransformer)), columns) { } @@ -107,7 +109,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat env.CheckValue(input, nameof(input)); env.CheckNonEmpty(args.Column, nameof(args.Column)); - var transformer = new KeyToValueMappingTransformer(env, args.Column.Select(c => (c.Source ?? c.Name, c.Name)).ToArray()); + var transformer = new KeyToValueMappingTransformer(env, args.Column.Select(c => (c.Name, c.Source ?? c.Name)).ToArray()); return transformer.MakeDataTransform(input); } @@ -176,7 +178,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var meta = new MetadataBuilder(); meta.Add(InputSchema[ColMapNewToOld[i]].Metadata, name => name == MetadataUtils.Kinds.SlotNames); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i], meta.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], meta.GetMetadata()); } return result; } @@ -191,20 +193,20 @@ public void SaveAsPfa(BoundPfaContext ctx) for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; ++iinfo) { var info = _parent.ColumnPairs[iinfo]; - var srcName = info.input; + var srcName = info.inputColumnName; string srcToken = ctx.TokenOrNullForName(srcName); if (srcToken == null) { - toHide.Add(info.output); + toHide.Add(info.outputColumnName); continue; } var result = _kvMaps[iinfo].SavePfa(ctx, srcToken); if (result == null) { - toHide.Add(info.output); + toHide.Add(info.outputColumnName); continue; } - toDeclare.Add(new KeyValuePair(info.output, result)); + toDeclare.Add(new KeyValuePair(info.outputColumnName, result)); } ctx.Hide(toHide.ToArray()); ctx.DeclareVar(toDeclare.ToArray()); @@ -508,7 +510,7 @@ public KeyToValueMappingEstimator(IHostEnvironment env, string columnName) { } - public KeyToValueMappingEstimator(IHostEnvironment env, params (string input, string output)[] columns) + public KeyToValueMappingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(KeyToValueMappingEstimator)), new KeyToValueMappingTransformer(env, columns)) { } @@ -519,19 +521,19 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input); + if (!inputSchema.TryFindColumn(colInfo.inputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName); if (!col.IsKey) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input, "key type", col.GetTypeString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName, "key type", col.GetTypeString()); if (!col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var keyMetaCol)) - throw Host.ExceptParam(nameof(inputSchema), $"Input column '{colInfo.input}' doesn't contain key values metadata"); + throw Host.ExceptParam(nameof(inputSchema), $"Input column '{colInfo.inputColumnName}' doesn't contain key values metadata"); SchemaShape metadata = null; if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotCol)) metadata = new SchemaShape(new[] { slotCol }); - result[colInfo.output] = new SchemaShape.Column(colInfo.output, col.Kind, keyMetaCol.ItemType, keyMetaCol.IsKey, metadata); + result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName, col.Kind, keyMetaCol.ItemType, keyMetaCol.IsKey, metadata); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs index e41cc919ac..46c68ba7c5 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVector.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVector.cs @@ -95,21 +95,21 @@ public sealed class Arguments /// public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly bool Bag; /// /// Describes how the transformer handles one column pair. /// - /// Name of input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Whether to combine multiple indicator vectors into a single bag vector instead of concatenating them. This is only relevant when the input column is a vector. - public ColumnInfo(string input, string output = null, bool bag = KeyToVectorMappingEstimator.Defaults.Bag) + public ColumnInfo(string name, string inputColumnName = null, bool bag = KeyToVectorMappingEstimator.Defaults.Bag) { - Contracts.CheckNonWhiteSpace(input, nameof(input)); - Input = input; - Output = output ?? input; + Contracts.CheckNonWhiteSpace(name, nameof(name)); + Name = name; + InputColumnName = inputColumnName ?? name; Bag = bag; } } @@ -119,10 +119,10 @@ public ColumnInfo(string input, string output = null, bool bag = KeyToVectorMapp public IReadOnlyCollection Columns => _columns.AsReadOnly(); private readonly ColumnInfo[] _columns; - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } private string TestIsKey(ColumnType type) @@ -137,7 +137,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol var type = inputSchema[srcCol].Type; string reason = TestIsKey(type); if (reason != null) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, reason, type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, reason, type.ToString()); } public KeyToVectorMappingTransformer(IHostEnvironment env, params ColumnInfo[] columns) : @@ -146,8 +146,8 @@ public KeyToVectorMappingTransformer(IHostEnvironment env, params ColumnInfo[] c _columns = columns.ToArray(); } - public const string LoaderSignature = "KeyToVectorTransform"; - public const string UserName = "KeyToVectorTransform"; + internal const string LoaderSignature = "KeyToVectorTransform"; + internal const string UserName = "KeyToVectorTransform"; internal const string Summary = "Converts a key column to an indicator vector."; private static VersionInfo GetVersionInfo() @@ -207,7 +207,7 @@ private KeyToVectorMappingTransformer(IHost host, ModelLoadContext ctx) _columns = new ColumnInfo[columnsLength]; for (int i = 0; i < columnsLength; i++) - _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, bags[i]); + _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, bags[i]); } // Factory method for SignatureDataTransform. @@ -223,8 +223,9 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData { var item = args.Column[i]; - cols[i] = new ColumnInfo(item.Source ?? item.Name, + cols[i] = new ColumnInfo( item.Name, + item.Source ?? item.Name, item.Bag ?? args.Bag); }; return new KeyToVectorMappingTransformer(env, cols).MakeDataTransform(input); @@ -245,13 +246,13 @@ private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx, ISaveAsPfa private sealed class ColInfo { public readonly string Name; - public readonly string Source; + public readonly string InputColumnName; public readonly ColumnType TypeSrc; - public ColInfo(string name, string source, ColumnType type) + public ColInfo(string outputColumnName, string inputColumnName, ColumnType type) { - Name = name; - Source = source; + Name = outputColumnName; + InputColumnName = inputColumnName; TypeSrc = type; } } @@ -283,11 +284,11 @@ private ColInfo[] CreateInfos(Schema inputSchema) var infos = new ColInfo[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName); var type = inputSchema[colSrc].Type; _parent.CheckInputColumn(inputSchema, i, colSrc); - infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); + infos[i] = new ColInfo(_parent.ColumnPairs[i].outputColumnName, _parent.ColumnPairs[i].inputColumnName, type); } return infos; } @@ -297,18 +298,18 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colIndex); Host.Assert(colIndex >= 0); var builder = new MetadataBuilder(); AddMetadata(i, builder); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i], builder.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], builder.GetMetadata()); } return result; } private void AddMetadata(int iinfo, MetadataBuilder builder) { - InputSchema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol); + InputSchema.TryGetColumnIndex(_infos[iinfo].InputColumnName, out int srcCol); var inputMetadata = InputSchema[srcCol].Metadata; var srcType = _infos[iinfo].TypeSrc; @@ -379,7 +380,7 @@ private void GetSlotNames(int iinfo, ref VBuffer> dst) // Get the source slot names, defaulting to empty text. var namesSlotSrc = default(VBuffer>); - var inputMetadata = InputSchema[_infos[iinfo].Source].Metadata; + var inputMetadata = InputSchema[_infos[iinfo].InputColumnName].Metadata; Contracts.AssertValue(inputMetadata); var typeSlotSrc = inputMetadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type as VectorType; if (typeSlotSrc != null && typeSlotSrc.Size == typeSrc.Size && typeSlotSrc.ItemType is TextType) @@ -477,7 +478,7 @@ private ValueGetter> MakeGetterOne(Row input, int iinfo) int size = keyTypeSrc.GetCountAsInt32(Host); Host.Assert(size == _types[iinfo].Size); Host.Assert(size > 0); - input.Schema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol); + input.Schema.TryGetColumnIndex(_infos[iinfo].InputColumnName, out int srcCol); Host.Assert(srcCol >= 0); var getSrc = RowCursorUtils.GetGetterAs(NumberType.U4, input, srcCol); var src = default(uint); @@ -518,7 +519,7 @@ private ValueGetter> MakeGetterBag(Row input, int iinfo) int cv = srcVectorType.Size; Host.Assert(cv >= 0); - input.Schema.TryGetColumnIndex(info.Source, out int srcCol); + input.Schema.TryGetColumnIndex(info.InputColumnName, out int srcCol); Host.Assert(srcCol >= 0); var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.U4, input, srcCol); var src = default(VBuffer); @@ -565,7 +566,7 @@ private ValueGetter> MakeGetterInd(Row input, int iinfo) int cv = srcVectorType.Size; Host.Assert(cv >= 0); Host.Assert(_types[iinfo].Size == size * cv); - input.Schema.TryGetColumnIndex(info.Source, out int srcCol); + input.Schema.TryGetColumnIndex(info.InputColumnName, out int srcCol); var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.U4, input, srcCol); var src = default(VBuffer); return @@ -623,14 +624,14 @@ public void SaveAsOnnx(OnnxContext ctx) for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) { ColInfo info = _infos[iinfo]; - string sourceColumnName = info.Source; - if (!ctx.ContainsColumn(sourceColumnName)) + string inputColumnName = info.InputColumnName; + if (!ctx.ContainsColumn(inputColumnName)) { ctx.RemoveColumn(info.Name, false); continue; } - if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName), + if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName), ctx.AddIntermediateVariable(_types[iinfo], info.Name))) { ctx.RemoveColumn(info.Name, true); @@ -648,7 +649,7 @@ public void SaveAsPfa(BoundPfaContext ctx) for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) { var info = _infos[iinfo]; - var srcName = info.Source; + var srcName = info.InputColumnName; string srcToken = ctx.TokenOrNullForName(srcName); if (srcToken == null) { @@ -753,8 +754,8 @@ public KeyToVectorMappingEstimator(IHostEnvironment env, params KeyToVectorMappi { } - public KeyToVectorMappingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, bool bag = Defaults.Bag) - : this(env, new KeyToVectorMappingTransformer(env, new KeyToVectorMappingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, bag))) + public KeyToVectorMappingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, bool bag = Defaults.Bag) + : this(env, new KeyToVectorMappingTransformer(env, new KeyToVectorMappingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, bag))) { } @@ -769,10 +770,10 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (!col.ItemType.IsStandardScalar()) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var keyMeta)) @@ -783,7 +784,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) if (!colInfo.Bag || (col.Kind == SchemaShape.Column.VectorKind.Scalar)) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs index 307c0304f9..c7ca1bb2f6 100644 --- a/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/LabelConvertTransform.cs @@ -71,10 +71,10 @@ private static VersionInfo GetVersionInfo() /// /// Host Environment. /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the input column. If this is null '' will be used. - public LabelConvertTransform(IHostEnvironment env, IDataView input, string name, string source = null) - : this(env, new Arguments() { Column = new[] { new Column() { Source = source ?? name, Name = name } } }, input) + /// Name of the output column. + /// Name of the input column. If this is null '' will be used. + public LabelConvertTransform(IHostEnvironment env, IDataView input, string outputColumnName, string inputColumnName = null) + : this(env, new Arguments() { Column = new[] { new Column() { Source = inputColumnName ?? outputColumnName, Name = outputColumnName } } }, input) { } diff --git a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs index 12755fe949..25f52eabca 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizeColumn.cs @@ -269,13 +269,13 @@ public sealed class SupervisedBinArguments : BinArgumentsBase /// /// Host Environment. /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. - public static IDataView CreateMinMaxNormalizer(IHostEnvironment env, IDataView input, string name, string source = null) + /// Name of the output column. + /// Name of the column to be transformed. If this is null '' will be used. + public static IDataView CreateMinMaxNormalizer(IHostEnvironment env, IDataView input, string outputColumnName, string inputColumnName = null) { Contracts.CheckValue(env, nameof(env)); - var normalizer = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumn(source ?? name, name)); + var normalizer = new NormalizingEstimator(env, new NormalizingEstimator.MinMaxColumn(outputColumnName, inputColumnName ?? outputColumnName)); return normalizer.Fit(input).MakeDataTransform(input); } @@ -290,8 +290,8 @@ internal static IDataTransform Create(IHostEnvironment env, MinMaxArguments args var columns = args.Column .Select(col => new NormalizingEstimator.MinMaxColumn( - col.Source ?? col.Name, col.Name, + col.Source ?? col.Name, col.MaxTrainingExamples ?? args.MaxTrainingExamples, col.FixZero ?? args.FixZero)) .ToArray(); @@ -308,8 +308,8 @@ internal static IDataTransform Create(IHostEnvironment env, MeanVarArguments arg var columns = args.Column .Select(col => new NormalizingEstimator.MeanVarColumn( - col.Source ?? col.Name, col.Name, + col.Source ?? col.Name, col.MaxTrainingExamples ?? args.MaxTrainingExamples, col.FixZero ?? args.FixZero)) .ToArray(); @@ -328,8 +328,8 @@ internal static IDataTransform Create(IHostEnvironment env, LogMeanVarArguments var columns = args.Column .Select(col => new NormalizingEstimator.LogMeanVarColumn( - col.Source ?? col.Name, col.Name, + col.Source ?? col.Name, col.MaxTrainingExamples ?? args.MaxTrainingExamples, args.UseCdf)) .ToArray(); @@ -348,8 +348,8 @@ internal static IDataTransform Create(IHostEnvironment env, BinArguments args, I var columns = args.Column .Select(col => new NormalizingEstimator.BinningColumn( - col.Source ?? col.Name, col.Name, + col.Source ?? col.Name, col.MaxTrainingExamples ?? args.MaxTrainingExamples, col.FixZero ?? args.FixZero, col.NumBins ?? args.NumBins)) @@ -919,8 +919,8 @@ public static IColumnFunctionBuilder CreateBuilder(MinMaxArguments args, IHost h host.AssertValue(args); return CreateBuilder(new NormalizingEstimator.MinMaxColumn( - args.Column[icol].Source ?? args.Column[icol].Name, args.Column[icol].Name, + args.Column[icol].Source ?? args.Column[icol].Name, args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples, args.Column[icol].FixZero ?? args.FixZero), host, srcIndex, srcType, cursor); } @@ -955,8 +955,8 @@ public static IColumnFunctionBuilder CreateBuilder(MeanVarArguments args, IHost host.AssertValue(args); return CreateBuilder(new NormalizingEstimator.MeanVarColumn( - args.Column[icol].Source ?? args.Column[icol].Name, args.Column[icol].Name, + args.Column[icol].Source ?? args.Column[icol].Name, args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples, args.Column[icol].FixZero ?? args.FixZero, args.UseCdf), host, srcIndex, srcType, cursor); @@ -995,8 +995,8 @@ public static IColumnFunctionBuilder CreateBuilder(LogMeanVarArguments args, IHo host.AssertValue(args); return CreateBuilder(new NormalizingEstimator.LogMeanVarColumn( - args.Column[icol].Source ?? args.Column[icol].Name, args.Column[icol].Name, + args.Column[icol].Source ?? args.Column[icol].Name, args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples, args.UseCdf), host, srcIndex, srcType, cursor); } @@ -1021,7 +1021,7 @@ public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.LogMeanV if (vectorType.ItemType == NumberType.R8) return Dbl.MeanVarVecColumnFunctionBuilder.Create(column, host, vectorType, cursor.GetGetter>(srcIndex)); } - throw host.ExceptUserArg(nameof(column), "Wrong column type for column {0}. Expected: R4, R8, Vec or Vec. Got: {1}.", column.Input, srcType.ToString()); + throw host.ExceptUserArg(nameof(column), "Wrong column type for column {0}. Expected: R4, R8, Vec or Vec. Got: {1}.", column.InputColumnName, srcType.ToString()); } } @@ -1034,8 +1034,8 @@ public static IColumnFunctionBuilder CreateBuilder(BinArguments args, IHost host host.AssertValue(args); return CreateBuilder(new NormalizingEstimator.BinningColumn( - args.Column[icol].Source ?? args.Column[icol].Name, args.Column[icol].Name, + args.Column[icol].Source ?? args.Column[icol].Name, args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples, args.Column[icol].FixZero ?? args.FixZero, args.Column[icol].NumBins ?? args.NumBins), host, srcIndex, srcType, cursor); @@ -1060,7 +1060,7 @@ public static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.BinningC if (vectorType.ItemType == NumberType.R8) return Dbl.BinVecColumnFunctionBuilder.Create(column, host, vectorType, cursor.GetGetter>(srcIndex)); } - throw host.ExceptParam(nameof(column), "Wrong column type for column {0}. Expected: R4, R8, Vec or Vec. Got: {1}.", column.Input, srcType.ToString()); + throw host.ExceptParam(nameof(column), "Wrong column type for column {0}. Expected: R4, R8, Vec or Vec. Got: {1}.", column.InputColumnName, srcType.ToString()); } } @@ -1083,8 +1083,8 @@ public static IColumnFunctionBuilder CreateBuilder(SupervisedBinArguments args, return CreateBuilder( new NormalizingEstimator.SupervisedBinningColumn( - args.Column[icol].Source ?? args.Column[icol].Name, args.Column[icol].Name, + args.Column[icol].Source ?? args.Column[icol].Name, args.LabelColumn ?? DefaultColumnNames.Label, args.Column[icol].MaxTrainingExamples ?? args.MaxTrainingExamples, args.Column[icol].FixZero ?? args.FixZero, @@ -1121,7 +1121,7 @@ private static IColumnFunctionBuilder CreateBuilder(NormalizingEstimator.Supervi } throw host.ExceptParam(nameof(column), "Wrong column type for column {0}. Expected: R4, R8, Vec or Vec. Got: {1}.", - column.Input, + column.InputColumnName, srcType.ToString()); } diff --git a/src/Microsoft.ML.Data/Transforms/Normalizer.cs b/src/Microsoft.ML.Data/Transforms/Normalizer.cs index a5614331b4..e5091f30cf 100644 --- a/src/Microsoft.ML.Data/Transforms/Normalizer.cs +++ b/src/Microsoft.ML.Data/Transforms/Normalizer.cs @@ -68,37 +68,37 @@ public enum NormalizerMode public abstract class ColumnBase { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly long MaxTrainingExamples; - private protected ColumnBase(string input, string output, long maxTrainingExamples) + private protected ColumnBase(string name, string inputColumnName, long maxTrainingExamples) { - Contracts.CheckNonEmpty(input, nameof(input)); - Contracts.CheckNonEmpty(output, nameof(output)); + Contracts.CheckNonEmpty(name, nameof(name)); + Contracts.CheckNonEmpty(inputColumnName, nameof(inputColumnName)); Contracts.CheckParam(maxTrainingExamples > 1, nameof(maxTrainingExamples), "Must be greater than 1"); - Input = input; - Output = output; + Name = name; + InputColumnName = inputColumnName; MaxTrainingExamples = maxTrainingExamples; } internal abstract IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, ColumnType srcType, RowCursor cursor); - internal static ColumnBase Create(string input, string output, NormalizerMode mode) + internal static ColumnBase Create(string outputColumnName, string inputColumnName, NormalizerMode mode) { switch (mode) { case NormalizerMode.MinMax: - return new MinMaxColumn(input, output); + return new MinMaxColumn(outputColumnName, inputColumnName); case NormalizerMode.MeanVariance: - return new MeanVarColumn(input, output); + return new MeanVarColumn(outputColumnName, inputColumnName); case NormalizerMode.LogMeanVariance: - return new LogMeanVarColumn(input, output); + return new LogMeanVarColumn(outputColumnName, inputColumnName); case NormalizerMode.Binning: - return new BinningColumn(input, output); + return new BinningColumn(outputColumnName, inputColumnName); case NormalizerMode.SupervisedBinning: - return new SupervisedBinningColumn(input, output); + return new SupervisedBinningColumn(outputColumnName, inputColumnName); default: throw Contracts.ExceptParam(nameof(mode), "Unknown normalizer mode"); } @@ -109,8 +109,8 @@ public abstract class FixZeroColumnBase : ColumnBase { public readonly bool FixZero; - private protected FixZeroColumnBase(string input, string output, long maxTrainingExamples, bool fixZero) - : base(input, output, maxTrainingExamples) + private protected FixZeroColumnBase(string outputColumnName, string inputColumnName, long maxTrainingExamples, bool fixZero) + : base(outputColumnName, inputColumnName, maxTrainingExamples) { FixZero = fixZero; } @@ -118,8 +118,8 @@ private protected FixZeroColumnBase(string input, string output, long maxTrainin public sealed class MinMaxColumn : FixZeroColumnBase { - public MinMaxColumn(string input, string output = null, long maxTrainingExamples = Defaults.MaxTrainingExamples, bool fixZero = Defaults.FixZero) - : base(input, output ?? input, maxTrainingExamples, fixZero) + public MinMaxColumn(string outputColumnName, string inputColumnName = null, long maxTrainingExamples = Defaults.MaxTrainingExamples, bool fixZero = Defaults.FixZero) + : base(outputColumnName, inputColumnName ?? outputColumnName, maxTrainingExamples, fixZero) { } @@ -131,9 +131,9 @@ public sealed class MeanVarColumn : FixZeroColumnBase { public readonly bool UseCdf; - public MeanVarColumn(string input, string output = null, + public MeanVarColumn(string outputColumnName, string inputColumnName = null, long maxTrainingExamples = Defaults.MaxTrainingExamples, bool fixZero = Defaults.FixZero, bool useCdf = Defaults.MeanVarCdf) - : base(input, output ?? input, maxTrainingExamples, fixZero) + : base(outputColumnName, inputColumnName ?? outputColumnName, maxTrainingExamples, fixZero) { UseCdf = useCdf; } @@ -146,9 +146,9 @@ public sealed class LogMeanVarColumn : ColumnBase { public readonly bool UseCdf; - public LogMeanVarColumn(string input, string output = null, + public LogMeanVarColumn(string outputColumnName, string inputColumnName = null, long maxTrainingExamples = Defaults.MaxTrainingExamples, bool useCdf = Defaults.LogMeanVarCdf) - : base(input, output ?? input, maxTrainingExamples) + : base(outputColumnName, inputColumnName ?? outputColumnName, maxTrainingExamples) { UseCdf = useCdf; } @@ -161,9 +161,9 @@ public sealed class BinningColumn : FixZeroColumnBase { public readonly int NumBins; - public BinningColumn(string input, string output = null, + public BinningColumn(string outputColumnName, string inputColumnName = null, long maxTrainingExamples = Defaults.MaxTrainingExamples, bool fixZero = true, int numBins = Defaults.NumBins) - : base(input, output ?? input, maxTrainingExamples, fixZero) + : base(outputColumnName, inputColumnName ?? outputColumnName, maxTrainingExamples, fixZero) { NumBins = numBins; } @@ -178,13 +178,13 @@ public sealed class SupervisedBinningColumn : FixZeroColumnBase public readonly string LabelColumn; public readonly int MinBinSize; - public SupervisedBinningColumn(string input, string output = null, + public SupervisedBinningColumn(string outputColumnName, string inputColumnName = null, string labelColumn = DefaultColumnNames.Label, long maxTrainingExamples = Defaults.MaxTrainingExamples, bool fixZero = true, int numBins = Defaults.NumBins, int minBinSize = Defaults.MinBinSize) - : base(input, output ?? input, maxTrainingExamples, fixZero) + : base(outputColumnName, inputColumnName ?? outputColumnName, maxTrainingExamples, fixZero) { NumBins = numBins; LabelColumn = labelColumn; @@ -202,11 +202,12 @@ internal override IColumnFunctionBuilder MakeBuilder(IHost host, int srcIndex, C /// Initializes a new instance of . /// /// Host Environment. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. + /// If set to , the value of the will be used as source. /// The indicating how to the old values are mapped to the new values. - public NormalizingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, NormalizerMode mode = NormalizerMode.MinMax) - : this(env, mode, (inputColumn, outputColumn ?? inputColumn)) + public NormalizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, NormalizerMode mode = NormalizerMode.MinMax) + : this(env, mode, (outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -215,13 +216,13 @@ public NormalizingEstimator(IHostEnvironment env, string inputColumn, string out /// /// The private instance of . /// The indicating how to the old values are mapped to the new values. - /// An array of (inputColumn, outputColumn) tuples. - public NormalizingEstimator(IHostEnvironment env, NormalizerMode mode, params (string inputColumn, string outputColumn)[] columns) + /// An array of (outputColumnName, inputColumnName) tuples. + public NormalizingEstimator(IHostEnvironment env, NormalizerMode mode, params (string outputColumnName, string inputColumnName)[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(NormalizingEstimator)); _host.CheckValue(columns, nameof(columns)); - _columns = columns.Select(x => ColumnBase.Create(x.inputColumn, x.outputColumn, mode)).ToArray(); + _columns = columns.Select(x => ColumnBase.Create(x.outputColumnName, x.inputColumnName, mode)).ToArray(); } /// @@ -251,13 +252,13 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) foreach (var colInfo in _columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (col.Kind == SchemaShape.Column.VectorKind.VariableVector) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, "fixed-size vector or scalar", col.GetTypeString()); + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, "fixed-size vector or scalar", col.GetTypeString()); if (!col.ItemType.Equals(NumberType.R4) && !col.ItemType.Equals(NumberType.R8)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, "vector or scalar of R4 or R8", col.GetTypeString()); + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, "vector or scalar of R4 or R8", col.GetTypeString()); var isNormalizedMeta = new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false); @@ -265,7 +266,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) newMetadataKinds.Add(slotMeta); var meta = new SchemaShape(newMetadataKinds); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, col.ItemType, col.IsKey, meta); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, col.Kind, col.ItemType, col.IsKey, meta); } return new SchemaShape(result.Values); @@ -305,16 +306,16 @@ private static VersionInfo GetVersionInfo() public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly NormalizerModelParametersBase ModelParameters; internal readonly ColumnType InputType; internal readonly IColumnFunction ColumnFunction; - internal ColumnInfo(string input, string output, ColumnType inputType, IColumnFunction columnFunction) + internal ColumnInfo(string name, string inputColumnName, ColumnType inputType, IColumnFunction columnFunction) { - Input = input; - Output = output; + Name = name; + InputColumnName = inputColumnName; InputType = inputType; ColumnFunction = columnFunction; ModelParameters = columnFunction.GetNormalizerModelParams(); @@ -380,7 +381,7 @@ public ColumnFunctionAccessor(ImmutableArray infos) public readonly ImmutableArray Columns; private NormalizingTransformer(IHostEnvironment env, ColumnInfo[] columns) - : base(env.Register(nameof(NormalizingTransformer)), columns.Select(x => (x.Input, x.Output)).ToArray()) + : base(env.Register(nameof(NormalizingTransformer)), columns.Select(x => (x.Name, x.InputColumnName)).ToArray()) { Columns = ImmutableArray.Create(columns); ColumnFunctions = new ColumnFunctionAccessor(Columns); @@ -399,9 +400,9 @@ public static NormalizingTransformer Train(IHostEnvironment env, IDataView data, for (int i = 0; i < columns.Length; i++) { var info = columns[i]; - bool success = data.Schema.TryGetColumnIndex(info.Input, out srcCols[i]); + bool success = data.Schema.TryGetColumnIndex(info.InputColumnName, out srcCols[i]); if (!success) - throw env.ExceptSchemaMismatch(nameof(data), "input", info.Input); + throw env.ExceptSchemaMismatch(nameof(data), "input", info.InputColumnName); srcTypes[i] = data.Schema[srcCols[i]].Type; activeCols.Add(data.Schema[srcCols[i]]); @@ -459,7 +460,7 @@ public static NormalizingTransformer Train(IHostEnvironment env, IDataView data, for (int i = 0; i < columns.Length; i++) { var func = functionBuilders[i].CreateColumnFunction(); - result[i] = new ColumnInfo(columns[i].Input, columns[i].Output, srcTypes[i], func); + result[i] = new ColumnInfo(columns[i].Name, columns[i].InputColumnName, srcTypes[i], func); } return new NormalizingTransformer(env, result); @@ -482,7 +483,7 @@ private NormalizingTransformer(IHost host, ModelLoadContext ctx) var dir = string.Format("Normalizer_{0:000}", iinfo); var typeSrc = ColumnInfo.LoadType(ctx); ctx.LoadModel(Host, out var function, dir, Host, typeSrc); - cols[iinfo] = new ColumnInfo(ColumnPairs[iinfo].input, ColumnPairs[iinfo].output, typeSrc, function); + cols[iinfo] = new ColumnInfo(ColumnPairs[iinfo].outputColumnName, ColumnPairs[iinfo].inputColumnName, typeSrc, function); } Columns = ImmutableArray.Create(cols); @@ -501,9 +502,9 @@ private NormalizingTransformer(IHost host, ModelLoadContext ctx, IDataView input for (int iinfo = 0; iinfo < ColumnPairs.Length; iinfo++) { var dir = string.Format("Normalizer_{0:000}", iinfo); - var typeSrc = input.Schema[ColumnPairs[iinfo].input].Type; + var typeSrc = input.Schema[ColumnPairs[iinfo].inputColumnName].Type; ctx.LoadModel(Host, out var function, dir, Host, typeSrc); - cols[iinfo] = new ColumnInfo(ColumnPairs[iinfo].input, ColumnPairs[iinfo].output, typeSrc, function); + cols[iinfo] = new ColumnInfo(ColumnPairs[iinfo].outputColumnName, ColumnPairs[iinfo].inputColumnName, typeSrc, function); } Columns = ImmutableArray.Create(cols); @@ -562,10 +563,10 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol var colType = inputSchema[srcCol].Type; VectorType vectorType = colType as VectorType; if (vectorType != null && !vectorType.IsKnownSize) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, expectedType, "variable-size vector"); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, expectedType, "variable-size vector"); ColumnType itemType = vectorType?.ItemType ?? colType; if (!itemType.Equals(NumberType.R4) && !itemType.Equals(NumberType.R8)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, expectedType, colType.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, expectedType, colType.ToString()); } // Temporary: enables SignatureDataTransform factory methods. @@ -591,7 +592,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var result = new Schema.DetachedColumn[_parent.Columns.Length]; for (int i = 0; i < _parent.Columns.Length; i++) - result[i] = new Schema.DetachedColumn(_parent.Columns[i].Output, _parent.Columns[i].InputType, MakeMetadata(i)); + result[i] = new Schema.DetachedColumn(_parent.Columns[i].Name, _parent.Columns[i].InputType, MakeMetadata(i)); return result; } @@ -623,17 +624,17 @@ public void SaveAsOnnx(OnnxContext ctx) for (int iinfo = 0; iinfo < _parent.Columns.Length; ++iinfo) { var info = _parent.Columns[iinfo]; - string sourceColumnName = info.Input; - if (!ctx.ContainsColumn(sourceColumnName)) + string inputColumnName = info.InputColumnName; + if (!ctx.ContainsColumn(inputColumnName)) { - ctx.RemoveColumn(info.Output, false); + ctx.RemoveColumn(info.Name, false); continue; } - if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName), - ctx.AddIntermediateVariable(info.InputType, info.Output))) + if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName), + ctx.AddIntermediateVariable(info.InputType, info.Name))) { - ctx.RemoveColumn(info.Output, true); + ctx.RemoveColumn(info.Name, true); } } } @@ -648,20 +649,20 @@ public void SaveAsPfa(BoundPfaContext ctx) for (int iinfo = 0; iinfo < _parent.Columns.Length; ++iinfo) { var info = _parent.Columns[iinfo]; - var srcName = info.Input; + var srcName = info.InputColumnName; string srcToken = ctx.TokenOrNullForName(srcName); if (srcToken == null) { - toHide.Add(info.Output); + toHide.Add(info.Name); continue; } var result = SaveAsPfaCore(ctx, iinfo, info, srcToken); if (result == null) { - toHide.Add(info.Output); + toHide.Add(info.Name); continue; } - toDeclare.Add(new KeyValuePair(info.Output, result)); + toDeclare.Add(new KeyValuePair(info.Name, result)); } ctx.Hide(toHide.ToArray()); ctx.DeclareVar(toDeclare.ToArray()); diff --git a/src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs b/src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs index a89a8a0ee2..c1d8a951d1 100644 --- a/src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs +++ b/src/Microsoft.ML.Data/Transforms/NormalizerCatalog.cs @@ -16,8 +16,8 @@ public static class NormalizerCatalog /// Normalize (rescale) the column according to the specified . /// /// The transform catalog - /// The column name - /// The column name + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// The used to map the old values in the new scale. /// /// @@ -27,10 +27,9 @@ public static class NormalizerCatalog /// /// public static NormalizingEstimator Normalize(this TransformsCatalog catalog, - string inputName, - string outputName = null, + string outputColumnName, string inputColumnName = null, NormalizingEstimator.NormalizerMode mode = NormalizingEstimator.NormalizerMode.MinMax) - => new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), inputName, outputName, mode); + => new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName ?? outputColumnName, mode); /// /// Normalize (rescale) several columns according to the specified . @@ -47,7 +46,7 @@ public static NormalizingEstimator Normalize(this TransformsCatalog catalog, /// public static NormalizingEstimator Normalize(this TransformsCatalog catalog, NormalizingEstimator.NormalizerMode mode, - params (string input, string output)[] columns) + params (string outputColumnName, string inputColumnName)[] columns) => new NormalizingEstimator(CatalogUtils.GetEnvironment(catalog), mode, columns); /// diff --git a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs index 45f069049f..f1cf1f52b3 100644 --- a/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs +++ b/src/Microsoft.ML.Data/Transforms/OneToOneTransformerBase.cs @@ -14,19 +14,19 @@ namespace Microsoft.ML.Data /// public abstract class OneToOneTransformerBase : RowToRowTransformerBase { - protected readonly (string input, string output)[] ColumnPairs; + protected readonly (string outputColumnName, string inputColumnName)[] ColumnPairs; - protected OneToOneTransformerBase(IHost host, params (string input, string output)[] columns) : base(host) + protected OneToOneTransformerBase(IHost host, params (string outputColumnName, string inputColumnName)[] columns) : base(host) { host.CheckValue(columns, nameof(columns)); var newNames = new HashSet(); foreach (var column in columns) { - host.CheckNonEmpty(column.input, nameof(columns)); - host.CheckNonEmpty(column.output, nameof(columns)); + host.CheckNonEmpty(column.inputColumnName, nameof(columns)); + host.CheckNonEmpty(column.outputColumnName, nameof(columns)); - if (!newNames.Add(column.output)) - throw Contracts.ExceptParam(nameof(columns), $"Output column '{column.output}' specified multiple times"); + if (!newNames.Add(column.outputColumnName)) + throw Contracts.ExceptParam(nameof(columns), $"Name of the result column '{column.outputColumnName}' specified multiple times"); } ColumnPairs = columns; @@ -41,12 +41,12 @@ protected OneToOneTransformerBase(IHost host, ModelLoadContext ctx) : base(host) // int: id of input column name int n = ctx.Reader.ReadInt32(); - ColumnPairs = new (string input, string output)[n]; + ColumnPairs = new (string outputColumnName, string inputColumnName)[n]; for (int i = 0; i < n; i++) { - string output = ctx.LoadNonEmptyString(); - string input = ctx.LoadNonEmptyString(); - ColumnPairs[i] = (input, output); + string outputColumnName = ctx.LoadNonEmptyString(); + string inputColumnName = ctx.LoadNonEmptyString(); + ColumnPairs[i] = (outputColumnName, inputColumnName); } } @@ -63,8 +63,8 @@ protected void SaveColumns(ModelSaveContext ctx) ctx.Writer.Write(ColumnPairs.Length); for (int i = 0; i < ColumnPairs.Length; i++) { - ctx.SaveNonEmptyString(ColumnPairs[i].output); - ctx.SaveNonEmptyString(ColumnPairs[i].input); + ctx.SaveNonEmptyString(ColumnPairs[i].outputColumnName); + ctx.SaveNonEmptyString(ColumnPairs[i].inputColumnName); } } @@ -73,8 +73,8 @@ private void CheckInput(Schema inputSchema, int col, out int srcCol) Contracts.AssertValue(inputSchema); Contracts.Assert(0 <= col && col < ColumnPairs.Length); - if (!inputSchema.TryGetColumnIndex(ColumnPairs[col].input, out srcCol)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input); + if (!inputSchema.TryGetColumnIndex(ColumnPairs[col].inputColumnName, out srcCol)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName); CheckInputColumn(inputSchema, col, srcCol); } diff --git a/src/Microsoft.ML.Data/Transforms/TransformBase.cs b/src/Microsoft.ML.Data/Transforms/TransformBase.cs index 925e6f1c90..8a34204b1e 100644 --- a/src/Microsoft.ML.Data/Transforms/TransformBase.cs +++ b/src/Microsoft.ML.Data/Transforms/TransformBase.cs @@ -584,14 +584,14 @@ void ISaveAsOnnx.SaveAsOnnx(OnnxContext ctx) for (int iinfo = 0; iinfo < Infos.Length; ++iinfo) { ColInfo info = Infos[iinfo]; - string sourceColumnName = Source.Schema[info.Source].Name; - if (!ctx.ContainsColumn(sourceColumnName)) + string inputColumnName = Source.Schema[info.Source].Name; + if (!ctx.ContainsColumn(inputColumnName)) { ctx.RemoveColumn(info.Name, false); continue; } - if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName), + if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName), ctx.AddIntermediateVariable(OutputSchema[_bindings.MapIinfoToCol(iinfo)].Type, info.Name))) { ctx.RemoveColumn(info.Name, true); diff --git a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs index a3279e510c..aa72fbe28b 100644 --- a/src/Microsoft.ML.Data/Transforms/TypeConverting.cs +++ b/src/Microsoft.ML.Data/Transforms/TypeConverting.cs @@ -177,22 +177,22 @@ private static VersionInfo GetVersionInfo() /// public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly DataKind OutputKind; public readonly KeyCount OutputKeyCount; /// /// Describes how the transformer handles one column pair. /// - /// Name of input column. - /// Name of output column. + /// Name of the column resulting from the transformation of . /// The expected kind of the converted column. - /// New key count, if we work with key type. - public ColumnInfo(string input, string output, DataKind outputKind, KeyCount outputKeyCount = null) + /// Name of column to transform. If set to , the value of the will be used as source. + /// New key range, if we work with key type. + public ColumnInfo(string name, DataKind outputKind, string inputColumnName, KeyCount outputKeyCount = null) { - Input = input; - Output = output; + Name = name; + InputColumnName = inputColumnName ?? name; OutputKind = outputKind; OutputKeyCount = outputKeyCount; } @@ -200,22 +200,22 @@ public ColumnInfo(string input, string output, DataKind outputKind, KeyCount out private readonly ColumnInfo[] _columns; - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckNonEmpty(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } /// /// Convinence constructor for simple one column case. /// /// Host Environment. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. + /// Name of the output column. + /// Name of the column to be transformed. If this is null '' will be used. /// The expected type of the converted column. /// New key count if we work with key type. - public TypeConvertingTransformer(IHostEnvironment env, string inputColumn, string outputColumn, DataKind outputKind, KeyCount outputKeyCount = null) - : this(env, new ColumnInfo(inputColumn, outputColumn, outputKind, outputKeyCount)) + public TypeConvertingTransformer(IHostEnvironment env, string outputColumnName, DataKind outputKind, string inputColumnName = null, KeyCount outputKeyCount = null) + : this(env, new ColumnInfo(outputColumnName, outputKind, inputColumnName ?? outputColumnName, outputKeyCount)) { } @@ -312,7 +312,7 @@ private TypeConvertingTransformer(IHost host, ModelLoadContext ctx) keyCount = new KeyCount(count); } - _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, kind, keyCount); + _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, kind, ColumnPairs[i].inputColumnName, keyCount); } } @@ -360,8 +360,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat { kind = tempResultType.Value; } - - cols[i] = new ColumnInfo(item.Source ?? item.Name, item.Name, kind, keyCount); + cols[i] = new ColumnInfo(item.Name, kind, item.Source ?? item.Name, keyCount); }; return new TypeConvertingTransformer(env, cols).MakeDataTransform(input); } @@ -423,13 +422,13 @@ public Mapper(TypeConvertingTransformer parent, Schema inputSchema) _srcCols = new int[_parent._columns.Length]; for (int i = 0; i < _parent._columns.Length; i++) { - inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i]); + inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i]); var srcCol = inputSchema[_srcCols[i]]; if (!CanConvertToType(Host, srcCol.Type, _parent._columns[i].OutputKind, _parent._columns[i].OutputKeyCount, out PrimitiveType itemType, out _types[i])) { throw Host.ExceptParam(nameof(inputSchema), "source column '{0}' with item type '{1}' is not compatible with destination type '{2}'", - _parent._columns[i].Input, srcCol.Type, itemType); + _parent._columns[i].InputColumnName, srcCol.Type, itemType); } } } @@ -484,7 +483,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() ValueGetter getter = (ref bool dst) => dst = true; builder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, getter); } - result[i] = new Schema.DetachedColumn(_parent._columns[i].Output, _types[i], builder.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent._columns[i].Name, _types[i], builder.GetMetadata()); } return result; } @@ -505,17 +504,17 @@ public void SaveAsOnnx(OnnxContext ctx) for (int iinfo = 0; iinfo < _parent._columns.Length; ++iinfo) { - string sourceColumnName = _parent._columns[iinfo].Input; - if (!ctx.ContainsColumn(sourceColumnName)) + string inputColumnName = _parent._columns[iinfo].InputColumnName; + if (!ctx.ContainsColumn(inputColumnName)) { - ctx.RemoveColumn(_parent._columns[iinfo].Output, false); + ctx.RemoveColumn(_parent._columns[iinfo].Name, false); continue; } - if (!SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(sourceColumnName), - ctx.AddIntermediateVariable(_types[iinfo], _parent._columns[iinfo].Output))) + if (!SaveAsOnnxCore(ctx, iinfo, ctx.GetVariableName(inputColumnName), + ctx.AddIntermediateVariable(_types[iinfo], _parent._columns[iinfo].Name))) { - ctx.RemoveColumn(_parent._columns[iinfo].Output, true); + ctx.RemoveColumn(_parent._columns[iinfo].Name, true); } } } @@ -550,13 +549,13 @@ internal sealed class Defaults /// Convinence constructor for simple one column case. /// /// Host Environment. - /// Name of the input column. - /// Name of the output column. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// The expected type of the converted column. public TypeConvertingEstimator(IHostEnvironment env, - string inputColumn, string outputColumn = null, + string outputColumnName, string inputColumnName = null, DataKind outputKind = Defaults.DefaultOutputKind) - : this(env, new TypeConvertingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, outputKind)) + : this(env, new TypeConvertingTransformer.ColumnInfo(outputColumnName, outputKind, inputColumnName ?? outputColumnName)) { } @@ -574,12 +573,12 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (!TypeConvertingTransformer.GetNewType(Host, col.ItemType, colInfo.OutputKind, colInfo.OutputKeyCount, out PrimitiveType newType)) - throw Host.ExceptParam(nameof(inputSchema), $"Can't convert {colInfo.Input} into {newType.ToString()}"); + throw Host.ExceptParam(nameof(inputSchema), $"Can't convert {colInfo.InputColumnName} into {newType.ToString()}"); if (!Data.Conversion.Conversions.Instance.TryGetStandardConversion(col.ItemType, newType, out Delegate del, out bool identity)) - throw Host.ExceptParam(nameof(inputSchema), $"Don't know how to convert {colInfo.Input} into {newType.ToString()}"); + throw Host.ExceptParam(nameof(inputSchema), $"Don't know how to convert {colInfo.InputColumnName} into {newType.ToString()}"); var metadata = new List(); if (col.ItemType is BoolType && newType is NumberType) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); @@ -592,7 +591,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.IsNormalized, out var normMeta)) if (col.ItemType is NumberType && newType is NumberType) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, normMeta.ItemType, false)); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, newType, false, col.Metadata); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, col.Kind, newType, false, col.Metadata); } return new SchemaShape(result.Values); } diff --git a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs index 0539c520e6..70b46fad01 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueMapping.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueMapping.cs @@ -39,7 +39,7 @@ namespace Microsoft.ML.Transforms.Conversions /// public class ValueMappingEstimator : TrivialEstimator { - private readonly (string input, string output)[] _columns; + private readonly (string outputColumnName, string inputColumnName)[] _columns; /// /// Constructs the ValueMappingEstimator, key type -> value type mapping @@ -49,7 +49,7 @@ public class ValueMappingEstimator : TrivialEstimator /// Name of the key column in . /// Name of the value column in . /// The list of names of the input columns to apply the transformation, and the name of the resulting column. - public ValueMappingEstimator(IHostEnvironment env, IDataView lookupMap, string keyColumn, string valueColumn, params (string input, string output)[] columns) + public ValueMappingEstimator(IHostEnvironment env, IDataView lookupMap, string keyColumn, string valueColumn, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingEstimator)), new ValueMappingTransformer(env, lookupMap, keyColumn, valueColumn, columns)) { @@ -71,17 +71,17 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var columnType = (isKey) ? NumberType.U4 : Transformer.ValueColumnType; var metadataShape = SchemaShape.Create(Transformer.ValueColumnMetadata.Schema); - foreach (var (Input, Output) in _columns) + foreach (var (outputColumnName, inputColumnName) in _columns) { - if (!inputSchema.TryFindColumn(Input, out var originalColumn)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", Input); + if (!inputSchema.TryFindColumn(inputColumnName, out var originalColumn)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnName); if ((originalColumn.Kind == SchemaShape.Column.VectorKind.VariableVector || originalColumn.Kind == SchemaShape.Column.VectorKind.Vector) && Transformer.ValueColumnType is VectorType) - throw Host.ExceptNotSupp("Column '{0}' cannot be mapped to values when the column and the map values are both vector type.", Input); + throw Host.ExceptNotSupp("Column '{0}' cannot be mapped to values when the column and the map values are both vector type.", inputColumnName); // Create the Value column - var col = new SchemaShape.Column(Output, vectorKind, columnType, isKey, metadataShape); - resultDic[Output] = col; + var col = new SchemaShape.Column(outputColumnName, vectorKind, columnType, isKey, metadataShape); + resultDic[outputColumnName] = col; } return new SchemaShape(resultDic.Values); } @@ -96,7 +96,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) /// Specifies the value type. public sealed class ValueMappingEstimator : ValueMappingEstimator { - private (string input, string output)[] _columns; + private (string outputColumnName, string inputColumnName)[] _columns; /// /// Constructs the ValueMappingEstimator, key type -> value type mapping @@ -105,7 +105,7 @@ public sealed class ValueMappingEstimator : ValueMappingEstimator /// The list of keys of TKey. /// The list of values of TValue. /// The list of columns to apply. - public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, params (string input, string output)[] columns) + public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, params (string outputColumnName, string inputColumnName)[] columns) : base(env, DataViewHelper.CreateDataView(env, keys, values, ValueMappingTransformer.KeyColumnName, ValueMappingTransformer.ValueColumnName, false), ValueMappingTransformer.KeyColumnName, ValueMappingTransformer.ValueColumnName, columns) { _columns = columns; @@ -119,7 +119,7 @@ public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnum /// The list of values of TValue. /// Specifies to treat the values as a . /// The list of columns to apply. - public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, bool treatValuesAsKeyType, params (string input, string output)[] columns) + public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, bool treatValuesAsKeyType, params (string outputColumnName, string inputColumnName)[] columns) : base(env, DataViewHelper.CreateDataView(env, keys, values, ValueMappingTransformer.KeyColumnName, ValueMappingTransformer.ValueColumnName, treatValuesAsKeyType), ValueMappingTransformer.KeyColumnName, ValueMappingTransformer.ValueColumnName, columns) { _columns = columns; @@ -132,7 +132,7 @@ public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnum /// The list of keys of TKey. /// The list of values of TValue[]. /// The list of columns to apply. - public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, params (string input, string output)[] columns) + public ValueMappingEstimator(IHostEnvironment env, IEnumerable keys, IEnumerable values, params (string outputColumnName, string inputColumnName)[] columns) : base(env, DataViewHelper.CreateDataView(env, keys, values, ValueMappingTransformer.KeyColumnName, ValueMappingTransformer.ValueColumnName), ValueMappingTransformer.KeyColumnName, ValueMappingTransformer.ValueColumnName, columns) { _columns = columns; @@ -386,7 +386,7 @@ public sealed class Arguments } internal ValueMappingTransformer(IHostEnvironment env, IDataView lookupMap, - string keyColumn, string valueColumn, (string input, string output)[] columns) + string keyColumn, string valueColumn, (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ValueMappingTransformer)), columns) { Host.CheckNonEmpty(keyColumn, nameof(keyColumn), "A key column must be specified when passing in an IDataView for the value mapping"); @@ -498,7 +498,7 @@ private static ValueMappingTransformer CreateTransformInvoke(IHost string keyColumnName, string valueColumnName, bool treatValuesAsKeyTypes, - (string input, string output)[] columns) + (string outputColumnName, string inputColumnName)[] columns) { // Read in the data // scan the input to create convert the values as key types @@ -633,7 +633,7 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData env.Assert(loader.Schema.TryGetColumnIndex(valueColumnName, out int valueColumnIndex)); ValueMappingTransformer transformer = null; - (string Source, string Name)[] columns = args.Column.Select(x => (x.Source, x.Name)).ToArray(); + (string outputColumnName, string inputColumnName)[] columns = args.Column.Select(x => (x.Name, x.Source)).ToArray(); transformer = new ValueMappingTransformer(env, loader, keyColumnName, valueColumnName, columns); return transformer.MakeDataTransform(input); } @@ -672,11 +672,11 @@ protected static ValueMappingTransformer Create(IHostEnvironment env, ModelLoadC // Binary stream of mapping var length = ctx.Reader.ReadInt32(); - var columns = new (string Source, string Name)[length]; + var columns = new (string outputColumnName, string inputColumnName)[length]; for (int i = 0; i < length; i++) { - columns[i].Name = ctx.LoadNonEmptyString(); - columns[i].Source = ctx.LoadNonEmptyString(); + columns[i].outputColumnName = ctx.LoadNonEmptyString(); + columns[i].inputColumnName = ctx.LoadNonEmptyString(); } byte[] rgb = null; @@ -928,14 +928,14 @@ private static byte[] GetBytesFromDataView(IHost host, IDataView lookup, string if (!schema.GetColumnOrNull(valueColumn).HasValue) throw host.ExceptUserArg(nameof(Arguments.ValueColumn), $"Value column not found: '{valueColumn}'"); - var cols = new List<(string Source, string Name)>() + var cols = new List<(string outputColumnName, string inputColumnName)>() { - (keyColumn, KeyColumnName), - (valueColumn, ValueColumnName) + (KeyColumnName, keyColumn), + (ValueColumnName, valueColumn) }; var view = new ColumnCopyingTransformer(host, cols.ToArray()).Transform(lookup); - view = ColumnSelectingTransformer.CreateKeep(host, view, cols.Select(x => x.Name).ToArray()); + view = ColumnSelectingTransformer.CreateKeep(host, view, cols.Select(x => x.outputColumnName).ToArray()); var saver = new BinarySaver(host, new BinarySaver.Arguments()); using (var strm = new MemoryStream()) @@ -964,14 +964,14 @@ private sealed class Mapper : OneToOneMapperBase private readonly Schema _inputSchema; private readonly ValueMap _valueMap; private readonly Schema.Metadata _valueMetadata; - private readonly (string Source, string Name)[] _columns; + private readonly (string outputColumnName, string inputColumnName)[] _columns; private readonly ValueMappingTransformer _parent; internal Mapper(ValueMappingTransformer transform, Schema inputSchema, ValueMap valueMap, Schema.Metadata valueMetadata, - (string input, string output)[] columns) + (string outputColumnName, string inputColumnName)[] columns) : base(transform.Host.Register(nameof(Mapper)), transform, inputSchema) { _inputSchema = inputSchema; @@ -995,12 +995,12 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_columns.Length]; for (int i = 0; i < _columns.Length; i++) { - if (_inputSchema[_columns[i].Source].Type is VectorType && _valueMap.ValueType is VectorType) - throw _parent.Host.ExceptNotSupp("Column '{0}' cannot be mapped to values when the column and the map values are both vector type.", _columns[i].Source); + if (_inputSchema[_columns[i].inputColumnName].Type is VectorType && _valueMap.ValueType is VectorType) + throw _parent.Host.ExceptNotSupp("Column '{0}' cannot be mapped to values when the column and the map values are both vector type.", _columns[i].inputColumnName); var colType = _valueMap.ValueType; - if (_inputSchema[_columns[i].Source].Type is VectorType) + if (_inputSchema[_columns[i].inputColumnName].Type is VectorType) colType = new VectorType(ColumnTypeExtensions.PrimitiveTypeFromType(_valueMap.ValueType.GetItemType().RawType)); - result[i] = new Schema.DetachedColumn(_columns[i].Name, colType, _valueMetadata); + result[i] = new Schema.DetachedColumn(_columns[i].outputColumnName, colType, _valueMetadata); } return result; } diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs index 337aabf34b..18718ae1f2 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingEstimator.cs @@ -26,13 +26,13 @@ public static class Defaults /// Initializes a new instance of . /// /// Host Environment. - /// Name of the column to be transformed. - /// Name of the output column. If this is null '' will be used. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Maximum number of keys to keep per column when auto-training. /// How items should be ordered when vectorized. If choosen they will be in the order encountered. /// If , items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). - public ValueToKeyMappingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int maxNumTerms = Defaults.MaxNumTerms, ValueToKeyMappingTransformer.SortOrder sort = Defaults.Sort) : - this(env, new[] { new ValueToKeyMappingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, maxNumTerms, sort) }) + public ValueToKeyMappingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int maxNumTerms = Defaults.MaxNumTerms, ValueToKeyMappingTransformer.SortOrder sort = Defaults.Sort) : + this(env, new [] { new ValueToKeyMappingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, maxNumTerms, sort) }) { } @@ -61,11 +61,11 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (!col.ItemType.IsStandardScalar()) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); SchemaShape metadata; // In the event that we are transforming something that is of type key, we will get their type of key value @@ -81,7 +81,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) metadata = new SchemaShape(new[] { slotMeta, kv }); else metadata = new SchemaShape(new[] { kv }); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, NumberType.U4, true, metadata); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, col.Kind, NumberType.U4, true, metadata); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs index 0be46f1dfa..0a2676cc70 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformer.cs @@ -152,13 +152,13 @@ public sealed class Arguments : ArgumentsBase internal sealed class ColInfo { public readonly string Name; - public readonly string Source; + public readonly string InputColumnName; public readonly ColumnType TypeSrc; - public ColInfo(string name, string source, ColumnType type) + public ColInfo(string name, string inputColumnName, ColumnType type) { Name = name; - Source = source; + InputColumnName = inputColumnName; TypeSrc = type; } } @@ -168,8 +168,8 @@ public ColInfo(string name, string source, ColumnType type) /// public class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly SortOrder Sort; public readonly int MaxNumTerms; public readonly string[] Term; @@ -180,23 +180,23 @@ public class ColumnInfo /// /// Describes how the transformer handles one column pair. /// - /// Name of input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Maximum number of terms to keep per column when auto-training. /// How items should be ordered when vectorized. If choosen they will be in the order encountered. /// If , items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). /// List of terms. /// Whether key value metadata should be text, regardless of the actual input type. - public ColumnInfo(string input, string output = null, + public ColumnInfo(string name, string inputColumnName = null, int maxNumTerms = ValueToKeyMappingEstimator.Defaults.MaxNumTerms, SortOrder sort = ValueToKeyMappingEstimator.Defaults.Sort, string[] term = null, bool textKeyValues = false ) { - Contracts.CheckNonWhiteSpace(input, nameof(input)); - Input = input; - Output = output ?? input; + Contracts.CheckNonWhiteSpace(name, nameof(name)); + Name = name; + InputColumnName = inputColumnName ?? name; Sort = sort; MaxNumTerms = maxNumTerms; Term = term; @@ -261,10 +261,10 @@ private static VersionInfo GetTermManagerVersionInfo() private readonly bool[] _textMetadata; private const string RegistrationName = "Term"; - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } private string TestIsKnownDataKind(ColumnType type) @@ -283,13 +283,13 @@ private ColInfo[] CreateInfos(Schema inputSchema) var infos = new ColInfo[ColumnPairs.Length]; for (int i = 0; i < ColumnPairs.Length; i++) { - if (!inputSchema.TryGetColumnIndex(ColumnPairs[i].input, out int colSrc)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input); + if (!inputSchema.TryGetColumnIndex(ColumnPairs[i].inputColumnName, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].inputColumnName); var type = inputSchema[colSrc].Type; string reason = TestIsKnownDataKind(type); if (reason != null) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].input, reason, type.ToString()); - infos[i] = new ColInfo(ColumnPairs[i].output, ColumnPairs[i].input, type); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[i].inputColumnName, reason, type.ToString()); + infos[i] = new ColInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, type); } return infos; } @@ -342,8 +342,9 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat if (!Enum.IsDefined(typeof(SortOrder), sortOrder)) throw env.ExceptUserArg(nameof(args.Sort), "Undefined sorting criteria '{0}' detected for column '{1}'", sortOrder, item.Name); - cols[i] = new ColumnInfo(item.Source ?? item.Name, + cols[i] = new ColumnInfo( item.Name, + item.Source ?? item.Name, item.MaxNumTerms ?? args.MaxNumTerms, sortOrder, item.Term, @@ -608,7 +609,7 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info // Auto train this column. Leave the term map null for now, but set the lim appropriately. lims[iinfo] = columns[iinfo].MaxNumTerms; ch.CheckUserArg(lims[iinfo] > 0, nameof(Column.MaxNumTerms), "Must be positive"); - Contracts.Check(trainingData.Schema.TryGetColumnIndex(infos[iinfo].Source, out int colIndex)); + Contracts.Check(trainingData.Schema.TryGetColumnIndex(infos[iinfo].InputColumnName, out int colIndex)); Utils.Add(ref toTrain, colIndex); ++trainsNeeded; } @@ -636,7 +637,7 @@ private static TermMap[] Train(IHostEnvironment env, IChannel ch, ColInfo[] info continue; var bldr = Builder.Create(infos[iinfo].TypeSrc, columns[iinfo].Sort); trainerInfo[itrainer] = iinfo; - trainingData.Schema.TryGetColumnIndex(infos[iinfo].Source, out int colIndex); + trainingData.Schema.TryGetColumnIndex(infos[iinfo].InputColumnName, out int colIndex); trainer[itrainer++] = Trainer.Create(cursor, colIndex, false, lims[iinfo], bldr); } ch.Assert(itrainer == trainer.Length); @@ -776,13 +777,13 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colIndex); Host.Assert(colIndex >= 0); var builder = new MetadataBuilder(); _termMap[i].AddMetadata(builder); builder.Add(InputSchema[colIndex].Metadata, name => name == MetadataUtils.Kinds.SlotNames); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i], builder.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], builder.GetMetadata()); } return result; } @@ -824,14 +825,14 @@ public void SaveAsOnnx(OnnxContext ctx) for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) { ColInfo info = _infos[iinfo]; - string sourceColumnName = info.Source; - if (!ctx.ContainsColumn(sourceColumnName)) + string inputColumnName = info.InputColumnName; + if (!ctx.ContainsColumn(inputColumnName)) { ctx.RemoveColumn(info.Name, false); continue; } - if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName), + if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName), ctx.AddIntermediateVariable(_types[iinfo], info.Name))) { ctx.RemoveColumn(info.Name, true); @@ -849,7 +850,7 @@ public void SaveAsPfa(BoundPfaContext ctx) for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) { var info = _infos[iinfo]; - var srcName = info.Source; + var srcName = info.InputColumnName; string srcToken = ctx.TokenOrNullForName(srcName); if (srcToken == null) { diff --git a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs index aa7a703532..e0f5b0f9fb 100644 --- a/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs +++ b/src/Microsoft.ML.Data/Transforms/ValueToKeyMappingTransformerImpl.cs @@ -906,7 +906,7 @@ public override Delegate GetMappingGetter(Row input) var info = _infos[_iinfo]; T src = default(T); Contracts.Assert(!(info.TypeSrc is VectorType)); - input.Schema.TryGetColumnIndex(info.Source, out int colIndex); + input.Schema.TryGetColumnIndex(info.InputColumnName, out int colIndex); _host.Assert(input.IsColumnActive(colIndex)); var getSrc = input.GetGetter(colIndex); ValueGetter retVal = @@ -927,7 +927,7 @@ public override Delegate GetMappingGetter(Row input) ValueMapper map = TypedMap.GetKeyMapper(); var info = _infos[_iinfo]; // First test whether default maps to default. If so this is sparsity preserving. - input.Schema.TryGetColumnIndex(info.Source, out int colIndex); + input.Schema.TryGetColumnIndex(info.InputColumnName, out int colIndex); _host.Assert(input.IsColumnActive(colIndex)); var getSrc = input.GetGetter>(colIndex); VBuffer src = default(VBuffer); @@ -1086,7 +1086,7 @@ public override void AddMetadata(MetadataBuilder builder) if (TypedMap.Count == 0) return; - _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); + _schema.TryGetColumnIndex(_infos[_iinfo].InputColumnName, out int srcCol); VectorType srcMetaType = _schema[srcCol].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type as VectorType; if (srcMetaType == null || srcMetaType.Size != TypedMap.ItemType.GetKeyCountAsInt32(_host) || TypedMap.ItemType.GetKeyCountAsInt32(_host) == 0 || !Utils.MarshalInvoke(AddMetadataCore, srcMetaType.ItemType.RawType, srcMetaType.ItemType, builder)) @@ -1110,7 +1110,7 @@ private bool AddMetadataCore(ColumnType srcMetaType, MetadataBuilder buil // If we can't convert this type to U4, don't try to pass along the metadata. if (!convInst.TryGetStandardConversion(srcType, dstType, out conv, out identity)) return false; - _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); + _schema.TryGetColumnIndex(_infos[_iinfo].InputColumnName, out int srcCol); ValueGetter> getter = (ref VBuffer dst) => @@ -1167,7 +1167,7 @@ public override void WriteTextTerms(TextWriter writer) if (TypedMap.Count == 0) return; - _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); + _schema.TryGetColumnIndex(_infos[_iinfo].InputColumnName, out int srcCol); VectorType srcMetaType = _schema[srcCol].Metadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.KeyValues)?.Type as VectorType; if (srcMetaType == null || srcMetaType.Size != TypedMap.ItemType.GetKeyCountAsInt32(_host) || TypedMap.ItemType.GetKeyCountAsInt32(_host) == 0 || !Utils.MarshalInvoke(WriteTextTermsCore, srcMetaType.ItemType.RawType, srcMetaType.ItemType, writer)) @@ -1190,7 +1190,7 @@ private bool WriteTextTermsCore(PrimitiveType srcMetaType, TextWriter wri // If we can't convert this type to U4, don't try. if (!convInst.TryGetStandardConversion(srcType, dstType, out conv, out identity)) return false; - _schema.TryGetColumnIndex(_infos[_iinfo].Source, out int srcCol); + _schema.TryGetColumnIndex(_infos[_iinfo].InputColumnName, out int srcCol); VBuffer srcMeta = default(VBuffer); _schema[srcCol].GetKeyValues(ref srcMeta); diff --git a/src/Microsoft.ML.DnnImageFeaturizer.AlexNet/AlexNetExtension.cs b/src/Microsoft.ML.DnnImageFeaturizer.AlexNet/AlexNetExtension.cs index e59a1da6e6..06152729d0 100644 --- a/src/Microsoft.ML.DnnImageFeaturizer.AlexNet/AlexNetExtension.cs +++ b/src/Microsoft.ML.DnnImageFeaturizer.AlexNet/AlexNetExtension.cs @@ -20,9 +20,9 @@ public static class AlexNetExtension /// This assumes both of the models are in the same location as the file containing this method, which they will be if used through the NuGet. /// This should be the default way to use AlexNet if importing the model from a NuGet. /// - public static EstimatorChain AlexNet(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string inputColumn, string outputColumn) + public static EstimatorChain AlexNet(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string outputColumnName, string inputColumnName) { - return AlexNet(dnnModelContext, env, inputColumn, outputColumn, Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location), "DnnImageModels")); + return AlexNet(dnnModelContext, env, outputColumnName, inputColumnName, Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location), "DnnImageModels")); } /// @@ -31,17 +31,17 @@ public static EstimatorChain AlexNet(this DnnImageMode /// must be in a directory all by themsleves for the OnnxTransform to work, this method appends a AlexNetOnnx/AlexNetPrepOnnx subdirectory /// to the passed in directory to prevent having to make that directory manually each time. /// - public static EstimatorChain AlexNet(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string inputColumn, string outputColumn, string modelDir) + public static EstimatorChain AlexNet(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string outputColumnName, string inputColumnName, string modelDir) { var modelChain = new EstimatorChain(); - var inputRename = new ColumnCopyingEstimator(env, new[] { (inputColumn, "OriginalInput") }); - var midRename = new ColumnCopyingEstimator(env, new[] { ("PreprocessedInput", "Input140") }); - var endRename = new ColumnCopyingEstimator(env, new[] { ("Dropout234_Output_0", outputColumn) }); + var inputRename = new ColumnCopyingEstimator(env, new[] { ("OriginalInput", inputColumnName) }); + var midRename = new ColumnCopyingEstimator(env, new[] { ("Input140", "PreprocessedInput") }); + var endRename = new ColumnCopyingEstimator(env, new[] { (outputColumnName, "Dropout234_Output_0") }); // There are two estimators created below. The first one is for image preprocessing and the second one is the actual DNN model. - var prepEstimator = new OnnxScoringEstimator(env, Path.Combine(modelDir, "AlexNetPrepOnnx", "AlexNetPreprocess.onnx"), new[] { "OriginalInput" }, new[] { "PreprocessedInput" }); - var mainEstimator = new OnnxScoringEstimator(env, Path.Combine(modelDir, "AlexNetOnnx", "AlexNet.onnx"), new[] { "Input140" }, new[] { "Dropout234_Output_0" }); + var prepEstimator = new OnnxScoringEstimator(env, new[] { "PreprocessedInput" }, new[] { "OriginalInput" }, Path.Combine(modelDir, "AlexNetPrepOnnx", "AlexNetPreprocess.onnx")); + var mainEstimator = new OnnxScoringEstimator(env, new[] { "Dropout234_Output_0" }, new[] { "Input140" }, Path.Combine(modelDir, "AlexNetOnnx", "AlexNet.onnx")); modelChain = modelChain.Append(inputRename); var modelChain2 = modelChain.Append(prepEstimator); modelChain = modelChain2.Append(midRename); diff --git a/src/Microsoft.ML.DnnImageFeaturizer.ResNet101/ResNet101Extension.cs b/src/Microsoft.ML.DnnImageFeaturizer.ResNet101/ResNet101Extension.cs index c5978fe198..9a78482fdf 100644 --- a/src/Microsoft.ML.DnnImageFeaturizer.ResNet101/ResNet101Extension.cs +++ b/src/Microsoft.ML.DnnImageFeaturizer.ResNet101/ResNet101Extension.cs @@ -20,9 +20,9 @@ public static class ResNet101Extension /// This assumes both of the models are in the same location as the file containing this method, which they will be if used through the NuGet. /// This should be the default way to use ResNet101 if importing the model from a NuGet. /// - public static EstimatorChain ResNet101(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string inputColumn, string outputColumn) + public static EstimatorChain ResNet101(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string outputColumnName, string inputColumnName) { - return ResNet101(dnnModelContext, env, inputColumn, outputColumn, Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location), "DnnImageModels")); + return ResNet101(dnnModelContext, env, outputColumnName, inputColumnName, Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location), "DnnImageModels")); } /// @@ -31,17 +31,17 @@ public static EstimatorChain ResNet101(this DnnImageMo /// must be in a directory all by themsleves for the OnnxTransform to work, this method appends a ResNet101Onnx/ResNetPrepOnnx subdirectory /// to the passed in directory to prevent having to make that directory manually each time. /// - public static EstimatorChain ResNet101(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string inputColumn, string outputColumn, string modelDir) + public static EstimatorChain ResNet101(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string outputColumnName, string inputColumnName, string modelDir) { var modelChain = new EstimatorChain(); - var inputRename = new ColumnCopyingEstimator(env, new[] { (inputColumn, "OriginalInput") }); - var midRename = new ColumnCopyingEstimator(env, new[] { ("PreprocessedInput", "Input1600") }); - var endRename = new ColumnCopyingEstimator(env, new[] { ("Pooling2286_Output_0", outputColumn) }); + var inputRename = new ColumnCopyingEstimator(env, new[] { ("OriginalInput", inputColumnName) }); + var midRename = new ColumnCopyingEstimator(env, new[] { ("Input1600", "PreprocessedInput") }); + var endRename = new ColumnCopyingEstimator(env, new[] { (outputColumnName, "Pooling2286_Output_0") }); // There are two estimators created below. The first one is for image preprocessing and the second one is the actual DNN model. - var prepEstimator = new OnnxScoringEstimator(env, Path.Combine(modelDir, "ResNetPrepOnnx", "ResNetPreprocess.onnx"), new[] { "OriginalInput" }, new[] { "PreprocessedInput" }); - var mainEstimator = new OnnxScoringEstimator(env, Path.Combine(modelDir, "ResNet101Onnx", "ResNet101.onnx"), new[] { "Input1600" }, new[] { "Pooling2286_Output_0" }); + var prepEstimator = new OnnxScoringEstimator(env, new[] { "PreprocessedInput" }, new[] { "OriginalInput" }, Path.Combine(modelDir, "ResNetPrepOnnx", "ResNetPreprocess.onnx")); + var mainEstimator = new OnnxScoringEstimator(env, new[] { "Pooling2286_Output_0" }, new[] { "Input1600" }, Path.Combine(modelDir, "ResNet101Onnx", "ResNet101.onnx")); modelChain = modelChain.Append(inputRename); var modelChain2 = modelChain.Append(prepEstimator); modelChain = modelChain2.Append(midRename); diff --git a/src/Microsoft.ML.DnnImageFeaturizer.ResNet18/ResNet18Extension.cs b/src/Microsoft.ML.DnnImageFeaturizer.ResNet18/ResNet18Extension.cs index 01b8436950..bfd5fa6373 100644 --- a/src/Microsoft.ML.DnnImageFeaturizer.ResNet18/ResNet18Extension.cs +++ b/src/Microsoft.ML.DnnImageFeaturizer.ResNet18/ResNet18Extension.cs @@ -20,9 +20,9 @@ public static class ResNet18Extension /// This assumes both of the models are in the same location as the file containing this method, which they will be if used through the NuGet. /// This should be the default way to use ResNet18 if importing the model from a NuGet. /// - public static EstimatorChain ResNet18(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string inputColumn, string outputColumn) + public static EstimatorChain ResNet18(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string outputColumnName, string inputColumnName) { - return ResNet18(dnnModelContext, env, inputColumn, outputColumn, Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location), "DnnImageModels")); + return ResNet18(dnnModelContext, env, outputColumnName, inputColumnName, Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location), "DnnImageModels")); } /// @@ -31,17 +31,17 @@ public static EstimatorChain ResNet18(this DnnImageMod /// must be in a directory all by themsleves for the OnnxTransform to work, this method appends a ResNet18Onnx/ResNetPrepOnnx subdirectory /// to the passed in directory to prevent having to make that directory manually each time. /// - public static EstimatorChain ResNet18(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string inputColumn, string outputColumn, string modelDir) + public static EstimatorChain ResNet18(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string outputColumnName, string inputColumnName, string modelDir) { var modelChain = new EstimatorChain(); - var inputRename = new ColumnCopyingEstimator(env, new[] { (inputColumn, "OriginalInput") }); - var midRename = new ColumnCopyingEstimator(env, new[] { ("PreprocessedInput", "Input247") }); - var endRename = new ColumnCopyingEstimator(env, new[] { ("Pooling395_Output_0", outputColumn) }); + var inputRename = new ColumnCopyingEstimator(env, new[] { ("OriginalInput", inputColumnName) }); + var midRename = new ColumnCopyingEstimator(env, new[] { ("Input247", "PreprocessedInput") }); + var endRename = new ColumnCopyingEstimator(env, new[] { (outputColumnName, "Pooling395_Output_0") }); // There are two estimators created below. The first one is for image preprocessing and the second one is the actual DNN model. - var prepEstimator = new OnnxScoringEstimator(env, Path.Combine(modelDir, "ResNetPrepOnnx", "ResNetPreprocess.onnx"), new[] { "OriginalInput" }, new[] { "PreprocessedInput" }); - var mainEstimator = new OnnxScoringEstimator(env, Path.Combine(modelDir, "ResNet18Onnx", "ResNet18.onnx"), new[] { "Input247" }, new[] { "Pooling395_Output_0" }); + var prepEstimator = new OnnxScoringEstimator(env, new[] { "PreprocessedInput" }, new[] { "OriginalInput" }, Path.Combine(modelDir, "ResNetPrepOnnx", "ResNetPreprocess.onnx")); + var mainEstimator = new OnnxScoringEstimator(env, new[] { "Pooling395_Output_0" }, new[] { "Input247" }, Path.Combine(modelDir, "ResNet18Onnx", "ResNet18.onnx")); modelChain = modelChain.Append(inputRename); var modelChain2 = modelChain.Append(prepEstimator); modelChain = modelChain2.Append(midRename); diff --git a/src/Microsoft.ML.DnnImageFeaturizer.ResNet50/ResNet50Extension.cs b/src/Microsoft.ML.DnnImageFeaturizer.ResNet50/ResNet50Extension.cs index c57e7f2560..6bd2ef3de8 100644 --- a/src/Microsoft.ML.DnnImageFeaturizer.ResNet50/ResNet50Extension.cs +++ b/src/Microsoft.ML.DnnImageFeaturizer.ResNet50/ResNet50Extension.cs @@ -20,9 +20,9 @@ public static class ResNet50Extension /// This assumes both of the models are in the same location as the file containing this method, which they will be if used through the NuGet. /// This should be the default way to use ResNet50 if importing the model from a NuGet. /// - public static EstimatorChain ResNet50(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string inputColumn, string outputColumn) + public static EstimatorChain ResNet50(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string outputColumnName, string inputColumnName) { - return ResNet50(dnnModelContext, env, inputColumn, outputColumn, Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location), "DnnImageModels")); + return ResNet50(dnnModelContext, env, outputColumnName, inputColumnName, Path.Combine(Path.GetDirectoryName(Assembly.GetExecutingAssembly().Location), "DnnImageModels")); } /// @@ -31,17 +31,17 @@ public static EstimatorChain ResNet50(this DnnImageMod /// must be in a directory all by themsleves for the OnnxTransform to work, this method appends a ResNet50Onnx/ResNetPrepOnnx subdirectory /// to the passed in directory to prevent having to make that directory manually each time. /// - public static EstimatorChain ResNet50(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string inputColumn, string outputColumn, string modelDir) + public static EstimatorChain ResNet50(this DnnImageModelSelector dnnModelContext, IHostEnvironment env, string outputColumnName, string inputColumnName, string modelDir) { var modelChain = new EstimatorChain(); - var inputRename = new ColumnCopyingEstimator(env, new[] { (inputColumn, "OriginalInput") }); - var midRename = new ColumnCopyingEstimator(env, new[] { ("PreprocessedInput", "Input750") }); - var endRename = new ColumnCopyingEstimator(env, new[] { ("Pooling1096_Output_0", outputColumn) }); + var inputRename = new ColumnCopyingEstimator(env, new[] { ("OriginalInput", inputColumnName) }); + var midRename = new ColumnCopyingEstimator(env, new[] { ("Input750", "PreprocessedInput") }); + var endRename = new ColumnCopyingEstimator(env, new[] { (outputColumnName, "Pooling1096_Output_0") }); // There are two estimators created below. The first one is for image preprocessing and the second one is the actual DNN model. - var prepEstimator = new OnnxScoringEstimator(env, Path.Combine(modelDir, "ResNetPrepOnnx", "ResNetPreprocess.onnx"), new[] { "OriginalInput" }, new[] { "PreprocessedInput" }); - var mainEstimator = new OnnxScoringEstimator(env, Path.Combine(modelDir, "ResNet50Onnx", "ResNet50.onnx"), new[] { "Input750" }, new[] { "Pooling1096_Output_0" }); + var prepEstimator = new OnnxScoringEstimator(env, new[] { "PreprocessedInput" }, new[] { "OriginalInput" }, Path.Combine(modelDir, "ResNetPrepOnnx", "ResNetPreprocess.onnx")); + var mainEstimator = new OnnxScoringEstimator(env, new[] { "Pooling1096_Output_0" }, new[] { "Input750" }, Path.Combine(modelDir, "ResNet50Onnx", "ResNet50.onnx")); modelChain = modelChain.Append(inputRename); var modelChain2 = modelChain.Append(prepEstimator); modelChain = modelChain2.Append(midRename); diff --git a/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs index 8c2a4efa7f..10a724c327 100644 --- a/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs +++ b/src/Microsoft.ML.EntryPoints/FeatureCombiner.cs @@ -95,19 +95,19 @@ private static IDataView ApplyKeyToVec(List (x.Input, x.Output)).ToArray()) + viewTrain = new KeyToValueMappingTransformer(host, ktv.Select(x => (x.Name, x.InputColumnName)).ToArray()) .Transform(viewTrain); viewTrain = ValueToKeyMappingTransformer.Create(host, new ValueToKeyMappingTransformer.Arguments() { Column = ktv - .Select(c => new ValueToKeyMappingTransformer.Column() { Name = c.Output, Source = c.Output, Terms = GetTerms(viewTrain, c.Input) }) + .Select(c => new ValueToKeyMappingTransformer.Column() { Name = c.Name, Source = c.Name, Terms = GetTerms(viewTrain, c.InputColumnName) }) .ToArray(), TextKeyValues = true }, viewTrain); - viewTrain = new KeyToVectorMappingTransformer(host, ktv.Select(c => new KeyToVectorMappingTransformer.ColumnInfo(c.Output, c.Output)).ToArray()).Transform(viewTrain); + viewTrain = new KeyToVectorMappingTransformer(host, ktv.Select(c => new KeyToVectorMappingTransformer.ColumnInfo(c.Name, c.Name)).ToArray()).Transform(viewTrain); } return viewTrain; } @@ -145,7 +145,7 @@ private static IDataView ApplyConvert(List Contracts.AssertValue(viewTrain); Contracts.AssertValue(env); if (Utils.Size(cvt) > 0) - viewTrain = new TypeConvertingTransformer(env,cvt.ToArray()).Transform(viewTrain); + viewTrain = new TypeConvertingTransformer(env, cvt.ToArray()).Transform(viewTrain); return viewTrain; } @@ -174,7 +174,7 @@ private static IDataView ApplyConvert(List { var colName = GetUniqueName(); concatNames.Add(new KeyValuePair(col.Name, colName)); - Utils.Add(ref ktv, new KeyToVectorMappingTransformer.ColumnInfo(col.Name, colName)); + Utils.Add(ref ktv, new KeyToVectorMappingTransformer.ColumnInfo(colName, col.Name)); continue; } } @@ -185,7 +185,7 @@ private static IDataView ApplyConvert(List // This happens when the training is done on an XDF and the scoring is done on a data frame. var colName = GetUniqueName(); concatNames.Add(new KeyValuePair(col.Name, colName)); - Utils.Add(ref cvt, new TypeConvertingTransformer.ColumnInfo(col.Name, colName, DataKind.R4)); + Utils.Add(ref cvt, new TypeConvertingTransformer.ColumnInfo(colName, DataKind.R4, col.Name)); continue; } } @@ -312,7 +312,7 @@ public static CommonOutputs.TransformOutput PrepareRegressionLabel(IHostEnvironm } } }; - var xf = new TypeConvertingTransformer(host, new TypeConvertingTransformer.ColumnInfo(input.LabelColumn, input.LabelColumn, DataKind.R4)).Transform(input.Data); + var xf = new TypeConvertingTransformer(host, new TypeConvertingTransformer.ColumnInfo(input.LabelColumn, DataKind.R4, input.LabelColumn)).Transform(input.Data); return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, xf, input.Data), OutputData = xf }; } } diff --git a/src/Microsoft.ML.EntryPoints/ScoreColumnSelector.cs b/src/Microsoft.ML.EntryPoints/ScoreColumnSelector.cs index 97b896bc9d..3de88df465 100644 --- a/src/Microsoft.ML.EntryPoints/ScoreColumnSelector.cs +++ b/src/Microsoft.ML.EntryPoints/ScoreColumnSelector.cs @@ -81,7 +81,7 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I // Rename all the score columns. int colMax; var maxScoreId = input.Data.Schema.GetMaxMetadataKind(out colMax, MetadataUtils.Kinds.ScoreColumnSetId); - var copyCols = new List<(string Source, string Name)>(); + var copyCols = new List<(string name, string source)>(); for (int i = 0; i < input.Data.Schema.Count; i++) { if (input.Data.Schema[i].IsHidden) @@ -98,11 +98,11 @@ public static CommonOutputs.TransformOutput RenameBinaryPredictionScoreColumns(I } var source = input.Data.Schema[i].Name; var name = source + "." + positiveClass; - copyCols.Add((source, name)); + copyCols.Add((name, source)); } var copyColumn = new ColumnCopyingTransformer(env, copyCols.ToArray()).Transform(input.Data); - var dropColumn = ColumnSelectingTransformer.CreateDrop(env, copyColumn, copyCols.Select(c => c.Source).ToArray()); + var dropColumn = ColumnSelectingTransformer.CreateDrop(env, copyColumn, copyCols.Select(c => c.source).ToArray()); return new CommonOutputs.TransformOutput { Model = new TransformModelImpl(env, dropColumn, input.Data), OutputData = dropColumn }; } } diff --git a/src/Microsoft.ML.FastTree/FastTree.cs b/src/Microsoft.ML.FastTree/FastTree.cs index 45d7412598..72700b9dc0 100644 --- a/src/Microsoft.ML.FastTree/FastTree.cs +++ b/src/Microsoft.ML.FastTree/FastTree.cs @@ -1382,7 +1382,7 @@ private Dataset Construct(RoleMappedData examples, ref int numExamples, int maxB } // Convert the group column, if one exists. if (examples.Schema.Group?.Name is string groupName) - data = new TypeConvertingTransformer(Host, new TypeConvertingTransformer.ColumnInfo(groupName, groupName, DataKind.U8)).Transform(data); + data = new TypeConvertingTransformer(Host, new TypeConvertingTransformer.ColumnInfo(groupName, DataKind.U8, groupName)).Transform(data); // Since we've passed it through a few transforms, reconstitute the mapping on the // newly transformed data. diff --git a/src/Microsoft.ML.HalLearners.StaticPipe/VectorWhiteningStaticExtensions.cs b/src/Microsoft.ML.HalLearners.StaticPipe/VectorWhiteningStaticExtensions.cs index 97e6005bc7..fad38dc3a5 100644 --- a/src/Microsoft.ML.HalLearners.StaticPipe/VectorWhiteningStaticExtensions.cs +++ b/src/Microsoft.ML.HalLearners.StaticPipe/VectorWhiteningStaticExtensions.cs @@ -51,7 +51,7 @@ public override IEstimator Reconcile(IHostEnvironment env, var infos = new VectorWhiteningTransformer.ColumnInfo[toOutput.Length]; for (int i = 0; i < toOutput.Length; i++) - infos[i] = new VectorWhiteningTransformer.ColumnInfo(inputNames[((OutPipelineColumn)toOutput[i]).Input], outputNames[toOutput[i]], _kind, _eps, _maxRows, _pcaNum); + infos[i] = new VectorWhiteningTransformer.ColumnInfo(outputNames[toOutput[i]], inputNames[((OutPipelineColumn)toOutput[i]).Input], _kind, _eps, _maxRows, _pcaNum); return new VectorWhiteningEstimator(env, infos); } diff --git a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs index 68c8261209..9f6f16ab51 100644 --- a/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs +++ b/src/Microsoft.ML.HalLearners/HalLearnersCatalog.cs @@ -98,9 +98,8 @@ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent( /// meaning that they are uncorrelated and each have variance 1. /// /// The transform's catalog. - /// Name of the input column. - /// Name of the column resulting from the transformation of . - /// Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Whitening kind (PCA/ZCA). /// Whitening constant, prevents division by zero. /// Maximum number of rows used to train the transform. @@ -112,13 +111,12 @@ public static SymSgdClassificationTrainer SymbolicStochasticGradientDescent( /// ]]> /// /// - public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.ProjectionTransforms catalog, - string inputColumn, string outputColumn = null, + public static VectorWhiteningEstimator VectorWhiten(this TransformsCatalog.ProjectionTransforms catalog, string outputColumnName, string inputColumnName = null, WhiteningKind kind = VectorWhiteningTransformer.Defaults.Kind, float eps = VectorWhiteningTransformer.Defaults.Eps, int maxRows = VectorWhiteningTransformer.Defaults.MaxRows, int pcaNum = VectorWhiteningTransformer.Defaults.PcaNum) - => new VectorWhiteningEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, kind, eps, maxRows, pcaNum); + => new VectorWhiteningEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, kind, eps, maxRows, pcaNum); /// /// Takes columns filled with a vector of random variables with a known covariance matrix into a set of new variables whose diff --git a/src/Microsoft.ML.HalLearners/VectorWhitening.cs b/src/Microsoft.ML.HalLearners/VectorWhitening.cs index 021c6af3c9..af71981a30 100644 --- a/src/Microsoft.ML.HalLearners/VectorWhitening.cs +++ b/src/Microsoft.ML.HalLearners/VectorWhitening.cs @@ -118,8 +118,8 @@ internal bool TryUnparse(StringBuilder sb) public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly WhiteningKind Kind; public readonly float Epsilon; public readonly int MaxRow; @@ -129,19 +129,19 @@ public sealed class ColumnInfo /// /// Describes how the transformer handles one input-output column pair. /// - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Whitening kind (PCA/ZCA). /// Whitening constant, prevents division by zero. /// Maximum number of rows used to train the transform. /// In case of PCA whitening, indicates the number of components to retain. - public ColumnInfo(string input, string output = null, WhiteningKind kind = Defaults.Kind, float eps = Defaults.Eps, + public ColumnInfo(string name, string inputColumnName = null, WhiteningKind kind = Defaults.Kind, float eps = Defaults.Eps, int maxRows = Defaults.MaxRows, int pcaNum = Defaults.PcaNum) { - Input = input; - Contracts.CheckValue(Input, nameof(Input)); - Output = output ?? input; - Contracts.CheckValue(Output, nameof(Output)); + Name = name; + Contracts.CheckValue(Name, nameof(Name)); + InputColumnName = inputColumnName ?? name; + Contracts.CheckValue(InputColumnName, nameof(InputColumnName)); Kind = kind; Contracts.CheckUserArg(Kind == WhiteningKind.Pca || Kind == WhiteningKind.Zca, nameof(Kind)); Epsilon = eps; @@ -155,10 +155,10 @@ public ColumnInfo(string input, string output = null, WhiteningKind kind = Defau internal ColumnInfo(Column item, Arguments args) { - Input = item.Source ?? item.Name; - Contracts.CheckValue(Input, nameof(Input)); - Output = item.Name; - Contracts.CheckValue(Output, nameof(Output)); + Name = item.Name; + Contracts.CheckValue(Name, nameof(Name)); + InputColumnName = item.Source ?? item.Name; + Contracts.CheckValue(InputColumnName, nameof(InputColumnName)); Kind = item.Kind ?? args.Kind; Contracts.CheckUserArg(Kind == WhiteningKind.Pca || Kind == WhiteningKind.Zca, nameof(item.Kind)); Epsilon = item.Eps ?? args.Eps; @@ -308,8 +308,8 @@ internal static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx internal static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Schema inputSchema) => Create(env, ctx).MakeRowMapper(inputSchema); - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) - => columns.Select(c => (c.Input, c.Output ?? c.Input)).ToArray(); + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) + => columns.Select(c => (c.Name, c.InputColumnName ?? c.Name)).ToArray(); protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { @@ -386,9 +386,9 @@ private static void GetColTypesAndIndex(IHostEnvironment env, IDataView inputDat for (int i = 0; i < columns.Length; i++) { - var col = inputSchema.GetColumnOrNull(columns[i].Input); + var col = inputSchema.GetColumnOrNull(columns[i].InputColumnName); if (!col.HasValue) - throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].Input); + throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].InputColumnName); cols[i] = col.Value.Index; srcTypes[i] = col.Value.Type; @@ -418,7 +418,7 @@ private static float[][] LoadDataAsDense(IHostEnvironment env, IChannel ch, IDat actualRowCounts[i] = (int)crowData; else { - ch.Info(MessageSensitivity.Schema, "Only {0:N0} rows of column '{1}' will be used for whitening transform.", ex.MaxRow, columns[i].Output); + ch.Info(MessageSensitivity.Schema, "Only {0:N0} rows of column '{1}' will be used for whitening transform.", ex.MaxRow, columns[i].Name); actualRowCounts[i] = ex.MaxRow; } @@ -427,7 +427,7 @@ private static float[][] LoadDataAsDense(IHostEnvironment env, IChannel ch, IDat if ((long)cslot * actualRowCounts[i] > int.MaxValue) { actualRowCounts[i] = int.MaxValue / cslot; - ch.Info(MessageSensitivity.Schema, "Only {0:N0} rows of column '{1}' will be used for whitening transform.", actualRowCounts[i], columns[i].Output); + ch.Info(MessageSensitivity.Schema, "Only {0:N0} rows of column '{1}' will be used for whitening transform.", actualRowCounts[i], columns[i].Name); } columnData[i] = new float[cslot * actualRowCounts[i]]; if (actualRowCounts[i] > maxActualRowCount) @@ -665,8 +665,8 @@ public Mapper(VectorWhiteningTransformer parent, Schema inputSchema) for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - if (!InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _cols[i])) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + if (!InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _cols[i])) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName); _srcTypes[i] = inputSchema[_cols[i]].Type; ValidateModel(Host, _parent._models[i], _srcTypes[i]); if (_parent._columns[i].SaveInv) @@ -688,11 +688,11 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int iinfo = 0; iinfo < _parent.ColumnPairs.Length; iinfo++) { - InputSchema.TryGetColumnIndex(_parent.ColumnPairs[iinfo].input, out int colIndex); + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[iinfo].inputColumnName, out int colIndex); Host.Assert(colIndex >= 0); var info = _parent._columns[iinfo]; ColumnType outType = (info.Kind == WhiteningKind.Pca && info.PcaNum > 0) ? new VectorType(NumberType.Float, info.PcaNum) : _srcTypes[iinfo]; - result[iinfo] = new Schema.DetachedColumn(_parent.ColumnPairs[iinfo].output, outType, null); + result[iinfo] = new Schema.DetachedColumn(_parent.ColumnPairs[iinfo].outputColumnName, outType, null); } return result; } @@ -786,18 +786,18 @@ public VectorWhiteningEstimator(IHostEnvironment env, params VectorWhiteningTran /// /// The environment. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Whitening kind (PCA/ZCA). /// Whitening constant, prevents division by zero when scaling the data by inverse of eigenvalues. /// Maximum number of rows used to train the transform. /// In case of PCA whitening, indicates the number of components to retain. - public VectorWhiteningEstimator(IHostEnvironment env, string inputColumn, string outputColumn, + public VectorWhiteningEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, WhiteningKind kind = VectorWhiteningTransformer.Defaults.Kind, float eps = VectorWhiteningTransformer.Defaults.Eps, int maxRows = VectorWhiteningTransformer.Defaults.MaxRows, int pcaNum = VectorWhiteningTransformer.Defaults.PcaNum) - : this(env, new VectorWhiteningTransformer.ColumnInfo(inputColumn, outputColumn, kind, eps, maxRows, pcaNum)) + : this(env, new VectorWhiteningTransformer.ColumnInfo(outputColumnName, inputColumnName, kind, eps, maxRows, pcaNum)) { } @@ -817,12 +817,12 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in _infos) { - if (!inputSchema.TryFindColumn(colPair.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.Input); + if (!inputSchema.TryFindColumn(colPair.InputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.InputColumnName); var reason = VectorWhiteningTransformer.TestColumn(col.ItemType); if (reason != null) throw _host.ExceptUserArg(nameof(inputSchema), reason); - result[colPair.Output] = new SchemaShape.Column(colPair.Output, col.Kind, col.ItemType, col.IsKey, null); + result[colPair.Name] = new SchemaShape.Column(colPair.Name, col.Kind, col.ItemType, col.IsKey, null); } return new SchemaShape(result.Values); } diff --git a/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs b/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs index 8000d42c6d..192d12bd72 100644 --- a/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.ImageAnalytics/ExtensionsCatalog.cs @@ -14,7 +14,7 @@ public static class ImageEstimatorsCatalog /// /// The transform's catalog. /// The name of the columns containing the image paths(first item of the tuple), and the name of the resulting output column (second item of the tuple). - public static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalog catalog, params (string input, string output)[] columns) + public static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalog catalog, params (string outputColumnName, string inputColumnName)[] columns) => new ImageGrayscalingEstimator(CatalogUtils.GetEnvironment(catalog), columns); /// @@ -23,20 +23,20 @@ public static ImageGrayscalingEstimator ConvertToGrayscale(this TransformsCatalo /// The transform's catalog. /// The images folder. /// The name of the columns containing the image paths(first item of the tuple), and the name of the resulting output column (second item of the tuple). - public static ImageLoadingEstimator LoadImages(this TransformsCatalog catalog, string imageFolder, params (string input, string output)[] columns) + public static ImageLoadingEstimator LoadImages(this TransformsCatalog catalog, string imageFolder, params (string outputColumnName, string inputColumnName)[] columns) => new ImageLoadingEstimator(CatalogUtils.GetEnvironment(catalog), imageFolder, columns); /// /// Loads the images from a given folder. /// /// The transform's catalog. - /// The name of the input column. - /// The name of the output column generated from the estimator. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// The color schema as defined in . /// - public static ImagePixelExtractingEstimator ExtractPixels(this TransformsCatalog catalog, string inputColumn, string outputColumn, + public static ImagePixelExtractingEstimator ExtractPixels(this TransformsCatalog catalog, string outputColumnName, string inputColumnName = null, ImagePixelExtractorTransformer.ColorBits colors = ImagePixelExtractorTransformer.ColorBits.Rgb, bool interleave = false) - => new ImagePixelExtractingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, colors, interleave); + => new ImagePixelExtractingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, colors, interleave); /// /// Loads the images from a given folder. @@ -50,16 +50,16 @@ public static ImagePixelExtractingEstimator ExtractPixels(this TransformsCatalog /// Resizes an image. /// /// The transform's catalog. - /// Name of the input column. - /// Name of the resulting output column. + /// Name of the input column. + /// Name of the resulting output column. /// The image width. /// The image height. /// The type of image resizing as specified in . /// Where to place the anchor, to start cropping. Options defined in /// - public static ImageResizingEstimator Resize(this TransformsCatalog catalog, string inputColumn, string outputColumn, - int imageWidth, int imageHeight, ImageResizerTransformer.ResizingKind resizing = ImageResizerTransformer.ResizingKind.IsoCrop, ImageResizerTransformer.Anchor cropAnchor = ImageResizerTransformer.Anchor.Center) - => new ImageResizingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, imageWidth, imageHeight, resizing, cropAnchor); + public static ImageResizingEstimator Resize(this TransformsCatalog catalog, string outputColumnName, int imageWidth, int imageHeight, + string inputColumnName = null, ImageResizerTransformer.ResizingKind resizing = ImageResizerTransformer.ResizingKind.IsoCrop, ImageResizerTransformer.Anchor cropAnchor = ImageResizerTransformer.Anchor.Center) + => new ImageResizingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, imageWidth, imageHeight, inputColumnName, resizing, cropAnchor); /// /// Resizes an image. diff --git a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs index 865ca2da83..b3cd495539 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageGrayscaleTransform.cs @@ -81,7 +81,7 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "ImageGrayscale"; - public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); + public IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); /// /// Converts the images to grayscale. @@ -89,7 +89,7 @@ private static VersionInfo GetVersionInfo() /// The estimator's local . /// The name of the columns containing the image paths(first item of the tuple), and the name of the resulting output column (second item of the tuple). - public ImageGrayscaleTransformer(IHostEnvironment env, params (string input, string output)[] columns) + public ImageGrayscaleTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) { } @@ -102,7 +102,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat env.CheckValue(input, nameof(input)); env.CheckValue(args.Column, nameof(args.Column)); - return new ImageGrayscaleTransformer(env, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray()) + return new ImageGrayscaleTransformer(env, args.Column.Select(x => (x.Name, x.Source ?? x.Name)).ToArray()) .MakeDataTransform(input); } @@ -156,7 +156,7 @@ public override void Save(ModelSaveContext ctx) protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { if (!(inputSchema[srcCol].Type is ImageType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "image", inputSchema[srcCol].Type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, "image", inputSchema[srcCol].Type.ToString()); } private sealed class Mapper : OneToOneMapperBase @@ -170,7 +170,7 @@ public Mapper(ImageGrayscaleTransformer parent, Schema inputSchema) } protected override Schema.DetachedColumn[] GetOutputColumnsCore() - => _parent.ColumnPairs.Select((x, idx) => new Schema.DetachedColumn(x.output, InputSchema[ColMapNewToOld[idx]].Type, null)).ToArray(); + => _parent.ColumnPairs.Select((x, idx) => new Schema.DetachedColumn(x.outputColumnName, InputSchema[ColMapNewToOld[idx]].Type, null)).ToArray(); protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { @@ -227,7 +227,7 @@ public sealed class ImageGrayscalingEstimator : TrivialEstimator /// The estimator's local . /// The name of the columns containing the image paths(first item of the tuple), and the name of the resulting output column (second item of the tuple). - public ImageGrayscalingEstimator(IHostEnvironment env, params (string input, string output)[] columns) + public ImageGrayscalingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageGrayscalingEstimator)), new ImageGrayscaleTransformer(env, columns)) { } @@ -238,12 +238,12 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input); + if (!inputSchema.TryFindColumn(colInfo.inputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName); if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input, new ImageType().ToString(), col.GetTypeString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName, new ImageType().ToString(), col.GetTypeString()); - result[colInfo.output] = new SchemaShape.Column(colInfo.output, col.Kind, col.ItemType, col.IsKey, col.Metadata); + result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName, col.Kind, col.ItemType, col.IsKey, col.Metadata); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs index 38265ff30f..171cc5670f 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageLoaderTransform.cs @@ -70,7 +70,7 @@ public sealed class Arguments : TransformInputBase public readonly string ImageFolder; - public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); + public IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); /// /// Load images in memory. @@ -78,7 +78,7 @@ public sealed class Arguments : TransformInputBase /// The host environment. /// Folder where to look for images. /// Names of input and output columns. - public ImageLoaderTransformer(IHostEnvironment env, string imageFolder = null, params (string input, string output)[] columns) + public ImageLoaderTransformer(IHostEnvironment env, string imageFolder = null, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImageLoaderTransformer)), columns) { ImageFolder = imageFolder; @@ -87,7 +87,7 @@ public ImageLoaderTransformer(IHostEnvironment env, string imageFolder = null, p // Factory method for SignatureDataTransform. internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView data) { - return new ImageLoaderTransformer(env, args.ImageFolder, args.Column.Select(x => (x.Source ?? x.Name, x.Name)).ToArray()) + return new ImageLoaderTransformer(env, args.ImageFolder, args.Column.Select(x => (x.Name, x.Source ?? x.Name)).ToArray()) .MakeDataTransform(data); } @@ -122,7 +122,7 @@ private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, Sch protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { if (!(inputSchema[srcCol].Type is TextType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, TextType.Instance.ToString(), inputSchema[srcCol].Type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, TextType.Instance.ToString(), inputSchema[srcCol].Type.ToString()); } public override void Save(ModelSaveContext ctx) @@ -213,7 +213,7 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func act } protected override Schema.DetachedColumn[] GetOutputColumnsCore() - => _parent.ColumnPairs.Select(x => new Schema.DetachedColumn(x.output, _imageType, null)).ToArray(); + => _parent.ColumnPairs.Select(x => new Schema.DetachedColumn(x.outputColumnName, _imageType, null)).ToArray(); } } @@ -230,7 +230,7 @@ public sealed class ImageLoadingEstimator : TrivialEstimatorThe host environment. /// Folder where to look for images. /// Names of input and output columns. - public ImageLoadingEstimator(IHostEnvironment env, string imageFolder, params (string input, string output)[] columns) + public ImageLoadingEstimator(IHostEnvironment env, string imageFolder, params (string outputColumnName, string inputColumnName)[] columns) : this(env, new ImageLoaderTransformer(env, imageFolder, columns)) { } @@ -245,14 +245,14 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) { Host.CheckValue(inputSchema, nameof(inputSchema)); var result = inputSchema.ToDictionary(x => x.Name); - foreach (var (input, output) in Transformer.Columns) + foreach (var (outputColumnName, inputColumnName) in Transformer.Columns) { - if (!inputSchema.TryFindColumn(input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); + if (!inputSchema.TryFindColumn(inputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnName); if (!(col.ItemType is TextType) || col.Kind != SchemaShape.Column.VectorKind.Scalar) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", input, TextType.Instance.ToString(), col.GetTypeString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColumnName, TextType.Instance.ToString(), col.GetTypeString()); - result[output] = new SchemaShape.Column(output, SchemaShape.Column.VectorKind.Scalar, _imageType, false); + result[outputColumnName] = new SchemaShape.Column(outputColumnName, SchemaShape.Column.VectorKind.Scalar, _imageType, false); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs index f07d26f4e2..7f33af1855 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImagePixelExtractorTransform.cs @@ -145,8 +145,8 @@ internal static class Defaults /// public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly ColorBits Colors; public readonly byte Planes; @@ -165,8 +165,8 @@ internal ColumnInfo(Column item, Arguments args) Contracts.CheckValue(item, nameof(item)); Contracts.CheckValue(args, nameof(args)); - Input = item.Source ?? item.Name; - Output = item.Name; + Name = item.Name; + InputColumnName = item.Source ?? item.Name; if (item.UseAlpha ?? args.UseAlpha) { Colors |= ColorBits.Alpha; Planes++; } if (item.UseRed ?? args.UseRed) { Colors |= ColorBits.Red; Planes++; } @@ -194,25 +194,25 @@ internal ColumnInfo(Column item, Arguments args) /// /// Describes how the transformer handles one input-output column pair. /// - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// What colors to extract. /// /// Scale color pixel value by this amount. /// Offset color pixel value by this amount. /// Output array as float array. If false, output as byte array. - public ColumnInfo(string input, string output = null, + public ColumnInfo(string name, string inputColumnName = null, ColorBits colors = Defaults.Colors, bool interleave = Defaults.Interleave, float scale = Defaults.Scale, float offset = Defaults.Offset, bool asFloat = Defaults.Convert) { - Contracts.CheckNonWhiteSpace(input, nameof(input)); + Contracts.CheckNonWhiteSpace(name, nameof(name)); - Input = input; - Output = output ?? input; + Name = name; + InputColumnName = inputColumnName ?? name; Colors = colors; if ((Colors & ColorBits.Alpha) == ColorBits.Alpha) Planes++; @@ -238,14 +238,14 @@ public ColumnInfo(string input, string output = null, Contracts.CheckParam(FloatUtils.IsFiniteNonZero(Scale), nameof(scale)); } - internal ColumnInfo(string input, string output, ModelLoadContext ctx) + internal ColumnInfo(string name, string inputColumnName, ModelLoadContext ctx) { - Contracts.AssertNonEmpty(input); - Contracts.AssertNonEmpty(output); + Contracts.AssertNonEmpty(name); + Contracts.AssertNonEmpty(inputColumnName); Contracts.AssertValue(ctx); - Input = input; - Output = output; + Name = name; + InputColumnName = inputColumnName; // *** Binary format *** // byte: colors @@ -329,17 +329,22 @@ private static VersionInfo GetVersionInfo() /// Extract pixels values from image and produce array of values. /// /// The host environment. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// What colors to extract. /// /// Scale color pixel value by this amount. /// Offset color pixel value by this amount. /// Output array as float array. If false, output as byte array. - public ImagePixelExtractorTransformer(IHostEnvironment env, string input, string output = null, - ColorBits colors = ColorBits.Rgb, bool interleave = Defaults.Interleave, float scale = Defaults.Scale, float offset = Defaults.Offset, + public ImagePixelExtractorTransformer(IHostEnvironment env, + string outputColumnName, + string inputColumnName = null, + ColorBits colors = ColorBits.Rgb, + bool interleave = Defaults.Interleave, + float scale = Defaults.Scale, + float offset = Defaults.Offset, bool asFloat = Defaults.Convert) - : this(env, new ColumnInfo(input, output, colors, interleave, scale, offset, asFloat)) + : this(env, new ColumnInfo(outputColumnName, inputColumnName, colors, interleave, scale, offset, asFloat)) { } @@ -354,10 +359,10 @@ public ImagePixelExtractorTransformer(IHostEnvironment env, params ColumnInfo[] _columns = columns.ToArray(); } - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } // Factory method for SignatureDataTransform. @@ -402,7 +407,7 @@ private ImagePixelExtractorTransformer(IHost host, ModelLoadContext ctx) _columns = new ColumnInfo[ColumnPairs.Length]; for (int i = 0; i < _columns.Length; i++) - _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, ctx); + _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, ctx); } // Factory method for SignatureLoadDataTransform. @@ -436,7 +441,7 @@ public override void Save(ModelSaveContext ctx) protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { - var inputColName = _columns[col].Input; + var inputColName = _columns[col].InputColumnName; var imageType = inputSchema[srcCol].Type as ImageType; if (imageType == null) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", inputColName, "image", inputSchema[srcCol].Type.ToString()); @@ -459,7 +464,7 @@ public Mapper(ImagePixelExtractorTransformer parent, Schema inputSchema) } protected override Schema.DetachedColumn[] GetOutputColumnsCore() - => _parent._columns.Select((x, idx) => new Schema.DetachedColumn(x.Output, _types[idx], null)).ToArray(); + => _parent._columns.Select((x, idx) => new Schema.DetachedColumn(x.Name, _types[idx], null)).ToArray(); protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { @@ -660,18 +665,20 @@ public sealed class ImagePixelExtractingEstimator : TrivialEstimator /// The host environment. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the input column. /// What colors to extract. /// /// Scale color pixel value by this amount. /// Offset color pixel value by this amount. /// Output array as float array. If false, output as byte array. - public ImagePixelExtractingEstimator(IHostEnvironment env, string input, - string output = null, ImagePixelExtractorTransformer.ColorBits colors = ImagePixelExtractorTransformer.Defaults.Colors, + public ImagePixelExtractingEstimator(IHostEnvironment env, + string outputColumnName, + string inputColumnName = null, + ImagePixelExtractorTransformer.ColorBits colors = ImagePixelExtractorTransformer.Defaults.Colors, bool interleave = ImagePixelExtractorTransformer.Defaults.Interleave, float scale = ImagePixelExtractorTransformer.Defaults.Scale, float offset = ImagePixelExtractorTransformer.Defaults.Offset, bool asFloat = ImagePixelExtractorTransformer.Defaults.Convert) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImagePixelExtractingEstimator)), new ImagePixelExtractorTransformer(env, input, output, colors, interleave)) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(ImagePixelExtractingEstimator)), new ImagePixelExtractorTransformer(env, outputColumnName, inputColumnName, colors, interleave)) { } @@ -691,13 +698,13 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new ImageType().ToString(), col.GetTypeString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, new ImageType().ToString(), col.GetTypeString()); var itemType = colInfo.AsFloat ? NumberType.R4 : NumberType.U1; - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, itemType, false); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, itemType, false); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs index 76e5c0ef75..3630298eba 100644 --- a/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs +++ b/src/Microsoft.ML.ImageAnalytics/ImageResizerTransform.cs @@ -117,8 +117,8 @@ internal static class Defaults /// public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly int Width; public readonly int Height; @@ -129,23 +129,23 @@ public sealed class ColumnInfo /// /// Describes how the transformer handles one image resize column pair. /// - /// Name of the input column. - /// Name of the column resulting from the transformation of . + /// Name of the column resulting from the transformation of . /// Width of resized image. /// Height of resized image. + /// Name of column to transform. If set to , the value of the will be used as source. /// What to use. /// If set to what anchor to use for cropping. - public ColumnInfo(string input, string output, int width, int height, ResizingKind resizing = Defaults.Resizing, Anchor anchor = Defaults.CropAnchor) + public ColumnInfo(string name, int width, int height, string inputColumnName = null, ResizingKind resizing = Defaults.Resizing, Anchor anchor = Defaults.CropAnchor) { - Contracts.CheckNonEmpty(input, nameof(input)); - Contracts.CheckNonEmpty(output, nameof(output)); + Contracts.CheckNonEmpty(name, nameof(name)); + Contracts.CheckNonEmpty(inputColumnName, nameof(inputColumnName)); Contracts.CheckUserArg(width > 0, nameof(Column.ImageWidth)); Contracts.CheckUserArg(height > 0, nameof(Column.ImageHeight)); Contracts.CheckUserArg(Enum.IsDefined(typeof(ResizingKind), resizing), nameof(Column.Resizing)); Contracts.CheckUserArg(Enum.IsDefined(typeof(Anchor), anchor), nameof(Column.CropAnchor)); - Input = input; - Output = output; + Name = name; + InputColumnName = inputColumnName; Width = width; Height = height; Resizing = resizing; @@ -182,15 +182,15 @@ private static VersionInfo GetVersionInfo() /// Resize image. /// /// The host environment. - /// Name of the input column. - /// Name of the column resulting from the transformation of . + /// Name of the column resulting from the transformation of . /// Width of resized image. /// Height of resized image. + /// Name of the input column. /// What to use. /// If set to what anchor to use for cropping. - public ImageResizerTransformer(IHostEnvironment env, string input, string output, - int imageWidth, int imageHeight, ResizingKind resizing = ResizingKind.IsoCrop, Anchor cropAnchor = Anchor.Center) - : this(env, new ColumnInfo(input, output, imageWidth, imageHeight, resizing, cropAnchor)) + public ImageResizerTransformer(IHostEnvironment env, string outputColumnName, + int imageWidth, int imageHeight, string inputColumnName = null, ResizingKind resizing = ResizingKind.IsoCrop, Anchor cropAnchor = Anchor.Center) + : this(env, new ColumnInfo(outputColumnName, imageWidth, imageHeight, inputColumnName, resizing, cropAnchor)) { } @@ -205,10 +205,10 @@ public ImageResizerTransformer(IHostEnvironment env, params ColumnInfo[] columns _columns = columns.ToArray(); } - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } // Factory method for SignatureDataTransform. @@ -225,10 +225,10 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat { var item = args.Column[i]; cols[i] = new ColumnInfo( - item.Source ?? item.Name, item.Name, item.ImageWidth ?? args.ImageWidth, item.ImageHeight ?? args.ImageHeight, + item.Source ?? item.Name, item.Resizing ?? args.Resizing, item.CropAnchor ?? args.CropAnchor); } @@ -271,7 +271,7 @@ private ImageResizerTransformer(IHost host, ModelLoadContext ctx) Host.CheckDecode(Enum.IsDefined(typeof(ResizingKind), scale)); var anchor = (Anchor)ctx.Reader.ReadByte(); Host.CheckDecode(Enum.IsDefined(typeof(Anchor), anchor)); - _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, width, height, scale, anchor); + _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, width, height, ColumnPairs[i].inputColumnName, scale, anchor); } } @@ -317,7 +317,7 @@ public override void Save(ModelSaveContext ctx) protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { if (!(inputSchema[srcCol].Type is ImageType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[col].Input, "image", inputSchema[srcCol].Type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _columns[col].InputColumnName, "image", inputSchema[srcCol].Type.ToString()); } private sealed class Mapper : OneToOneMapperBase @@ -331,7 +331,7 @@ public Mapper(ImageResizerTransformer parent, Schema inputSchema) } protected override Schema.DetachedColumn[] GetOutputColumnsCore() - => _parent._columns.Select(x => new Schema.DetachedColumn(x.Output, x.Type, null)).ToArray(); + => _parent._columns.Select(x => new Schema.DetachedColumn(x.Name, x.Type, null)).ToArray(); protected override Delegate MakeGetter(Row input, int iinfo, Func activeOutput, out Action disposer) { @@ -462,16 +462,20 @@ public sealed class ImageResizingEstimator : TrivialEstimator /// The host environment. - /// Name of the input column. - /// Name of the column resulting from the transformation of . + /// Name of the column resulting from the transformation of . /// Width of resized image. /// Height of resized image. + /// Name of the input column. /// What to use. /// If set to what anchor to use for cropping. - public ImageResizingEstimator(IHostEnvironment env, string input, string output, - int imageWidth, int imageHeight, ImageResizerTransformer.ResizingKind resizing = ImageResizerTransformer.Defaults.Resizing, + public ImageResizingEstimator(IHostEnvironment env, + string outputColumnName, + int imageWidth, + int imageHeight, + string inputColumnName = null, + ImageResizerTransformer.ResizingKind resizing = ImageResizerTransformer.Defaults.Resizing, ImageResizerTransformer.Anchor cropAnchor = ImageResizerTransformer.Defaults.CropAnchor) - : this(env, new ImageResizerTransformer(env, input, output, imageWidth, imageHeight, resizing, cropAnchor)) + : this(env, new ImageResizerTransformer(env, outputColumnName, imageWidth, imageHeight, inputColumnName, resizing, cropAnchor)) { } @@ -496,12 +500,12 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (!(col.ItemType is ImageType) || col.Kind != SchemaShape.Column.VectorKind.Scalar) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new ImageType().ToString(), col.GetTypeString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, new ImageType().ToString(), col.GetTypeString()); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Scalar, colInfo.Type, false); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Scalar, colInfo.Type, false); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.OnnxTransform.StaticPipe/DnnImageFeaturizerStaticExtensions.cs b/src/Microsoft.ML.OnnxTransform.StaticPipe/DnnImageFeaturizerStaticExtensions.cs index 7f1e36118c..a048c1015c 100644 --- a/src/Microsoft.ML.OnnxTransform.StaticPipe/DnnImageFeaturizerStaticExtensions.cs +++ b/src/Microsoft.ML.OnnxTransform.StaticPipe/DnnImageFeaturizerStaticExtensions.cs @@ -42,7 +42,7 @@ public override IEstimator Reconcile(IHostEnvironment env, Contracts.Assert(toOutput.Length == 1); var outCol = (OutColumn)toOutput[0]; - return new DnnImageFeaturizerEstimator(env, _modelFactory, inputNames[outCol.Input], outputNames[outCol]); + return new DnnImageFeaturizerEstimator(env, outputNames[outCol], _modelFactory, inputNames[outCol.Input]); } } diff --git a/src/Microsoft.ML.OnnxTransform.StaticPipe/OnnxStaticExtensions.cs b/src/Microsoft.ML.OnnxTransform.StaticPipe/OnnxStaticExtensions.cs index 64695e9922..378501b808 100644 --- a/src/Microsoft.ML.OnnxTransform.StaticPipe/OnnxStaticExtensions.cs +++ b/src/Microsoft.ML.OnnxTransform.StaticPipe/OnnxStaticExtensions.cs @@ -42,7 +42,7 @@ public override IEstimator Reconcile(IHostEnvironment env, Contracts.Assert(toOutput.Length == 1); var outCol = (OutColumn)toOutput[0]; - return new OnnxScoringEstimator(env, _modelFile, new[] { inputNames[outCol.Input] }, new[] { outputNames[outCol] }); + return new OnnxScoringEstimator(env, new[] { outputNames[outCol] }, new[] { inputNames[outCol.Input] }, _modelFile); } } diff --git a/src/Microsoft.ML.OnnxTransform/DnnImageFeaturizerTransform.cs b/src/Microsoft.ML.OnnxTransform/DnnImageFeaturizerTransform.cs index cc7cb2ffbe..dbac203908 100644 --- a/src/Microsoft.ML.OnnxTransform/DnnImageFeaturizerTransform.cs +++ b/src/Microsoft.ML.OnnxTransform/DnnImageFeaturizerTransform.cs @@ -31,11 +31,11 @@ public sealed class DnnImageFeaturizerInput public DnnImageModelSelector ModelSelector { get; } public string OutputColumn { get; } - public DnnImageFeaturizerInput(IHostEnvironment env, string inputColumn, string outputColumn, DnnImageModelSelector modelSelector) + public DnnImageFeaturizerInput(string outputColumnName, string inputColumnName, IHostEnvironment env, DnnImageModelSelector modelSelector) { Environment = env; - InputColumn = inputColumn; - OutputColumn = outputColumn; + InputColumn = inputColumnName; + OutputColumn = outputColumnName; ModelSelector = modelSelector; } } @@ -59,11 +59,11 @@ public sealed class DnnImageFeaturizerEstimator : IEstimators /// to allow arbitrary column naming, as the ONNXEstimators require very specific naming based on the models. /// For an example, see Microsoft.ML.DnnImageFeaturizer.ResNet18 - /// inputColumn column name. - /// Output column name. - public DnnImageFeaturizerEstimator(IHostEnvironment env, Func> modelFactory, string inputColumn, string outputColumn) + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. + public DnnImageFeaturizerEstimator(IHostEnvironment env, string outputColumnName, Func> modelFactory, string inputColumnName = null) { - _modelChain = modelFactory( new DnnImageFeaturizerInput(env, inputColumn, outputColumn, new DnnImageModelSelector())); + _modelChain = modelFactory(new DnnImageFeaturizerInput(outputColumnName, inputColumnName ?? outputColumnName, env, new DnnImageModelSelector())); } /// diff --git a/src/Microsoft.ML.OnnxTransform/OnnxCatalog.cs b/src/Microsoft.ML.OnnxTransform/OnnxCatalog.cs index e688e03bc2..16b56d06ad 100644 --- a/src/Microsoft.ML.OnnxTransform/OnnxCatalog.cs +++ b/src/Microsoft.ML.OnnxTransform/OnnxCatalog.cs @@ -14,13 +14,13 @@ public static class OnnxCatalog /// /// The transform's catalog. /// The path of the file containing the ONNX model. - /// The input columns. - /// The output columns resulting from the transformation. + /// The input columns. + /// The output columns resulting from the transformation. public static OnnxScoringEstimator ApplyOnnxModel(this TransformsCatalog catalog, string modelFile, - string[] inputColumns, - string[] outputColumns) - => new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), modelFile, inputColumns, outputColumns); + string[] outputColumnNames, + string[] inputColumnNames) + => new OnnxScoringEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames, modelFile); /// /// Initializes a new instance of . diff --git a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs index 6e19d188ee..a46e7c963e 100644 --- a/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs +++ b/src/Microsoft.ML.OnnxTransform/OnnxTransform.cs @@ -219,17 +219,17 @@ public OnnxTransformer(IHostEnvironment env, string modelFile, int? gpuDeviceId /// the model specification. Only 1 output column is generated. /// /// The environment to use. + /// Name of the column resulting from the transformation of . /// Model file path. - /// The name of the input data column. Must match model input name. - /// The output columns to generate. Names must match model specifications. Data types are inferred from model. + /// Name of column to transform. If set to , the value of the will be used as source. /// Optional GPU device ID to run execution on. Null for CPU. /// If GPU error, raise exception or fallback to CPU. - public OnnxTransformer(IHostEnvironment env, string modelFile, string inputColumn, string outputColumn, int? gpuDeviceId = null, bool fallbackToCpu = false) + public OnnxTransformer(IHostEnvironment env, string outputColumnName, string modelFile, string inputColumnName = null, int? gpuDeviceId = null, bool fallbackToCpu = false) : this(env, new Arguments() { ModelFile = modelFile, - InputColumns = new[] { inputColumn }, - OutputColumns = new[] { outputColumn }, + InputColumns = new[] { inputColumnName ?? outputColumnName }, + OutputColumns = new[] { outputColumnName }, GpuDeviceId = gpuDeviceId, FallbackToCpu = fallbackToCpu }) @@ -241,17 +241,17 @@ public OnnxTransformer(IHostEnvironment env, string modelFile, string inputColum /// all model input names. Only the output columns specified will be generated. /// /// The environment to use. + /// The output columns to generate. Names must match model specifications. Data types are inferred from model. + /// The name of the input data columns. Must match model's input names. /// Model file path. - /// The name of the input data columns. Must match model's input names. - /// The output columns to generate. Names must match model specifications. Data types are inferred from model. /// Optional GPU device ID to run execution on. Null for CPU. /// If GPU error, raise exception or fallback to CPU. - public OnnxTransformer(IHostEnvironment env, string modelFile, string[] inputColumns, string[] outputColumns, int? gpuDeviceId = null, bool fallbackToCpu = false) + public OnnxTransformer(IHostEnvironment env, string[] outputColumnNames, string[] inputColumnNames, string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false) : this(env, new Arguments() { ModelFile = modelFile, - InputColumns = inputColumns, - OutputColumns = outputColumns, + InputColumns = inputColumnNames, + OutputColumns = outputColumnNames, GpuDeviceId = gpuDeviceId, FallbackToCpu = fallbackToCpu }) @@ -523,7 +523,7 @@ public sealed class OnnxScoringEstimator : TrivialEstimator /// Optional GPU device ID to run execution on. Null for CPU. /// If GPU error, raise exception or fallback to CPU. public OnnxScoringEstimator(IHostEnvironment env, string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false) - : this(env, new OnnxTransformer(env, modelFile, new string[] { }, new string[] { }, gpuDeviceId, fallbackToCpu)) + : this(env, new OnnxTransformer(env, new string[] { }, new string[] { }, modelFile, gpuDeviceId, fallbackToCpu)) { } @@ -532,13 +532,13 @@ public OnnxScoringEstimator(IHostEnvironment env, string modelFile, int? gpuDevi /// all model input names. Only the output columns specified will be generated. /// /// The environment to use. + /// The output columns to generate. Names must match model specifications. Data types are inferred from model. + /// The name of the input data columns. Must match model's input names. /// Model file path. - /// The name of the input data columns. Must match model's input names. - /// The output columns to generate. Names must match model specifications. Data types are inferred from model. /// Optional GPU device ID to run execution on. Null for CPU. /// If GPU error, raise exception or fallback to CPU. - public OnnxScoringEstimator(IHostEnvironment env, string modelFile, string[] inputColumns, string[] outputColumns, int? gpuDeviceId = null, bool fallbackToCpu = false) - : this(env, new OnnxTransformer(env, modelFile, inputColumns, outputColumns, gpuDeviceId, fallbackToCpu)) + public OnnxScoringEstimator(IHostEnvironment env, string[] outputColumnNames, string[] inputColumnNames, string modelFile, int? gpuDeviceId = null, bool fallbackToCpu = false) + : this(env, new OnnxTransformer(env, outputColumnNames, inputColumnNames, modelFile, gpuDeviceId, fallbackToCpu)) { } diff --git a/src/Microsoft.ML.PCA/PCACatalog.cs b/src/Microsoft.ML.PCA/PCACatalog.cs index 06bb6b0270..fca74a2524 100644 --- a/src/Microsoft.ML.PCA/PCACatalog.cs +++ b/src/Microsoft.ML.PCA/PCACatalog.cs @@ -12,23 +12,23 @@ public static class PcaCatalog /// Initializes a new instance of . /// The transform's catalog. - /// Input column to apply PrincipalComponentAnalysis on. - /// Optional output column. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// The name of the weight column. /// The number of principal components. /// Oversampling parameter for randomized PrincipalComponentAnalysis training. /// If enabled, data is centered to be zero mean. /// The seed for random number generation. public static PrincipalComponentAnalysisEstimator ProjectToPrincipalComponents(this TransformsCatalog.ProjectionTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, string weightColumn = PrincipalComponentAnalysisEstimator.Defaults.WeightColumn, int rank = PrincipalComponentAnalysisEstimator.Defaults.Rank, int overSampling = PrincipalComponentAnalysisEstimator.Defaults.Oversampling, bool center = PrincipalComponentAnalysisEstimator.Defaults.Center, int? seed = null) => new PrincipalComponentAnalysisEstimator(CatalogUtils.GetEnvironment(catalog), - inputColumn, outputColumn, weightColumn, rank, overSampling, center, seed); + outputColumnName, inputColumnName, weightColumn, rank, overSampling, center, seed); /// Initializes a new instance of . /// The transform's catalog. diff --git a/src/Microsoft.ML.PCA/PcaTransformer.cs b/src/Microsoft.ML.PCA/PcaTransformer.cs index f960d52904..ec8704c5c7 100644 --- a/src/Microsoft.ML.PCA/PcaTransformer.cs +++ b/src/Microsoft.ML.PCA/PcaTransformer.cs @@ -98,8 +98,8 @@ internal bool TryUnparse(StringBuilder sb) public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly string WeightColumn; public readonly int Rank; public readonly int Oversampling; @@ -109,23 +109,24 @@ public sealed class ColumnInfo /// /// Describes how the transformer handles one column pair. /// - /// The column to apply PCA to. - /// The output column that contains PCA values. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. + /// If set to , the value of the will be used as source. /// The name of the weight column. /// The number of components in the PCA. /// Oversampling parameter for randomized PCA training. /// If enabled, data is centered to be zero mean. /// The seed for random number generation. - public ColumnInfo(string input, - string output, + public ColumnInfo(string name, + string inputColumnName = null, string weightColumn = PrincipalComponentAnalysisEstimator.Defaults.WeightColumn, int rank = PrincipalComponentAnalysisEstimator.Defaults.Rank, int overSampling = PrincipalComponentAnalysisEstimator.Defaults.Oversampling, bool center = PrincipalComponentAnalysisEstimator.Defaults.Center, int? seed = null) { - Input = input; - Output = output; + Name = name; + InputColumnName = inputColumnName ?? name; WeightColumn = weightColumn; Rank = rank; Oversampling = overSampling; @@ -254,7 +255,7 @@ internal PcaTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] colu { var colInfo = columns[i]; var sInfo = _schemaInfos[i] = new Mapper.ColumnSchemaInfo(ColumnPairs[i], input.Schema, colInfo.WeightColumn); - ValidatePcaInput(Host, colInfo.Input, sInfo.InputType); + ValidatePcaInput(Host, colInfo.InputColumnName, sInfo.InputType); _transformInfos[i] = new TransformInfo(colInfo.Rank, sInfo.InputType.GetValueCount()); } @@ -293,8 +294,8 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData env.CheckValue(input, nameof(input)); env.CheckValue(args.Column, nameof(args.Column)); var cols = args.Column.Select(item => new ColumnInfo( - item.Source, item.Name, + item.Source, item.WeightColumn, item.Rank ?? args.Rank, item.Oversampling ?? args.Oversampling, @@ -332,10 +333,10 @@ public override void Save(ModelSaveContext ctx) for (int i = 0; i < _transformInfos.Length; i++) _transformInfos[i].Save(ctx); } - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } private void Train(ColumnInfo[] columns, TransformInfo[] transformInfos, IDataView trainingData) @@ -358,7 +359,7 @@ private void Train(ColumnInfo[] columns, TransformInfo[] transformInfos, IDataVi using (var ch = Host.Start("Memory usage")) { ch.Info("Estimate memory usage for transforming column {1}: {0:G2} GB. If running out of memory, reduce rank and oversampling factor.", - colMemoryUsageEstimate, ColumnPairs[iinfo].input); + colMemoryUsageEstimate, ColumnPairs[iinfo].inputColumnName); } } @@ -493,7 +494,7 @@ private void Project(IDataView trainingData, float[][] mean, float[][][] omega, for (int iinfo = 0; iinfo < _numColumns; iinfo++) { if (totalColWeight[iinfo] <= 0) - throw Host.Except("Empty data in column '{0}'", ColumnPairs[iinfo].input); + throw Host.Except("Empty data in column '{0}'", ColumnPairs[iinfo].inputColumnName); } for (int iinfo = 0; iinfo < _numColumns; iinfo++) @@ -561,11 +562,11 @@ public sealed class ColumnSchemaInfo public int InputIndex { get; } public int WeightColumnIndex { get; } - public ColumnSchemaInfo((string input, string output) columnPair, Schema schema, string weightColumn = null) + public ColumnSchemaInfo((string outputColumnName, string inputColumnName) columnPair, Schema schema, string weightColumn = null) { - schema.TryGetColumnIndex(columnPair.input, out int inputIndex); + schema.TryGetColumnIndex(columnPair.inputColumnName, out int inputIndex); InputIndex = inputIndex; - InputType = schema[columnPair.input].Type; + InputType = schema[columnPair.inputColumnName].Type; var weightIndex = -1; if (weightColumn != null) @@ -590,10 +591,10 @@ public Mapper(PcaTransformer parent, Schema inputSchema) { var colPair = _parent.ColumnPairs[i]; var colSchemaInfo = new ColumnSchemaInfo(colPair, inputSchema); - ValidatePcaInput(Host, colPair.input, colSchemaInfo.InputType); + ValidatePcaInput(Host, colPair.inputColumnName, colSchemaInfo.InputType); if (colSchemaInfo.InputType.GetVectorSize() != _parent._transformInfos[i].Dimension) { - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.input, + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName, new VectorType(NumberType.R4, _parent._transformInfos[i].Dimension).ToString(), colSchemaInfo.InputType.ToString()); } } @@ -603,7 +604,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var result = new Schema.DetachedColumn[_numColumns]; for (int i = 0; i < _numColumns; i++) - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _parent._transformInfos[i].OutputType, null); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _parent._transformInfos[i].OutputType, null); return result; } @@ -676,18 +677,21 @@ internal static class Defaults /// /// The environment to use. - /// Input column to project to Principal Component. - /// Output column. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. + /// If set to , the value of the will be used as source. /// The name of the weight column. /// The number of components in the PCA. /// Oversampling parameter for randomized PCA training. /// If enabled, data is centered to be zero mean. /// The seed for random number generation. - public PrincipalComponentAnalysisEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, + public PrincipalComponentAnalysisEstimator(IHostEnvironment env, + string outputColumnName, + string inputColumnName = null, string weightColumn = Defaults.WeightColumn, int rank = Defaults.Rank, int overSampling = Defaults.Oversampling, bool center = Defaults.Center, int? seed = null) - : this(env, new PcaTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, weightColumn, rank, overSampling, center, seed)) + : this(env, new PcaTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, weightColumn, rank, overSampling, center, seed)) { } @@ -709,13 +713,13 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (col.Kind != SchemaShape.Column.VectorKind.Vector || !col.ItemType.Equals(NumberType.R4)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } diff --git a/src/Microsoft.ML.StaticPipe/CategoricalHashStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/CategoricalHashStaticExtensions.cs index 4b46d32739..29526e6cc1 100644 --- a/src/Microsoft.ML.StaticPipe/CategoricalHashStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/CategoricalHashStaticExtensions.cs @@ -105,7 +105,7 @@ public override IEstimator Reconcile(IHostEnvironment env, Pipelin for (int i = 0; i < toOutput.Length; ++i) { var tcol = (ICategoricalCol)toOutput[i]; - infos[i] = new OneHotHashEncodingEstimator.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], (OneHotEncodingTransformer.OutputKind)tcol.Config.OutputKind, + infos[i] = new OneHotHashEncodingEstimator.ColumnInfo(outputNames[toOutput[i]], inputNames[tcol.Input], (OneHotEncodingTransformer.OutputKind)tcol.Config.OutputKind, tcol.Config.HashBits, tcol.Config.Seed, tcol.Config.Ordered, tcol.Config.InvertHash); } return new OneHotHashEncodingEstimator(env, infos); diff --git a/src/Microsoft.ML.StaticPipe/CategoricalStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/CategoricalStaticExtensions.cs index cb0688befc..05e5d4f5e2 100644 --- a/src/Microsoft.ML.StaticPipe/CategoricalStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/CategoricalStaticExtensions.cs @@ -114,7 +114,7 @@ public override IEstimator Reconcile(IHostEnvironment env, Pipelin for (int i = 0; i < toOutput.Length; ++i) { var tcol = (ICategoricalCol)toOutput[i]; - infos[i] = new OneHotEncodingEstimator.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], (OneHotEncodingTransformer.OutputKind)tcol.Config.OutputKind, + infos[i] = new OneHotEncodingEstimator.ColumnInfo(outputNames[toOutput[i]], inputNames[tcol.Input], (OneHotEncodingTransformer.OutputKind)tcol.Config.OutputKind, tcol.Config.Max, (ValueToKeyMappingTransformer.SortOrder)tcol.Config.Order); if (tcol.Config.OnFit != null) { diff --git a/src/Microsoft.ML.StaticPipe/ImageTransformsStatic.cs b/src/Microsoft.ML.StaticPipe/ImageTransformsStatic.cs index 7109d5bce5..b7edd9f3d9 100644 --- a/src/Microsoft.ML.StaticPipe/ImageTransformsStatic.cs +++ b/src/Microsoft.ML.StaticPipe/ImageTransformsStatic.cs @@ -59,11 +59,11 @@ public override IEstimator Reconcile(IHostEnvironment env, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) { - var cols = new (string input, string output)[toOutput.Length]; + var cols = new (string outputColumnName, string inputColumnName)[toOutput.Length]; for (int i = 0; i < toOutput.Length; ++i) { var outCol = (OutPipelineColumn)toOutput[i]; - cols[i] = (inputNames[outCol._input], outputNames[outCol]); + cols[i] = (outputNames[outCol], inputNames[outCol._input]); } return new ImageLoadingEstimator(env, _relTo, cols); } @@ -109,11 +109,11 @@ public override IEstimator Reconcile(IHostEnvironment env, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) { - var cols = new (string input, string output)[toOutput.Length]; + var cols = new (string outputColumnName, string inputColumnName)[toOutput.Length]; for (int i = 0; i < toOutput.Length; ++i) { var outCol = (IColInput)toOutput[i]; - cols[i] = (inputNames[outCol.Input], outputNames[toOutput[i]]); + cols[i] = (outputNames[toOutput[i]], inputNames[outCol.Input]); } return new ImageGrayscalingEstimator(env, cols); } @@ -142,8 +142,8 @@ public OutPipelineColumn(PipelineColumn input, int width, int height, _cropAnchor = cropAnchor; } - private ImageResizerTransformer.ColumnInfo MakeColumnInfo(string input, string output) - => new ImageResizerTransformer.ColumnInfo(input, output, _width, _height, _resizing, _cropAnchor); + private ImageResizerTransformer.ColumnInfo MakeColumnInfo(string outputColumnName, string inputColumnName) + => new ImageResizerTransformer.ColumnInfo(outputColumnName, _width, _height, inputColumnName, _resizing, _cropAnchor); /// /// Reconciler to an for the . @@ -168,7 +168,7 @@ public override IEstimator Reconcile(IHostEnvironment env, for (int i = 0; i < toOutput.Length; ++i) { var outCol = (OutPipelineColumn)toOutput[i]; - cols[i] = outCol.MakeColumnInfo(inputNames[outCol._input], outputNames[outCol]); + cols[i] = outCol.MakeColumnInfo(outputNames[outCol], inputNames[outCol._input]); } return new ImageResizingEstimator(env, cols); } @@ -182,7 +182,7 @@ private interface IColInput { Custom Input { get; } - ImagePixelExtractorTransformer.ColumnInfo MakeColumnInfo(string input, string output); + ImagePixelExtractorTransformer.ColumnInfo MakeColumnInfo(string outputColumnName, string inputColumnName); } internal sealed class OutPipelineColumn : Vector, IColInput @@ -200,14 +200,14 @@ public OutPipelineColumn(Custom input, ImagePixelExtractorTransformer.Co _colParam = col; } - public ImagePixelExtractorTransformer.ColumnInfo MakeColumnInfo(string input, string output) + public ImagePixelExtractorTransformer.ColumnInfo MakeColumnInfo(string outputColumnName, string inputColumnName) { // In principle, the analyzer should only call the the reconciler once for these columns. Contracts.Assert(_colParam.Source == null); Contracts.Assert(_colParam.Name == null); - _colParam.Name = output; - _colParam.Source = input; + _colParam.Name = outputColumnName; + _colParam.Source = inputColumnName; return new ImagePixelExtractorTransformer.ColumnInfo(_colParam, _defaultArgs); } } @@ -237,7 +237,7 @@ public override IEstimator Reconcile(IHostEnvironment env, for (int i = 0; i < toOutput.Length; ++i) { var outCol = (IColInput)toOutput[i]; - cols[i] = outCol.MakeColumnInfo(inputNames[outCol.Input], outputNames[toOutput[i]]); + cols[i] = outCol.MakeColumnInfo(outputNames[toOutput[i]], inputNames[outCol.Input]); } return new ImagePixelExtractingEstimator(env, cols); } diff --git a/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs index 7a940e4b57..f87e5ef4a1 100644 --- a/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/LdaStaticExtensions.cs @@ -107,7 +107,8 @@ public override IEstimator Reconcile(IHostEnvironment env, { var tcol = (ILdaCol)toOutput[i]; - infos[i] = new LatentDirichletAllocationTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], + infos[i] = new LatentDirichletAllocationTransformer.ColumnInfo(outputNames[toOutput[i]], + inputNames[tcol.Input], tcol.Config.NumTopic, tcol.Config.AlphaSum, tcol.Config.Beta, diff --git a/src/Microsoft.ML.StaticPipe/LpNormalizerStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/LpNormalizerStaticExtensions.cs index a48d587e96..3b765dd18b 100644 --- a/src/Microsoft.ML.StaticPipe/LpNormalizerStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/LpNormalizerStaticExtensions.cs @@ -44,9 +44,9 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); - var pairs = new List<(string input, string output)>(); + var pairs = new List<(string outputColumnName, string inputColumnName)>(); foreach (var outCol in toOutput) - pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + pairs.Add((outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input])); return new LpNormalizingEstimator(env, pairs.ToArray(), _normKind, _subMean); } diff --git a/src/Microsoft.ML.StaticPipe/NormalizerStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/NormalizerStaticExtensions.cs index 65a6004d47..e5a812acaf 100644 --- a/src/Microsoft.ML.StaticPipe/NormalizerStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/NormalizerStaticExtensions.cs @@ -72,7 +72,7 @@ private static NormVector NormalizeByMinMaxCore(Vector input, bool fixZ { Contracts.CheckValue(input, nameof(input)); Contracts.CheckParam(maxTrainingExamples > 1, nameof(maxTrainingExamples), "Must be greater than 1"); - return new Impl(input, (src, name) => new NormalizingEstimator.MinMaxColumn(src, name, maxTrainingExamples, fixZero), AffineMapper(onFit)); + return new Impl(input, (name, src) => new NormalizingEstimator.MinMaxColumn(name, src, maxTrainingExamples, fixZero), AffineMapper(onFit)); } // We have a slightly different breaking up of categories of normalizers versus the dynamic API. Both the mean-var and @@ -171,11 +171,11 @@ private static NormVector NormalizeByMVCdfCore(Vector input, bool fixZe { Contracts.CheckValue(input, nameof(input)); Contracts.CheckParam(maxTrainingExamples > 1, nameof(maxTrainingExamples), "Must be greater than 1"); - return new Impl(input, (src, name) => + return new Impl(input, (name, src) => { if (useLog) - return new NormalizingEstimator.LogMeanVarColumn(src, name, maxTrainingExamples, useCdf); - return new NormalizingEstimator.MeanVarColumn(src, name, maxTrainingExamples, fixZero, useCdf); + return new NormalizingEstimator.LogMeanVarColumn(name, src, maxTrainingExamples, useCdf); + return new NormalizingEstimator.MeanVarColumn(name, src, maxTrainingExamples, fixZero, useCdf); }, onFit); } @@ -235,7 +235,7 @@ private static NormVector NormalizeByBinningCore(Vector input, int numB Contracts.CheckValue(input, nameof(input)); Contracts.CheckParam(numBins > 1, nameof(maxTrainingExamples), "Must be greater than 1"); Contracts.CheckParam(maxTrainingExamples > 1, nameof(maxTrainingExamples), "Must be greater than 1"); - return new Impl(input, (src, name) => new NormalizingEstimator.BinningColumn(src, name, maxTrainingExamples, fixZero, numBins), BinMapper(onFit)); + return new Impl(input, (name, src) => new NormalizingEstimator.BinningColumn(name, src, maxTrainingExamples, fixZero, numBins), BinMapper(onFit)); } /// @@ -271,7 +271,7 @@ private static NormVector NormalizeByBinningCore(Vector input, int numB public delegate void OnFitBinned(ImmutableArray upperBounds); #region Implementation support - private delegate NormalizingEstimator.ColumnBase CreateNormCol(string input, string name); + private delegate NormalizingEstimator.ColumnBase CreateNormCol(string outputColumnName, string inputColumnName); private sealed class Rec : EstimatorReconciler { @@ -287,7 +287,7 @@ public override IEstimator Reconcile(IHostEnvironment env, Pipelin for (int i = 0; i < toOutput.Length; ++i) { var col = (INormColCreator)toOutput[i]; - cols[i] = col.CreateNormCol(inputNames[col.Input], outputNames[toOutput[i]]); + cols[i] = col.CreateNormCol(outputNames[toOutput[i]], inputNames[col.Input]); if (col.OnFit != null) Utils.Add(ref onFits, (i, col.OnFit)); } diff --git a/src/Microsoft.ML.StaticPipe/StaticPipeUtils.cs b/src/Microsoft.ML.StaticPipe/StaticPipeUtils.cs index ec893a0589..9a70f3ee11 100644 --- a/src/Microsoft.ML.StaticPipe/StaticPipeUtils.cs +++ b/src/Microsoft.ML.StaticPipe/StaticPipeUtils.cs @@ -134,7 +134,7 @@ internal static IDataReaderEstimator> } estimator = null; - var toCopy = new List<(string src, string dst)>(); + var toCopy = new List<(string dst, string src)>(); int tempNum = 0; // For all outputs, get potential name collisions with used inputs. Resolve by assigning the input a temporary name. @@ -147,7 +147,7 @@ internal static IDataReaderEstimator> ch.Assert(baseInputs.Contains(inputCol)); string tempName = $"#Temp_{tempNum++}"; ch.Trace($"Input/output name collision: Renaming '{p.Key}' to '{tempName}'."); - toCopy.Add((p.Key, tempName)); + toCopy.Add((tempName, p.Key)); nameMap[tempName] = nameMap[p.Key]; ch.Assert(!nameMap.ContainsKey(p.Key)); } @@ -255,8 +255,8 @@ internal static IDataReaderEstimator> if (keyDependsOn.Any(p => p.Value.Count > 0)) { // This might happen if the user does something incredibly strange, like, say, take some prior - // lambda, assign a column to a local variable, then re-use it downstream in a different lambdas. - // The user would have to go to some extraorindary effort to do that, but nonetheless we want to + // lambda, assign a column to a local variable, then re-use it downstream in a different lambda. + // The user would have to go to some extraordinary effort to do that, but nonetheless we want to // fail with a semi-sensible error message. throw ch.Except("There were some leftover columns with unresolved dependencies. " + "Did the caller use a " + nameof(PipelineColumn) + " from another delegate?"); @@ -272,8 +272,8 @@ internal static IDataReaderEstimator> string currentName = nameMap[p.Value]; if (currentName != p.Key) { - ch.Trace($"Will copy '{currentName}' to '{p.Key}'"); - toCopy.Add((currentName, p.Key)); + ch.Trace($"Will copy '{p.Key}' to '{currentName}'"); + toCopy.Add((p.Key, currentName)); } } diff --git a/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs index 7aac7bb6e8..a2155330c6 100644 --- a/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/TextStaticExtensions.cs @@ -43,9 +43,9 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); - var pairs = new List<(string input, string output)>(); + var pairs = new List<(string outputColumnName, string inputColumnName)>(); foreach (var outCol in toOutput) - pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + pairs.Add((outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input])); return new WordTokenizingEstimator(env, pairs.ToArray(), _separators); } @@ -97,9 +97,9 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); - var pairs = new List<(string input, string output)>(); + var pairs = new List<(string outputColumnName, string inputColumnName)>(); foreach (var outCol in toOutput) - pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + pairs.Add((outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input])); return new TokenizingByCharactersEstimator(env, _useMarker, pairs.ToArray()); } @@ -153,7 +153,7 @@ public override IEstimator Reconcile(IHostEnvironment env, var columns = new List(); foreach (var outCol in toOutput) - columns.Add(new StopWordsRemovingTransformer.ColumnInfo(inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol], _language)); + columns.Add(new StopWordsRemovingTransformer.ColumnInfo(outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input], _language)); return new StopWordsRemovingEstimator(env, columns.ToArray()); } @@ -216,9 +216,9 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); - var pairs = new List<(string input, string output)>(); + var pairs = new List<(string outputColumnName, string inputColumnName)>(); foreach (var outCol in toOutput) - pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + pairs.Add((outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input])); return new TextNormalizingEstimator(env, _textCase, _keepDiacritics, _keepPunctuations, _keepNumbers, pairs.ToArray()); } @@ -295,9 +295,9 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); - var pairs = new List<(string[] inputs, string output)>(); + var pairs = new List<(string names, string[] sources)>(); foreach (var outCol in toOutput) - pairs.Add((new[] { inputNames[((OutPipelineColumn)outCol).Input] }, outputNames[outCol])); + pairs.Add((outputNames[outCol], new[] { inputNames[((OutPipelineColumn)outCol).Input] })); return new WordBagEstimator(env, pairs.ToArray(), _ngramLength, _skipLength, _allLengths, _maxNumTerms, _weighting); } @@ -385,9 +385,9 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); - var pairs = new List<(string[] inputs, string output)>(); + var pairs = new List<(string name, string[] sources)>(); foreach (var outCol in toOutput) - pairs.Add((new[] { inputNames[((OutPipelineColumn)outCol).Input] }, outputNames[outCol])); + pairs.Add((outputNames[outCol], new[] { inputNames[((OutPipelineColumn)outCol).Input] })); return new WordHashBagEstimator(env, pairs.ToArray(), _hashBits, _ngramLength, _skipLength, _allLengths, _seed, _ordered, _invertHash); } @@ -476,7 +476,7 @@ public override IEstimator Reconcile(IHostEnvironment env, var pairs = new List<(string inputs, string output)>(); foreach (var outCol in toOutput) - pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + pairs.Add((outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input])); return new NgramExtractingEstimator(env, pairs.ToArray(), _ngramLength, _skipLength, _allLengths, _maxNumTerms, _weighting); } @@ -561,7 +561,7 @@ public override IEstimator Reconcile(IHostEnvironment env, Contracts.Assert(toOutput.Length == 1); var columns = new List(); foreach (var outCol in toOutput) - columns.Add(new NgramHashingTransformer.ColumnInfo(new[] { inputNames[((OutPipelineColumn)outCol).Input] }, outputNames[outCol], + columns.Add(new NgramHashingTransformer.ColumnInfo(outputNames[outCol], new[] { inputNames[((OutPipelineColumn)outCol).Input] }, _ngramLength, _skipLength, _allLengths, _hashBits, _seed, _ordered, _invertHash)); return new NgramHashingEstimator(env, columns.ToArray()); diff --git a/src/Microsoft.ML.StaticPipe/TrainerEstimatorReconciler.cs b/src/Microsoft.ML.StaticPipe/TrainerEstimatorReconciler.cs index 15753abedf..4aab5e3a83 100644 --- a/src/Microsoft.ML.StaticPipe/TrainerEstimatorReconciler.cs +++ b/src/Microsoft.ML.StaticPipe/TrainerEstimatorReconciler.cs @@ -101,7 +101,7 @@ public sealed override IEstimator Reconcile(IHostEnvironment env, newInputNames[p.Key] = old2New.ContainsKey(p.Value) ? old2New[p.Value] : p.Value; inputNames = newInputNames; } - result = new ColumnCopyingEstimator(env, old2New.Select(p => (p.Key, p.Value)).ToArray()); + result = new ColumnCopyingEstimator(env, old2New.Select(p => (p.Value, p.Key)).ToArray()); } // Map the inputs to the names. @@ -115,17 +115,17 @@ public sealed override IEstimator Reconcile(IHostEnvironment env, // OK. Now handle the final renamings from the fixed names, to the desired names, in the case // where the output was desired, and a renaming is even necessary. - var toRename = new List<(string source, string name)>(); + var toRename = new List<(string outputColumnName, string inputColumnName)>(); foreach ((PipelineColumn outCol, string fixedName) in Outputs.Zip(_outputNames, (c, n) => (c, n))) { if (outputNames.TryGetValue(outCol, out string desiredName)) - toRename.Add((fixedName, desiredName)); + toRename.Add((desiredName, fixedName)); else env.Assert(!toOutput.Contains(outCol)); } // Finally if applicable handle the renaming back from the temp names to the original names. foreach (var p in old2New) - toRename.Add((p.Value, p.Key)); + toRename.Add((p.Key, p.Value)); if (toRename.Count > 0) result = result.Append(new ColumnCopyingEstimator(env, toRename.ToArray())); diff --git a/src/Microsoft.ML.StaticPipe/TransformsStatic.cs b/src/Microsoft.ML.StaticPipe/TransformsStatic.cs index b98fc856a1..4c02d2c098 100644 --- a/src/Microsoft.ML.StaticPipe/TransformsStatic.cs +++ b/src/Microsoft.ML.StaticPipe/TransformsStatic.cs @@ -55,9 +55,9 @@ public override IEstimator Reconcile(IHostEnvironment env, { Contracts.Assert(toOutput.Length == 1); - var pairs = new List<(string input, string output)>(); + var pairs = new List<(string outputColumnName, string inputColumnName)>(); foreach (var outCol in toOutput) - pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + pairs.Add((outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input])); return new GlobalContrastNormalizingEstimator(env, pairs.ToArray(), _subMean, _useStdDev, _scale); } @@ -118,9 +118,9 @@ public override IEstimator Reconcile(IHostEnvironment env, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) { - var pairs = new List<(string input, string output)>(); + var pairs = new List<(string outputColumnName, string inputColumnName)>(); foreach (var outCol in toOutput) - pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); + pairs.Add((outputNames[outCol], inputNames[((OutPipelineColumn)outCol).Input])); return new MutualInformationFeatureSelectingEstimator(env, inputNames[_labelColumn], _slotsInOutput, _numBins, pairs.ToArray()); } @@ -270,7 +270,7 @@ public override IEstimator Reconcile(IHostEnvironment env, var infos = new CountFeatureSelectingEstimator.ColumnInfo[toOutput.Length]; for (int i = 0; i < toOutput.Length; i++) - infos[i] = new CountFeatureSelectingEstimator.ColumnInfo(inputNames[((OutPipelineColumn)toOutput[i]).Input], outputNames[toOutput[i]], _count); + infos[i] = new CountFeatureSelectingEstimator.ColumnInfo(outputNames[toOutput[i]], inputNames[((OutPipelineColumn)toOutput[i]).Input], _count); return new CountFeatureSelectingEstimator(env, infos); } @@ -396,7 +396,7 @@ public override IEstimator Reconcile(IHostEnvironment env, for (int i = 0; i < toOutput.Length; ++i) { var col = (IColInput)toOutput[i]; - infos[i] = new KeyToBinaryVectorMappingTransformer.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]]); + infos[i] = new KeyToBinaryVectorMappingTransformer.ColumnInfo(outputNames[toOutput[i]], inputNames[col.Input]); } return new KeyToBinaryVectorMappingEstimator(env, infos); } @@ -582,7 +582,7 @@ public override IEstimator Reconcile(IHostEnvironment env, for (int i = 0; i < toOutput.Length; ++i) { var col = (IColInput)toOutput[i]; - infos[i] = new KeyToVectorMappingTransformer.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]], col.Bag); + infos[i] = new KeyToVectorMappingTransformer.ColumnInfo(outputNames[toOutput[i]], inputNames[col.Input], col.Bag); } return new KeyToVectorMappingEstimator(env, infos); } @@ -807,7 +807,7 @@ public override IEstimator Reconcile(IHostEnvironment env, for (int i = 0; i < toOutput.Length; ++i) { var col = (IColInput)toOutput[i]; - infos[i] = new MissingValueReplacingTransformer.ColumnInfo(inputNames[col.Input], outputNames[toOutput[i]], col.Config.ReplacementMode, col.Config.ImputeBySlot); + infos[i] = new MissingValueReplacingTransformer.ColumnInfo(outputNames[toOutput[i]], inputNames[col.Input], col.Config.ReplacementMode, col.Config.ImputeBySlot); } return new MissingValueReplacingEstimator(env, infos); } @@ -937,7 +937,7 @@ public override IEstimator Reconcile(IHostEnvironment env, Pipelin for (int i = 0; i < toOutput.Length; ++i) { var tcol = (IConvertCol)toOutput[i]; - infos[i] = new TypeConvertingTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], tcol.Kind); + infos[i] = new TypeConvertingTransformer.ColumnInfo(outputNames[toOutput[i]], tcol.Kind, inputNames[tcol.Input]); } return new TypeConvertingEstimator(env, infos); } @@ -1021,14 +1021,16 @@ private sealed class Rec : EstimatorReconciler public static readonly Rec Inst = new Rec(); public override IEstimator Reconcile(IHostEnvironment env, PipelineColumn[] toOutput, - IReadOnlyDictionary inputNames, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) { var infos = new ValueToKeyMappingTransformer.ColumnInfo[toOutput.Length]; Action onFit = null; for (int i = 0; i < toOutput.Length; ++i) { var tcol = (ITermCol)toOutput[i]; - infos[i] = new ValueToKeyMappingTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], + infos[i] = new ValueToKeyMappingTransformer.ColumnInfo(outputNames[toOutput[i]], inputNames[tcol.Input], tcol.Config.Max, (ValueToKeyMappingTransformer.SortOrder)tcol.Config.Order); if (tcol.Config.OnFit != null) { @@ -1110,11 +1112,11 @@ public override IEstimator Reconcile(IHostEnvironment env, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) { - var cols = new (string input, string output)[toOutput.Length]; + var cols = new (string outputColumnName, string inputColumnName)[toOutput.Length]; for (int i = 0; i < toOutput.Length; ++i) { var outCol = (IColInput)toOutput[i]; - cols[i] = (inputNames[outCol.Input], outputNames[toOutput[i]]); + cols[i] = (outputNames[toOutput[i]], inputNames[outCol.Input]); } return new KeyToValueMappingEstimator(env, cols); } @@ -1421,11 +1423,11 @@ public override IEstimator Reconcile(IHostEnvironment env, IReadOnlyDictionary outputNames, IReadOnlyCollection usedNames) { - var columnPairs = new (string input, string output)[toOutput.Length]; + var columnPairs = new (string outputColumnName, string inputColumnName)[toOutput.Length]; for (int i = 0; i < toOutput.Length; ++i) { var col = (IColInput)toOutput[i]; - columnPairs[i] = (inputNames[col.Input], outputNames[toOutput[i]]); + columnPairs[i] = (outputNames[toOutput[i]], inputNames[col.Input]); } return new MissingValueIndicatorEstimator(env, columnPairs); } @@ -1533,7 +1535,7 @@ public override IEstimator Reconcile(IHostEnvironment env, var outCol = (OutPipelineColumn)toOutput[0]; var inputs = outCol.Inputs.Select(x => inputNames[x]); - return new TextFeaturizingEstimator(env, inputs, outputNames[outCol], _settings); + return new TextFeaturizingEstimator(env, outputNames[outCol], inputs, _settings); } } /// @@ -1597,7 +1599,7 @@ public override IEstimator Reconcile(IHostEnvironment env, Pipelin for (int i = 0; i < toOutput.Length; ++i) { var tcol = (IColInput)toOutput[i]; - infos[i] = new RandomFourierFeaturizingTransformer.ColumnInfo(inputNames[tcol.Input], outputNames[toOutput[i]], tcol.Config.NewDim, tcol.Config.UseSin, tcol.Config.Generator, tcol.Config.Seed); + infos[i] = new RandomFourierFeaturizingTransformer.ColumnInfo(outputNames[toOutput[i]], tcol.Config.NewDim, tcol.Config.UseSin, inputNames[tcol.Input], tcol.Config.Generator, tcol.Config.Seed); } return new RandomFourierFeaturizingEstimator(env, infos); } @@ -1656,7 +1658,7 @@ public override IEstimator Reconcile(IHostEnvironment env, var outCol = (OutPipelineColumn)toOutput[0]; var inputColName = inputNames[outCol.Input]; var outputColName = outputNames[outCol]; - return new PrincipalComponentAnalysisEstimator(env, inputColName, outputColName, + return new PrincipalComponentAnalysisEstimator(env, outputColName, inputColName, _colInfo.WeightColumn, _colInfo.Rank, _colInfo.Oversampling, _colInfo.Center, _colInfo.Seed); } diff --git a/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs b/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs index f928045c4e..1c0908d9ae 100644 --- a/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs +++ b/src/Microsoft.ML.StaticPipe/WordEmbeddingsStaticExtensions.cs @@ -76,7 +76,7 @@ public override IEstimator Reconcile(IHostEnvironment env, for (int i = 0; i < toOutput.Length; ++i) { var outCol = (OutColumn)toOutput[i]; - cols[i] = new WordEmbeddingsExtractingTransformer.ColumnInfo(inputNames[outCol.Input], outputNames[outCol]); + cols[i] = new WordEmbeddingsExtractingTransformer.ColumnInfo(outputNames[outCol], inputNames[outCol.Input]); } bool customLookup = !string.IsNullOrWhiteSpace(_customLookupTable); diff --git a/src/Microsoft.ML.TensorFlow.StaticPipe/TensorFlowStaticExtensions.cs b/src/Microsoft.ML.TensorFlow.StaticPipe/TensorFlowStaticExtensions.cs index 4baf0aa156..d131a17dc7 100644 --- a/src/Microsoft.ML.TensorFlow.StaticPipe/TensorFlowStaticExtensions.cs +++ b/src/Microsoft.ML.TensorFlow.StaticPipe/TensorFlowStaticExtensions.cs @@ -59,13 +59,9 @@ public override IEstimator Reconcile(IHostEnvironment env, var outCol = (OutColumn)toOutput[0]; if (_modelFile == null) - { - return new TensorFlowEstimator(env, _tensorFlowModel, new[] { inputNames[outCol.Input] }, new[] { outputNames[outCol] }); - } + return new TensorFlowEstimator(env, new[] { outputNames[outCol] }, new[] { inputNames[outCol.Input] }, _tensorFlowModel); else - { - return new TensorFlowEstimator(env, _modelFile, new[] { inputNames[outCol.Input] }, new[] { outputNames[outCol] }); - } + return new TensorFlowEstimator(env, new[] { outputNames[outCol] }, new[] { inputNames[outCol.Input] }, _modelFile); } } diff --git a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs index 374aed0cd3..7c5eac8f4a 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowCatalog.cs @@ -16,8 +16,8 @@ public static class TensorflowCatalog /// /// The transform's catalog. /// Location of the TensorFlow model. - /// The names of the model inputs. - /// The names of the requested model outputs. + /// The names of the model inputs. + /// The names of the requested model outputs. /// /// /// public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog, string modelLocation, - string[] inputs, - string[] outputs) - => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), modelLocation, inputs, outputs); + string[] outputColumnNames, + string[] inputColumnNames) + => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames, modelLocation); /// /// Scores a dataset using a pre-traiend TensorFlow model specified via . /// /// The transform's catalog. /// The pre-trained TensorFlow model. - /// The names of the model inputs. - /// The names of the requested model outputs. + /// The names of the model inputs. + /// The names of the requested model outputs. public static TensorFlowEstimator ScoreTensorFlowModel(this TransformsCatalog catalog, TensorFlowModelInfo tensorFlowModel, - string[] inputs, - string[] outputs) - => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), tensorFlowModel, inputs, outputs); + string[] outputColumnNames, + string[] inputColumnNames) + => new TensorFlowEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnNames, inputColumnNames, tensorFlowModel); /// /// Score or Retrain a tensorflow model (based on setting of the ) setting. diff --git a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs index 277f7bbf78..d804375d52 100644 --- a/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs +++ b/src/Microsoft.ML.TensorFlow/TensorflowTransform.cs @@ -185,10 +185,10 @@ private static VersionInfo GetVersionInfo() /// /// The environment to use. /// Model file path. - /// The name of the input data column. Must match model input name. - /// The output columns to generate. Names must match model specifications. Data types are inferred from model. - public TensorFlowTransformer(IHostEnvironment env, string modelFile, string inputColumn, string outputColumn) - : this(env, TensorFlowUtils.GetSession(env, modelFile), new[] { inputColumn }, new[] { outputColumn }, TensorFlowUtils.IsSavedModel(env, modelFile) ? modelFile : null, false) + /// The output columns to generate. Names must match model specifications. Data types are inferred from model. + /// The name of the input data column. Must match model input name. If set to , the value of the will be used as source. + public TensorFlowTransformer(IHostEnvironment env, string modelFile, string outputColumnName, string inputColumnName = null) + : this(env, TensorFlowUtils.GetSession(env, modelFile), new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, TensorFlowUtils.IsSavedModel(env, modelFile) ? modelFile : null, false) { } @@ -199,10 +199,10 @@ public TensorFlowTransformer(IHostEnvironment env, string modelFile, string inpu /// /// The environment to use. /// Model file path. - /// The name of the input data columns. Must match model's input names. - /// The output columns to generate. Names must match model specifications. Data types are inferred from model. - public TensorFlowTransformer(IHostEnvironment env, string modelFile, string[] inputColumns, string[] outputColumns) - : this(env, TensorFlowUtils.GetSession(env, modelFile), inputColumns, outputColumns, TensorFlowUtils.IsSavedModel(env, modelFile) ? modelFile : null, false) + /// The name of the input data columns. Must match model's input names. + /// The output columns to generate. Names must match model specifications. Data types are inferred from model. + public TensorFlowTransformer(IHostEnvironment env, string modelFile, string[] outputColumnNames, string[] inputColumnNames) + : this(env, TensorFlowUtils.GetSession(env, modelFile), outputColumnNames, inputColumnNames, TensorFlowUtils.IsSavedModel(env, modelFile) ? modelFile : null, false) { } @@ -214,10 +214,10 @@ public TensorFlowTransformer(IHostEnvironment env, string modelFile, string[] in /// /// The environment to use. /// object created with . - /// The name of the input data columns. Must match model's input names. - /// The output columns to generate. Names must match model specifications. Data types are inferred from model. - public TensorFlowTransformer(IHostEnvironment env, TensorFlowModelInfo tfModelInfo, string inputColumn, string outputColumn) - : this(env, tfModelInfo.Session, new[] { inputColumn }, new[] { outputColumn }, TensorFlowUtils.IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false) + /// The output columns to generate. Names must match model specifications. Data types are inferred from model. + /// The name of the input data columns. Must match model's input names. If set to , the value of the will be used as source. + public TensorFlowTransformer(IHostEnvironment env, TensorFlowModelInfo tfModelInfo, string outputColumnName, string inputColumnName = null) + : this(env, tfModelInfo.Session, new[] { outputColumnName }, new[] { inputColumnName ?? outputColumnName }, TensorFlowUtils.IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false) { } @@ -229,10 +229,10 @@ public TensorFlowTransformer(IHostEnvironment env, TensorFlowModelInfo tfModelIn /// /// The environment to use. /// object created with . - /// The name of the input data columns. Must match model's input names. - /// The output columns to generate. Names must match model specifications. Data types are inferred from model. - public TensorFlowTransformer(IHostEnvironment env, TensorFlowModelInfo tfModelInfo, string[] inputColumns, string[] outputColumns) - : this(env, tfModelInfo.Session, inputColumns, outputColumns, TensorFlowUtils.IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false) + /// The name of the input data columns. Must match model's input names. + /// The output columns to generate. Names must match model specifications. Data types are inferred from model. + public TensorFlowTransformer(IHostEnvironment env, TensorFlowModelInfo tfModelInfo, string[] outputColumnNames, string[] inputColumnNames) + : this(env, tfModelInfo.Session, outputColumnNames, inputColumnNames, TensorFlowUtils.IsSavedModel(env, tfModelInfo.ModelPath) ? tfModelInfo.ModelPath : null, false) { } @@ -258,7 +258,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte byte[] modelBytes = null; if (!ctx.TryLoadBinaryStream("TFModel", r => modelBytes = r.ReadByteArray())) throw env.ExceptDecode(); - return new TensorFlowTransformer(env, TensorFlowUtils.LoadTFSession(env, modelBytes), inputs, outputs, null, false); + return new TensorFlowTransformer(env, TensorFlowUtils.LoadTFSession(env, modelBytes), outputs, inputs, null, false); } var tempDirPath = Path.GetFullPath(Path.Combine(Path.GetTempPath(), nameof(TensorFlowTransformer) + "_" + Guid.NewGuid())); @@ -287,7 +287,7 @@ private static TensorFlowTransformer Create(IHostEnvironment env, ModelLoadConte } }); - return new TensorFlowTransformer(env, TensorFlowUtils.GetSession(env, tempDirPath), inputs, outputs, tempDirPath, true); + return new TensorFlowTransformer(env, TensorFlowUtils.GetSession(env, tempDirPath), outputs, inputs, tempDirPath, true); } catch (Exception) { @@ -314,7 +314,7 @@ internal TensorFlowTransformer(IHostEnvironment env, Arguments args, IDataView i } internal TensorFlowTransformer(IHostEnvironment env, Arguments args, TensorFlowModelInfo tensorFlowModel, IDataView input) - : this(env, tensorFlowModel.Session, args.InputColumns, args.OutputColumns, TensorFlowUtils.IsSavedModel(env, args.ModelLocation) ? args.ModelLocation : null, false) + : this(env, tensorFlowModel.Session, args.OutputColumns, args.InputColumns, TensorFlowUtils.IsSavedModel(env, args.ModelLocation) ? args.ModelLocation : null, false) { Contracts.CheckValue(env, nameof(env)); @@ -626,19 +626,19 @@ private static void GetModelInfo(IHostEnvironment env, ModelLoadContext ctx, out outputs[j] = ctx.LoadNonEmptyString(); } - internal TensorFlowTransformer(IHostEnvironment env, TFSession session, string[] inputs, string[] outputs, string savedModelPath, bool isTemporarySavedModel) : + internal TensorFlowTransformer(IHostEnvironment env, TFSession session, string[] outputColumnNames, string[] inputColumnNames, string savedModelPath, bool isTemporarySavedModel) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TensorFlowTransformer))) { Host.CheckValue(session, nameof(session)); - Host.CheckNonEmpty(inputs, nameof(inputs)); - Host.CheckNonEmpty(outputs, nameof(outputs)); + Host.CheckNonEmpty(inputColumnNames, nameof(inputColumnNames)); + Host.CheckNonEmpty(outputColumnNames, nameof(outputColumnNames)); Session = session; _savedModelPath = savedModelPath; _isTemporarySavedModel = isTemporarySavedModel; - Inputs = inputs; - Outputs = outputs; + Inputs = inputColumnNames; + Outputs = outputColumnNames; (TFInputTypes, TFInputShapes) = GetInputInfo(Host, Session, Inputs); (TFOutputTypes, OutputTypes) = GetOutputInfo(Host, Session, Outputs); @@ -1089,13 +1089,13 @@ public sealed class TensorFlowEstimator : IEstimator private readonly ColumnType[] _outputTypes; private TensorFlowTransformer _transformer; - public TensorFlowEstimator(IHostEnvironment env, string modelLocation, string[] inputs, string[] outputs) - : this(env, TensorFlowUtils.LoadTensorFlowModel(env, modelLocation), inputs, outputs) + public TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, string[] inputColumnNames, string modelLocation) + : this(env, outputColumnNames, inputColumnNames, TensorFlowUtils.LoadTensorFlowModel(env, modelLocation)) { } - public TensorFlowEstimator(IHostEnvironment env, TensorFlowModelInfo tensorFlowModel, string[] inputs, string[] outputs) - : this(env, CreateArguments(tensorFlowModel, inputs, outputs), tensorFlowModel) + public TensorFlowEstimator(IHostEnvironment env, string[] outputColumnNames, string[] inputColumnNames, TensorFlowModelInfo tensorFlowModel) + : this(env, CreateArguments(tensorFlowModel, outputColumnNames, inputColumnNames), tensorFlowModel) { } @@ -1115,12 +1115,12 @@ public TensorFlowEstimator(IHostEnvironment env, TensorFlowTransformer.Arguments _outputTypes = outputTuple.outputTypes; } - private static TensorFlowTransformer.Arguments CreateArguments(TensorFlowModelInfo tensorFlowModel, string[] inputs, string[] outputs) + private static TensorFlowTransformer.Arguments CreateArguments(TensorFlowModelInfo tensorFlowModel, string[] outputColumnNames, string[] inputColumnName) { var args = new TensorFlowTransformer.Arguments(); args.ModelLocation = tensorFlowModel.ModelPath; - args.InputColumns = inputs; - args.OutputColumns = outputs; + args.InputColumns = inputColumnName; + args.OutputColumns = outputColumnNames; args.ReTrain = false; return args; } @@ -1155,7 +1155,7 @@ public TensorFlowTransformer Fit(IDataView input) if (_transformer == null) { _transformer = _args.ReTrain ? new TensorFlowTransformer(_host, _args, _tensorFlowModel, input) : - new TensorFlowTransformer(_host, _tensorFlowModel.Session, _args.InputColumns, _args.OutputColumns, + new TensorFlowTransformer(_host, _tensorFlowModel.Session, _args.OutputColumns, _args.InputColumns, TensorFlowUtils.IsSavedModel(_host, _args.ModelLocation) ? _args.ModelLocation : null, false); } // Validate input schema. diff --git a/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs b/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs index 5bb29f7f99..8e8f7e137b 100644 --- a/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs +++ b/src/Microsoft.ML.TimeSeries.StaticPipe/TimeSeriesStatic.cs @@ -61,10 +61,10 @@ public override IEstimator Reconcile(IHostEnvironment env, Contracts.Assert(toOutput.Length == 1); var outCol = (OutColumn)toOutput[0]; return new IidChangePointEstimator(env, - inputNames[outCol.Input], outputNames[outCol], _confidence, _changeHistoryLength, + inputNames[outCol.Input], _martingale, _eps); } @@ -125,10 +125,10 @@ public override IEstimator Reconcile(IHostEnvironment env, Contracts.Assert(toOutput.Length == 1); var outCol = (OutColumn)toOutput[0]; return new IidSpikeEstimator(env, - inputNames[outCol.Input], outputNames[outCol], _confidence, _pvalueHistoryLength, + inputNames[outCol.Input], _side); } } @@ -204,12 +204,12 @@ public override IEstimator Reconcile(IHostEnvironment env, Contracts.Assert(toOutput.Length == 1); var outCol = (OutColumn)toOutput[0]; return new SsaChangePointEstimator(env, - inputNames[outCol.Input], outputNames[outCol], _confidence, _changeHistoryLength, _trainingWindowSize, _seasonalityWindowSize, + inputNames[outCol.Input], _errorFunction, _martingale, _eps); @@ -286,12 +286,12 @@ public override IEstimator Reconcile(IHostEnvironment env, Contracts.Assert(toOutput.Length == 1); var outCol = (OutColumn)toOutput[0]; return new SsaSpikeEstimator(env, - inputNames[outCol.Input], outputNames[outCol], _confidence, _pvalueHistoryLength, _trainingWindowSize, _seasonalityWindowSize, + inputNames[outCol.Input], _side, _errorFunction); } diff --git a/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs b/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs index 110ae76db2..15e7a84a8f 100644 --- a/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs +++ b/src/Microsoft.ML.TimeSeries/ExponentialAverageTransform.cs @@ -58,7 +58,7 @@ private static VersionInfo GetVersionInfo() private readonly Single _decay; public ExponentialAverageTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(1, 1, args.Source, args.Name, LoaderSignature, env, input) + : base(1, 1, args.Name, args.Source, LoaderSignature, env, input) { Host.CheckUserArg(0 <= args.Decay && args.Decay <= 1, nameof(args.Decay), "Should be in [0, 1]."); _decay = args.Decay; diff --git a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs index 2b46f2a6d5..1735d6a4dd 100644 --- a/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidChangePointDetector.cs @@ -212,20 +212,20 @@ public sealed class IidChangePointEstimator : TrivialEstimator /// /// Host Environment. - /// Name of the input column. - /// Name of the output column. Column is a vector of type double and size 4. - /// The vector contains Alert, Raw Score, P-Value and Martingale score as first four values. + /// Name of the column resulting from the transformation of . + /// Column is a vector of type double and size 4. The vector contains Alert, Raw Score, P-Value and Martingale score as first four values. /// The confidence for change point detection in the range [0, 100]. /// The length of the sliding window on p-values for computing the martingale score. + /// Name of column to transform. If set to , the value of the will be used as source. /// The martingale used for scoring. /// The epsilon parameter for the Power martingale. - public IidChangePointEstimator(IHostEnvironment env, string inputColumn, string outputColumn, int confidence, - int changeHistoryLength, MartingaleType martingale = MartingaleType.Power, double eps = 0.1) + public IidChangePointEstimator(IHostEnvironment env, string outputColumnName, int confidence, + int changeHistoryLength, string inputColumnName, MartingaleType martingale = MartingaleType.Power, double eps = 0.1) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidChangePointEstimator)), new IidChangePointDetector(env, new IidChangePointDetector.Arguments { - Name = outputColumn, - Source = inputColumn, + Name = outputColumnName, + Source = inputColumnName ?? outputColumnName, Confidence = confidence, ChangeHistoryLength = changeHistoryLength, Martingale = martingale, diff --git a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs index 7f2fc5d42e..dabbf45d22 100644 --- a/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/IidSpikeDetector.cs @@ -191,18 +191,18 @@ public sealed class IidSpikeEstimator : TrivialEstimator /// Create a new instance of /// /// Host Environment. - /// Name of the input column. - /// Name of the output column. Column is a vector of type double and size 3. - /// The vector contains Alert, Raw Score, P-Value as first three values. + /// Name of the column resulting from the transformation of . + /// Column is a vector of type double and size 4. The vector contains Alert, Raw Score, P-Value as first three values. /// The confidence for spike detection in the range [0, 100]. /// The size of the sliding window for computing the p-value. + /// Name of column to transform. If set to , the value of the will be used as source. /// The argument that determines whether to detect positive or negative anomalies, or both. - public IidSpikeEstimator(IHostEnvironment env, string inputColumn, string outputColumn, int confidence, int pvalueHistoryLength, AnomalySide side = AnomalySide.TwoSided) + public IidSpikeEstimator(IHostEnvironment env, string outputColumnName, int confidence, int pvalueHistoryLength, string inputColumnName, AnomalySide side = AnomalySide.TwoSided) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(IidSpikeDetector)), new IidSpikeDetector(env, new IidSpikeDetector.Arguments { - Name = outputColumn, - Source = inputColumn, + Name = outputColumnName, + Source = inputColumnName, Confidence = confidence, PvalueHistoryLength = pvalueHistoryLength, Side = side diff --git a/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs b/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs index 58d8f3fa61..f207c7b75a 100644 --- a/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs +++ b/src/Microsoft.ML.TimeSeries/MovingAverageTransform.cs @@ -67,7 +67,7 @@ private static VersionInfo GetVersionInfo() private readonly Single[] _weights; public MovingAverageTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(args.WindowSize + args.Lag - 1, args.WindowSize + args.Lag - 1, args.Source, args.Name, LoaderSignature, env, input) + : base(args.WindowSize + args.Lag - 1, args.WindowSize + args.Lag - 1, args.Name, args.Source, LoaderSignature, env, input) { Host.CheckUserArg(args.WindowSize >= 1, nameof(args.WindowSize), "Should be at least 1."); Host.CheckUserArg(args.Lag >= 0, nameof(args.Lag), "Should be positive."); diff --git a/src/Microsoft.ML.TimeSeries/PValueTransform.cs b/src/Microsoft.ML.TimeSeries/PValueTransform.cs index 8fd5724535..d8f3adcb29 100644 --- a/src/Microsoft.ML.TimeSeries/PValueTransform.cs +++ b/src/Microsoft.ML.TimeSeries/PValueTransform.cs @@ -72,7 +72,7 @@ private static VersionInfo GetVersionInfo() private readonly bool _isPositiveSide; public PValueTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(args.WindowSize, args.InitialWindowSize, args.Source, args.Name, LoaderSignature, env, input) + : base(args.WindowSize, args.InitialWindowSize, args.Name, args.Source, LoaderSignature, env, input) { Host.CheckUserArg(args.WindowSize >= 1, nameof(args.WindowSize), "The size of the sliding window should be at least 1."); _seed = args.Seed; diff --git a/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs b/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs index 374d09ec55..868e446899 100644 --- a/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs +++ b/src/Microsoft.ML.TimeSeries/PercentileThresholdTransform.cs @@ -67,7 +67,7 @@ private static VersionInfo GetVersionInfo() private readonly Double _percentile; public PercentileThresholdTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(args.WindowSize, args.WindowSize, args.Source, args.Name, LoaderSignature, env, input) + : base(args.WindowSize, args.WindowSize, args.Name, args.Source, LoaderSignature, env, input) { Host.CheckUserArg(args.WindowSize >= 1, nameof(args.WindowSize), "The size of the sliding window should be at least 1."); Host.CheckUserArg(MinPercentile <= args.Percentile && args.Percentile <= MaxPercentile, nameof(args.Percentile), "The percentile value should be in [0, 100]."); diff --git a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs index 3cbdf5880b..7e82cc889f 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialAnomalyDetectionTransformBase.cs @@ -166,7 +166,7 @@ private static int GetOutputLength(AlertingScore alertingScore, IHostEnvironment private protected SequentialAnomalyDetectionTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, string name, IHostEnvironment env, AnomalySide anomalySide, MartingaleType martingale, AlertingScore alertingScore, Double powerMartingaleEpsilon, Double alertThreshold) - : base(Contracts.CheckRef(env, nameof(env)).Register(name), windowSize, initialWindowSize, inputColumnName, outputColumnName, new VectorType(NumberType.R8, GetOutputLength(alertingScore, env))) + : base(Contracts.CheckRef(env, nameof(env)).Register(name), windowSize, initialWindowSize, outputColumnName, inputColumnName, new VectorType(NumberType.R8, GetOutputLength(alertingScore, env))) { Host.CheckUserArg(Enum.IsDefined(typeof(MartingaleType), martingale), nameof(ArgumentsBase.Martingale), "Value is undefined."); Host.CheckUserArg(Enum.IsDefined(typeof(AnomalySide), anomalySide), nameof(ArgumentsBase.Side), "Value is undefined."); diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs index 9218458df1..febe2d4fe9 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformBase.cs @@ -212,7 +212,7 @@ public void ProcessWithoutBuffer(ref TInput input, ref TOutput output) protected string InputColumnName; protected string OutputColumnName; - private static IDataTransform CreateLambdaTransform(IHost host, IDataView input, string inputColumnName, string outputColumnName, + private static IDataTransform CreateLambdaTransform(IHost host, IDataView input, string outputColumnName, string inputColumnName, Action initFunction, bool hasBuffer, ColumnType outputColTypeOverride) { var inputSchema = SchemaDefinition.Create(typeof(DataBox)); @@ -238,19 +238,19 @@ private static IDataTransform CreateLambdaTransform(IHost host, IDataView input, /// /// The size of buffer used for windowed buffering. /// The number of datapoints picked from the beginning of the series for training the transform parameters if needed. - /// The name of the input column. /// The name of the dst column. - /// + /// The name of the input column. + /// Name of the extending type. /// A reference to the environment variable. /// A reference to the input data view. /// - private protected SequentialTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, + private protected SequentialTransformBase(int windowSize, int initialWindowSize, string outputColumnName, string inputColumnName, string name, IHostEnvironment env, IDataView input, ColumnType outputColTypeOverride = null) - : this(windowSize, initialWindowSize, inputColumnName, outputColumnName, Contracts.CheckRef(env, nameof(env)).Register(name), input, outputColTypeOverride) + : this(windowSize, initialWindowSize, outputColumnName, inputColumnName, Contracts.CheckRef(env, nameof(env)).Register(name), input, outputColTypeOverride) { } - private protected SequentialTransformBase(int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, + private protected SequentialTransformBase(int windowSize, int initialWindowSize, string outputColumnName, string inputColumnName, IHost host, IDataView input, ColumnType outputColTypeOverride = null) : base(host, input) { @@ -267,7 +267,7 @@ private protected SequentialTransformBase(int windowSize, int initialWindowSize, InitialWindowSize = initialWindowSize; WindowSize = windowSize; - _transform = CreateLambdaTransform(Host, input, InputColumnName, OutputColumnName, InitFunction, WindowSize > 0, outputColTypeOverride); + _transform = CreateLambdaTransform(Host, input, OutputColumnName, InputColumnName, InitFunction, WindowSize > 0, outputColTypeOverride); } private protected SequentialTransformBase(IHostEnvironment env, ModelLoadContext ctx, string name, IDataView input) @@ -278,7 +278,7 @@ private protected SequentialTransformBase(IHostEnvironment env, ModelLoadContext // *** Binary format *** // int: _windowSize // int: _initialWindowSize - // int (string ID): _inputColumnName + // int (string ID): _sourceColumnName // int (string ID): _outputColumnName // ColumnType: _transform.Schema.GetColumnType(0) @@ -299,7 +299,7 @@ private protected SequentialTransformBase(IHostEnvironment env, ModelLoadContext BinarySaver bs = new BinarySaver(Host, new BinarySaver.Arguments()); ColumnType ct = bs.LoadTypeDescriptionOrNull(ctx.Reader.BaseStream); - _transform = CreateLambdaTransform(Host, input, InputColumnName, OutputColumnName, InitFunction, WindowSize > 0, ct); + _transform = CreateLambdaTransform(Host, input, OutputColumnName, InputColumnName, InitFunction, WindowSize > 0, ct); } public override void Save(ModelSaveContext ctx) @@ -311,7 +311,7 @@ public override void Save(ModelSaveContext ctx) // *** Binary format *** // int: _windowSize // int: _initialWindowSize - // int (string ID): _inputColumnName + // int (string ID): _sourceColumnName // int (string ID): _outputColumnName // ColumnType: _transform.Schema.GetColumnType(0) diff --git a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs index ce8afda544..28fb5cdc85 100644 --- a/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs +++ b/src/Microsoft.ML.TimeSeries/SequentialTransformerBase.cs @@ -270,10 +270,10 @@ private protected virtual void CloneCore(StateBase state) /// The host. /// The size of buffer used for windowed buffering. /// The number of datapoints picked from the beginning of the series for training the transform parameters if needed. - /// The name of the input column. /// The name of the dst column. + /// The name of the input column. /// - private protected SequentialTransformerBase(IHost host, int windowSize, int initialWindowSize, string inputColumnName, string outputColumnName, ColumnType outputColType) + private protected SequentialTransformerBase(IHost host, int windowSize, int initialWindowSize, string outputColumnName, string inputColumnName, ColumnType outputColType) { Host = host; Host.CheckParam(initialWindowSize >= 0, nameof(initialWindowSize), "Must be non-negative."); @@ -298,7 +298,7 @@ private protected SequentialTransformerBase(IHost host, ModelLoadContext ctx) // *** Binary format *** // int: _windowSize // int: _initialWindowSize - // int (string ID): _inputColumnName + // int (string ID): _sourceColumnName // int (string ID): _outputColumnName // ColumnType: _transform.Schema.GetColumnType(0) @@ -329,7 +329,7 @@ public virtual void Save(ModelSaveContext ctx) // *** Binary format *** // int: _windowSize // int: _initialWindowSize - // int (string ID): _inputColumnName + // int (string ID): _sourceColumnName // int (string ID): _outputColumnName // ColumnType: _transform.Schema.GetColumnType(0) diff --git a/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs b/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs index 5813916b55..1555c98afa 100644 --- a/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs +++ b/src/Microsoft.ML.TimeSeries/SlidingWindowTransformBase.cs @@ -65,7 +65,7 @@ public sealed class Arguments : TransformInputBase private TInput _nanValue; protected SlidingWindowTransformBase(Arguments args, string loaderSignature, IHostEnvironment env, IDataView input) - : base(args.WindowSize + args.Lag - 1, args.WindowSize + args.Lag - 1, args.Source, args.Name, loaderSignature, env, input) + : base(args.WindowSize + args.Lag - 1, args.WindowSize + args.Lag - 1, args.Name, args.Source, loaderSignature, env, input) { Host.CheckUserArg(args.WindowSize >= 1, nameof(args.WindowSize), "Must be at least 1."); Host.CheckUserArg(args.Lag >= 0, nameof(args.Lag), "Must be positive."); diff --git a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs index db74ef1c79..669c023abf 100644 --- a/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaChangePointDetector.cs @@ -37,9 +37,9 @@ namespace Microsoft.ML.TimeSeriesProcessing public sealed class SsaChangePointDetector : SsaAnomalyDetectionBase { internal const string Summary = "This transform detects the change-points in a seasonal time-series using Singular Spectrum Analysis (SSA)."; - public const string LoaderSignature = "SsaChangePointDetector"; - public const string UserName = "SSA Change Point Detection"; - public const string ShortName = "chgpnt"; + internal const string LoaderSignature = "SsaChangePointDetector"; + internal const string UserName = "SSA Change Point Detection"; + internal const string ShortName = "chgpnt"; public sealed class Arguments : TransformInputBase { @@ -225,24 +225,29 @@ public sealed class SsaChangePointEstimator : IEstimator /// Create a new instance of /// /// Host Environment. - /// Name of the input column. - /// Name of the output column. Column is a vector of type double and size 4. - /// The vector contains Alert, Raw Score, P-Value and Martingale score as first four values. + /// Name of the column resulting from the transformation of . + /// Column is a vector of type double and size 4. The vector contains Alert, Raw Score, P-Value and Martingale score as first four values. /// The confidence for change point detection in the range [0, 100]. /// The number of points from the beginning of the sequence used for training. /// The size of the sliding window for computing the p-value. /// An upper bound on the largest relevant seasonality in the input time-series. + /// Name of column to transform. If set to , the value of the will be used as source. /// The function used to compute the error between the expected and the observed value. /// The martingale used for scoring. /// The epsilon parameter for the Power martingale. - public SsaChangePointEstimator(IHostEnvironment env, string inputColumn, string outputColumn, - int confidence, int changeHistoryLength, int trainingWindowSize, int seasonalityWindowSize, + public SsaChangePointEstimator(IHostEnvironment env, string outputColumnName, + int confidence, + int changeHistoryLength, + int trainingWindowSize, + int seasonalityWindowSize, + string inputColumnName = null, ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference, - MartingaleType martingale = MartingaleType.Power, double eps = 0.1) + MartingaleType martingale = MartingaleType.Power, + double eps = 0.1) : this(env, new SsaChangePointDetector.Arguments { - Name = outputColumn, - Source = inputColumn, + Name = outputColumnName, + Source = inputColumnName ?? outputColumnName, Confidence = confidence, ChangeHistoryLength = changeHistoryLength, TrainingWindowSize = trainingWindowSize, diff --git a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs index 5796413589..976a223a5a 100644 --- a/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs +++ b/src/Microsoft.ML.TimeSeries/SsaSpikeDetector.cs @@ -206,22 +206,28 @@ public sealed class SsaSpikeEstimator : IEstimator /// Create a new instance of /// /// Host Environment. - /// Name of the input column. - /// Name of the output column. Column is a vector of type double and size 3. - /// The vector contains Alert, Raw Score, P-Value as first three values. + /// Name of the column resulting from the transformation of . /// The confidence for spike detection in the range [0, 100]. /// The size of the sliding window for computing the p-value. /// The number of points from the beginning of the sequence used for training. /// An upper bound on the largest relevant seasonality in the input time-series. + /// Name of column to transform. If set to , the value of the will be used as source. + /// The vector contains Alert, Raw Score, P-Value as first three values. /// The argument that determines whether to detect positive or negative anomalies, or both. /// The function used to compute the error between the expected and the observed value. - public SsaSpikeEstimator(IHostEnvironment env, string inputColumn, string outputColumn, int confidence, - int pvalueHistoryLength, int trainingWindowSize, int seasonalityWindowSize, AnomalySide side = AnomalySide.TwoSided, + public SsaSpikeEstimator(IHostEnvironment env, + string outputColumnName, + int confidence, + int pvalueHistoryLength, + int trainingWindowSize, + int seasonalityWindowSize, + string inputColumnName = null, + AnomalySide side = AnomalySide.TwoSided, ErrorFunctionUtils.ErrorFunction errorFunction = ErrorFunctionUtils.ErrorFunction.SignedDifference) : this(env, new SsaSpikeDetector.Arguments { - Name = outputColumn, - Source = inputColumn, + Source = inputColumnName ?? outputColumnName, + Name = outputColumnName, Confidence = confidence, PvalueHistoryLength = pvalueHistoryLength, TrainingWindowSize = trainingWindowSize, diff --git a/src/Microsoft.ML.Transforms/CategoricalCatalog.cs b/src/Microsoft.ML.Transforms/CategoricalCatalog.cs index 1b29befae2..68fb4a15ab 100644 --- a/src/Microsoft.ML.Transforms/CategoricalCatalog.cs +++ b/src/Microsoft.ML.Transforms/CategoricalCatalog.cs @@ -16,15 +16,15 @@ public static class CategoricalCatalog /// Convert a text column into one-hot encoded vector. /// /// The transform catalog - /// The input column - /// The output column. If null, is used. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// The conversion mode. /// public static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.CategoricalTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, OneHotEncodingTransformer.OutputKind outputKind = OneHotEncodingTransformer.OutputKind.Ind) - => new OneHotEncodingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, outputKind); + => new OneHotEncodingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, outputKind); /// /// Convert several text column into one-hot encoded vectors. @@ -40,8 +40,8 @@ public static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.Cate /// Convert a text column into hash-based one-hot encoded vector. /// /// The transform catalog - /// The input column - /// The output column. If null, is used. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Number of bits to hash into. Must be between 1 and 30, inclusive. /// During hashing we constuct mappings between original values and the produced hash values. /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. @@ -50,12 +50,12 @@ public static OneHotEncodingEstimator OneHotEncoding(this TransformsCatalog.Cate /// The conversion mode. /// public static OneHotHashEncodingEstimator OneHotHashEncoding(this TransformsCatalog.CategoricalTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, int hashBits = OneHotHashEncodingEstimator.Defaults.HashBits, int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash, OneHotEncodingTransformer.OutputKind outputKind = OneHotEncodingTransformer.OutputKind.Ind) - => new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, hashBits, invertHash, outputKind); + => new OneHotHashEncodingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName ?? outputColumnName, hashBits, invertHash, outputKind); /// /// Convert several text column into hash-based one-hot encoded vectors. diff --git a/src/Microsoft.ML.Transforms/ConversionsCatalog.cs b/src/Microsoft.ML.Transforms/ConversionsCatalog.cs index 53c9e5aca5..2e5bde64c0 100644 --- a/src/Microsoft.ML.Transforms/ConversionsCatalog.cs +++ b/src/Microsoft.ML.Transforms/ConversionsCatalog.cs @@ -26,11 +26,11 @@ public static KeyToBinaryVectorMappingEstimator MapKeyToBinaryVector(this Transf /// Convert the key types back to binary verctor. /// /// The categorical transform's catalog. - /// The name of the input column of the transformation. - /// The name of the column produced by the transformation. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. public static KeyToBinaryVectorMappingEstimator MapKeyToBinaryVector(this TransformsCatalog.ConversionTransforms catalog, - string inputColumn, - string outputColumn = null) - => new KeyToBinaryVectorMappingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn); + string outputColumnName, + string inputColumnName = null) + => new KeyToBinaryVectorMappingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName); } } diff --git a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs index 0f1bdeef36..6b7320406e 100644 --- a/src/Microsoft.ML.Transforms/CountFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/CountFeatureSelection.cs @@ -48,22 +48,23 @@ public sealed class Arguments : TransformInputBase public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly long MinCount; /// /// Describes the parameters of the feature selection process for a column pair. /// - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// If the count of non-default values for a slot is greater than or equal to this threshold in the training data, the slot is preserved. - public ColumnInfo(string input, string output = null, long minCount = Defaults.Count) + public ColumnInfo(string name, string inputColumnName = null, long minCount = Defaults.Count) { - Input = input; - Contracts.CheckValue(Input, nameof(Input)); - Output = output ?? input; - Contracts.CheckValue(Output, nameof(Output)); + Name = name; + Contracts.CheckValue(Name, nameof(Name)); + + InputColumnName = inputColumnName ?? name; + Contracts.CheckValue(InputColumnName, nameof(InputColumnName)); MinCount = minCount; } } @@ -89,8 +90,8 @@ public CountFeatureSelectingEstimator(IHostEnvironment env, params ColumnInfo[] /// /// The environment to use. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// If the count of non-default values for a slot is greater than or equal to this threshold in the training data, the slot is preserved. /// /// @@ -99,8 +100,8 @@ public CountFeatureSelectingEstimator(IHostEnvironment env, params ColumnInfo[] /// ]]> /// /// - public CountFeatureSelectingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, long minCount = Defaults.Count) - : this(env, new ColumnInfo(inputColumn, outputColumn ?? inputColumn, minCount)) + public CountFeatureSelectingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, long minCount = Defaults.Count) + : this(env, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, minCount)) { } @@ -110,17 +111,17 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in _columns) { - if (!inputSchema.TryFindColumn(colPair.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.Input); + if (!inputSchema.TryFindColumn(colPair.InputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.InputColumnName); if (!CountFeatureSelectionUtils.IsValidColumnType(col.ItemType)) - throw _host.ExceptUserArg(nameof(inputSchema), "Column '{0}' does not have compatible type. Expected types are float, double or string.", colPair.Input); + throw _host.ExceptUserArg(nameof(inputSchema), "Column '{0}' does not have compatible type. Expected types are float, double or string.", colPair.InputColumnName); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) metadata.Add(slotMeta); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.CategoricalSlotRanges, out var categoricalSlotMeta)) metadata.Add(categoricalSlotMeta); metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); - result[colPair.Output] = new SchemaShape.Column(colPair.Output, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray())); + result[colPair.Name] = new SchemaShape.Column(colPair.Name, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray())); } return new SchemaShape(result.Values); } @@ -130,18 +131,18 @@ public ITransformer Fit(IDataView input) _host.CheckValue(input, nameof(input)); int[] colSizes; - var scores = CountFeatureSelectionUtils.Train(_host, input, _columns.Select(column => column.Input).ToArray(), out colSizes); + var scores = CountFeatureSelectionUtils.Train(_host, input, _columns.Select(column => column.InputColumnName).ToArray(), out colSizes); var size = _columns.Length; using (var ch = _host.Start("Dropping Slots")) { // If no slots should be dropped from a column, use copy column to generate the corresponding output column. SlotsDroppingTransformer.ColumnInfo[] dropSlotsColumns; - (string input, string output)[] copyColumnsPairs; + (string outputColumnName, string inputColumnName)[] copyColumnsPairs; CreateDropAndCopyColumns(_columns, size, scores, out int[] selectedCount, out dropSlotsColumns, out copyColumnsPairs); for (int i = 0; i < selectedCount.Length; i++) - ch.Info(MessageSensitivity.Schema, "Selected {0} slots out of {1} in column '{2}'", selectedCount[i], colSizes[i], _columns[i].Input); + ch.Info(MessageSensitivity.Schema, "Selected {0} slots out of {1} in column '{2}'", selectedCount[i], colSizes[i], _columns[i].InputColumnName); ch.Info("Total number of slots selected: {0}", selectedCount.Sum()); if (dropSlotsColumns.Length <= 0) @@ -176,7 +177,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat } private static void CreateDropAndCopyColumns(ColumnInfo[] columnInfos, int size, long[][] scores, - out int[] selectedCount, out SlotsDroppingTransformer.ColumnInfo[] dropSlotsColumns, out (string input, string output)[] copyColumnsPairs) + out int[] selectedCount, out SlotsDroppingTransformer.ColumnInfo[] dropSlotsColumns, out (string outputColumnName, string inputColumnName)[] copyColumnsPairs) { Contracts.Assert(size > 0); Contracts.Assert(Utils.Size(scores) == size); @@ -185,7 +186,7 @@ private static void CreateDropAndCopyColumns(ColumnInfo[] columnInfos, int size, selectedCount = new int[scores.Length]; var dropSlotsCols = new List(); - var copyCols = new List<(string input, string output)>(); + var copyCols = new List<(string outputColumnName, string inputColumnName)>(); for (int i = 0; i < size; i++) { var slots = new List<(int min, int? max)>(); @@ -208,9 +209,9 @@ private static void CreateDropAndCopyColumns(ColumnInfo[] columnInfos, int size, selectedCount[i]++; } if (slots.Count <= 0) - copyCols.Add((columnInfos[i].Input, columnInfos[i].Output)); + copyCols.Add((columnInfos[i].Name, columnInfos[i].InputColumnName)); else - dropSlotsCols.Add(new SlotsDroppingTransformer.ColumnInfo(columnInfos[i].Input, columnInfos[i].Output, slots.ToArray())); + dropSlotsCols.Add(new SlotsDroppingTransformer.ColumnInfo(columnInfos[i].Name, columnInfos[i].InputColumnName, slots.ToArray())); } dropSlotsColumns = dropSlotsCols.ToArray(); copyColumnsPairs = copyCols.ToArray(); diff --git a/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs b/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs index 11acd22eae..71f8464892 100644 --- a/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs +++ b/src/Microsoft.ML.Transforms/ExtensionsCatalog.cs @@ -16,41 +16,41 @@ public static class ExtensionsCatalog /// The transform extensions' catalog. /// The names of the input columns of the transformation and the corresponding names for the output columns. public static MissingValueIndicatorEstimator IndicateMissingValues(this TransformsCatalog catalog, - params (string inputColumn, string outputColumn)[] columns) + params (string outputColumnName, string inputColumnName)[] columns) => new MissingValueIndicatorEstimator(CatalogUtils.GetEnvironment(catalog), columns); /// - /// Creates a new output column, or replaces the inputColumn with a new column - /// (depending on whether the is given a value, or left to null) + /// Creates a new output column, or replaces the source with a new column + /// (depending on whether the is given a value, or left to null) /// of boolean type, with the same number of slots as the input column. The value in the output column /// is true if the value in the input column is missing. /// /// The transform extensions' catalog. - /// The name of the input column of the transformation. - /// The name of the optional column produced by the transformation. - /// If left to null the will get replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. + /// If left to null the will get replaced. public static MissingValueIndicatorEstimator IndicateMissingValues(this TransformsCatalog catalog, - string inputColumn, - string outputColumn = null) - => new MissingValueIndicatorEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn); + string outputColumnName, + string inputColumnName = null) + => new MissingValueIndicatorEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName); /// - /// Creates a new output column, or replaces the inputColumn with a new column - /// (depending on whether the is given a value, or left to null) + /// Creates a new output column, or replaces the source with a new column + /// (depending on whether the is given a value, or left to null) /// identical to the input column for everything but the missing values. The missing values of the input column, in this new column are replaced with /// one of the values specifid in the . The default for the is /// . /// /// The transform extensions' catalog. - /// The name of the input column. - /// The optional name of the output column, - /// If not provided, the will be replaced with the results of the transforms. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. + /// If not provided, the will be replaced with the results of the transforms. /// The type of replacement to use as specified in public static MissingValueReplacingEstimator ReplaceMissingValues(this TransformsCatalog catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementKind = MissingValueReplacingEstimator.Defaults.ReplacementMode) - => new MissingValueReplacingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, replacementKind); + => new MissingValueReplacingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, replacementKind); /// /// Creates a new output column, identical to the input column for everything but the missing values. diff --git a/src/Microsoft.ML.Transforms/FeatureSelectionCatalog.cs b/src/Microsoft.ML.Transforms/FeatureSelectionCatalog.cs index e6e5e524b5..427d961f85 100644 --- a/src/Microsoft.ML.Transforms/FeatureSelectionCatalog.cs +++ b/src/Microsoft.ML.Transforms/FeatureSelectionCatalog.cs @@ -29,13 +29,13 @@ public static MutualInformationFeatureSelectingEstimator SelectFeaturesBasedOnMu string labelColumn = MutualInfoSelectDefaults.LabelColumn, int slotsInOutput = MutualInfoSelectDefaults.SlotsInOutput, int numBins = MutualInfoSelectDefaults.NumBins, - params (string input, string output)[] columns) + params (string outputColumnName, string inputColumnName)[] columns) => new MutualInformationFeatureSelectingEstimator(CatalogUtils.GetEnvironment(catalog), labelColumn, slotsInOutput, numBins, columns); /// /// The transform's catalog. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Name of the column to use for labels. /// The maximum number of slots to preserve in the output. The number of slots to preserve is taken across all input columns. /// Max number of bins used to approximate mutual information between each input column and the label column. Power of 2 recommended. @@ -47,11 +47,11 @@ public static MutualInformationFeatureSelectingEstimator SelectFeaturesBasedOnMu /// /// public static MutualInformationFeatureSelectingEstimator SelectFeaturesBasedOnMutualInformation(this TransformsCatalog.FeatureSelectionTransforms catalog, - string inputColumn, string outputColumn = null, + string outputColumnName, string inputColumnName = null, string labelColumn = MutualInfoSelectDefaults.LabelColumn, int slotsInOutput = MutualInfoSelectDefaults.SlotsInOutput, int numBins = MutualInfoSelectDefaults.NumBins) - => new MutualInformationFeatureSelectingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, labelColumn, slotsInOutput, numBins); + => new MutualInformationFeatureSelectingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, labelColumn, slotsInOutput, numBins); /// /// The transform's catalog. @@ -69,8 +69,8 @@ public static CountFeatureSelectingEstimator SelectFeaturesBasedOnCount(this Tra /// /// The transform's catalog. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// If the count of non-default values for a slot is greater than or equal to this threshold in the training data, the slot is preserved. /// /// @@ -80,9 +80,9 @@ public static CountFeatureSelectingEstimator SelectFeaturesBasedOnCount(this Tra /// /// public static CountFeatureSelectingEstimator SelectFeaturesBasedOnCount(this TransformsCatalog.FeatureSelectionTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, long count = CountSelectDefaults.Count) - => new CountFeatureSelectingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, count); + => new CountFeatureSelectingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, count); } } diff --git a/src/Microsoft.ML.Transforms/GcnTransform.cs b/src/Microsoft.ML.Transforms/GcnTransform.cs index 507e294d0c..e23d5cbf5f 100644 --- a/src/Microsoft.ML.Transforms/GcnTransform.cs +++ b/src/Microsoft.ML.Transforms/GcnTransform.cs @@ -155,16 +155,16 @@ public sealed class GcnColumnInfo : ColumnInfoBase /// /// Describes how the transformer handles one Gcn column pair. /// - /// Name of input column. - /// Name of output column. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Subtract mean from each value before normalizing. /// Normalize by standard deviation rather than L2 norm. /// Scale features by this value. - public GcnColumnInfo(string input, string output, + public GcnColumnInfo(string name, string inputColumnName = null, bool substractMean = LpNormalizingEstimatorBase.Defaults.GcnSubstractMean, bool useStdDev = LpNormalizingEstimatorBase.Defaults.UseStdDev, float scale = LpNormalizingEstimatorBase.Defaults.Scale) - : base(input, output, substractMean, useStdDev ? LpNormalizingEstimatorBase.NormalizerKind.StdDev : LpNormalizingEstimatorBase.NormalizerKind.L2Norm, scale) + : base(name, inputColumnName, substractMean, useStdDev ? LpNormalizingEstimatorBase.NormalizerKind.StdDev : LpNormalizingEstimatorBase.NormalizerKind.L2Norm, scale) { } } @@ -177,22 +177,22 @@ public sealed class LpNormColumnInfo : ColumnInfoBase /// /// Describes how the transformer handles one LpNorm column pair. /// - /// Name of input column. - /// Name of output column. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Subtract mean from each value before normalizing. /// The norm to use to normalize each sample. - public LpNormColumnInfo(string input, string output, + public LpNormColumnInfo(string name, string inputColumnName = null, bool substractMean = LpNormalizingEstimatorBase.Defaults.LpSubstractMean, LpNormalizingEstimatorBase.NormalizerKind normalizerKind = LpNormalizingEstimatorBase.Defaults.NormKind) - : base(input, output, substractMean, normalizerKind, 1) + : base(name, inputColumnName ?? name, substractMean, normalizerKind, 1) { } } private sealed class ColumnInfoLoaded : ColumnInfoBase { - internal ColumnInfoLoaded(ModelLoadContext ctx, string input, string output, bool normKindSerialized) - : base(ctx, input, output, normKindSerialized) + internal ColumnInfoLoaded(ModelLoadContext ctx, string name, string inputColumnName, bool normKindSerialized) + : base(ctx, name, inputColumnName, normKindSerialized) { } @@ -203,31 +203,31 @@ internal ColumnInfoLoaded(ModelLoadContext ctx, string input, string output, boo /// public abstract class ColumnInfoBase { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly bool SubtractMean; public readonly LpNormalizingEstimatorBase.NormalizerKind NormKind; public readonly float Scale; - internal ColumnInfoBase(string input, string output, bool substractMean, LpNormalizingEstimatorBase.NormalizerKind normalizerKind, float scale) + internal ColumnInfoBase(string name, string inputColumnName, bool substractMean, LpNormalizingEstimatorBase.NormalizerKind normalizerKind, float scale) { - Contracts.CheckNonWhiteSpace(input, nameof(input)); - Contracts.CheckNonWhiteSpace(output, nameof(output)); - Input = input; - Output = output; + Contracts.CheckNonWhiteSpace(name, nameof(name)); + Contracts.CheckNonWhiteSpace(inputColumnName, nameof(inputColumnName)); + Name = name; + InputColumnName = inputColumnName; SubtractMean = substractMean; Contracts.CheckUserArg(0 < scale && scale < float.PositiveInfinity, nameof(scale), "scale must be a positive finite value"); Scale = scale; NormKind = normalizerKind; } - internal ColumnInfoBase(ModelLoadContext ctx, string input, string output, bool normKindSerialized) + internal ColumnInfoBase(ModelLoadContext ctx, string name, string inputColumnName, bool normKindSerialized) { Contracts.AssertValue(ctx); - Contracts.CheckNonWhiteSpace(input, nameof(input)); - Contracts.CheckNonWhiteSpace(output, nameof(output)); - Input = input; - Output = output; + Contracts.CheckNonWhiteSpace(inputColumnName, nameof(inputColumnName)); + Contracts.CheckNonWhiteSpace(name, nameof(name)); + Name = name; + InputColumnName = inputColumnName; // *** Binary format *** // byte: SubtractMean @@ -299,10 +299,10 @@ private static VersionInfo GetVersionInfo() public IReadOnlyCollection Columns => _columns.AsReadOnly(); private readonly ColumnInfoBase[] _columns; - private static (string input, string output)[] GetColumnPairs(ColumnInfoBase[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfoBase[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) @@ -334,8 +334,9 @@ internal static IDataTransform Create(IHostEnvironment env, GcnArguments args, I for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - cols[i] = new GcnColumnInfo(item.Source ?? item.Name, + cols[i] = new GcnColumnInfo( item.Name, + item.Source ?? item.Name, item.SubMean ?? args.SubMean, item.UseStdDev ?? args.UseStdDev, item.Scale ?? args.Scale); @@ -360,8 +361,9 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - cols[i] = new LpNormColumnInfo(item.Source ?? item.Name, + cols[i] = new LpNormColumnInfo( item.Name, + item.Source ?? item.Name, item.SubMean ?? args.SubMean, item.NormKind ?? args.NormKind); } @@ -403,7 +405,7 @@ private LpNormalizingTransformer(IHost host, ModelLoadContext ctx) var columnsLength = ColumnPairs.Length; _columns = new ColumnInfoLoaded[columnsLength]; for (int i = 0; i < columnsLength; i++) - _columns[i] = new ColumnInfoLoaded(ctx, ColumnPairs[i].input, ColumnPairs[i].output, ctx.Header.ModelVerWritten >= VerVectorNormalizerSupported); + _columns[i] = new ColumnInfoLoaded(ctx, ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, ctx.Header.ModelVerWritten >= VerVectorNormalizerSupported); } public override void Save(ModelSaveContext ctx) @@ -441,7 +443,7 @@ public Mapper(LpNormalizingTransformer parent, Schema inputSchema) _srcCols = new int[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i]); + inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i]); var srcCol = inputSchema[_srcCols[i]]; _srcTypes[i] = srcCol.Type; _types[i] = srcCol.Type; @@ -457,7 +459,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() builder.Add(InputSchema[ColMapNewToOld[i]].Metadata, name => name == MetadataUtils.Kinds.SlotNames); ValueGetter getter = (ref bool dst) => dst = true; builder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, getter); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i], builder.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], builder.GetMetadata()); } return result; } @@ -801,15 +803,15 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colPair.Input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.Input); + if (!inputSchema.TryFindColumn(colPair.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.InputColumnName); if (!IsSchemaColumnValid(col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.Input, ExpectedColumnType, col.GetTypeString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.InputColumnName, ExpectedColumnType, col.GetTypeString()); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) metadata.Add(slotMeta); metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); - result[colPair.Output] = new SchemaShape.Column(colPair.Output, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray())); + result[colPair.Name] = new SchemaShape.Column(colPair.Name, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray())); } return new SchemaShape(result.Values); } @@ -822,13 +824,13 @@ public sealed class LpNormalizingEstimator : LpNormalizingEstimatorBase { /// /// The environment. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Type of norm to use to normalize each sample. /// Subtract mean from each value before normalizing. - public LpNormalizingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, + public LpNormalizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, NormalizerKind normKind = Defaults.NormKind, bool substractMean = Defaults.LpSubstractMean) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, normKind, substractMean) + : this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, normKind, substractMean) { } @@ -837,9 +839,9 @@ public LpNormalizingEstimator(IHostEnvironment env, string inputColumn, string o /// Pairs of columns to run the normalization on. /// Type of norm to use to normalize each sample. /// Subtract mean from each value before normalizing. - public LpNormalizingEstimator(IHostEnvironment env, (string input, string output)[] columns, + public LpNormalizingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, NormalizerKind normKind = Defaults.NormKind, bool substractMean = Defaults.LpSubstractMean) - : this(env, columns.Select(x => new LpNormalizingTransformer.LpNormColumnInfo(x.input, x.output, substractMean, normKind)).ToArray()) + : this(env, columns.Select(x => new LpNormalizingTransformer.LpNormColumnInfo(x.outputColumnName, x.inputColumnName, substractMean, normKind)).ToArray()) { } @@ -859,14 +861,14 @@ public sealed class GlobalContrastNormalizingEstimator : LpNormalizingEstimatorB { /// /// The environment. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Subtract mean from each value before normalizing. /// Normalize by standard deviation rather than L2 norm. /// Scale features by this value. - public GlobalContrastNormalizingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, + public GlobalContrastNormalizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, bool substractMean = Defaults.GcnSubstractMean, bool useStdDev = Defaults.UseStdDev, float scale = Defaults.Scale) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, substractMean, useStdDev, scale) + : this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, substractMean, useStdDev, scale) { } @@ -876,9 +878,9 @@ public GlobalContrastNormalizingEstimator(IHostEnvironment env, string inputColu /// Subtract mean from each value before normalizing. /// Normalize by standard deviation rather than L2 norm. /// Scale features by this value. - public GlobalContrastNormalizingEstimator(IHostEnvironment env, (string input, string output)[] columns, + public GlobalContrastNormalizingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, bool substractMean = Defaults.GcnSubstractMean, bool useStdDev = Defaults.UseStdDev, float scale = Defaults.Scale) - : this(env, columns.Select(x => new LpNormalizingTransformer.GcnColumnInfo(x.input, x.output, substractMean, useStdDev, scale)).ToArray()) + : this(env, columns.Select(x => new LpNormalizingTransformer.GcnColumnInfo(x.outputColumnName, x.inputColumnName, substractMean, useStdDev, scale)).ToArray()) { } diff --git a/src/Microsoft.ML.Transforms/GroupTransform.cs b/src/Microsoft.ML.Transforms/GroupTransform.cs index 15131205e4..cfcdef497f 100644 --- a/src/Microsoft.ML.Transforms/GroupTransform.cs +++ b/src/Microsoft.ML.Transforms/GroupTransform.cs @@ -59,11 +59,11 @@ namespace Microsoft.ML.Transforms /// public sealed class GroupTransform : TransformBase { - public const string Summary = "Groups values of a scalar column into a vector, by a contiguous group ID"; - public const string UserName = "Group Transform"; - public const string ShortName = "Group"; + internal const string Summary = "Groups values of a scalar column into a vector, by a contiguous group ID"; + internal const string UserName = "Group Transform"; + internal const string ShortName = "Group"; private const string RegistrationName = "GroupTransform"; - public const string LoaderSignature = "GroupTransform"; + internal const string LoaderSignature = "GroupTransform"; private static VersionInfo GetVersionInfo() { diff --git a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs index 85d371e724..a357df8fec 100644 --- a/src/Microsoft.ML.Transforms/HashJoiningTransform.cs +++ b/src/Microsoft.ML.Transforms/HashJoiningTransform.cs @@ -155,7 +155,7 @@ private static KeyType GetItemType(int hashBits) internal const string UserName = "Hash Join Transform"; - public const string LoaderSignature = "HashJoinTransform"; + internal const string LoaderSignature = "HashJoinTransform"; private static VersionInfo GetVersionInfo() { return new VersionInfo( diff --git a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs index d1514aa920..871197afdd 100644 --- a/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs +++ b/src/Microsoft.ML.Transforms/KeyToVectorMapping.cs @@ -43,20 +43,21 @@ public sealed class Arguments /// public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; /// /// Describes how the transformer handles one column pair. /// - /// Name of input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. + /// If set to , the value of the will be used as source. - public ColumnInfo(string input, string output = null) + public ColumnInfo(string name, string inputColumnName = null) { - Contracts.CheckNonWhiteSpace(input, nameof(input)); - Input = input; - Output = output ?? input; + Contracts.CheckNonWhiteSpace(name, nameof(name)); + Name = name; + InputColumnName = inputColumnName ?? name; } } @@ -77,10 +78,10 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "KeyToBinary"; - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } public IReadOnlyCollection Columns => _columns.AsReadOnly(); @@ -98,7 +99,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol var type = inputSchema[srcCol].Type; string reason = TestIsKey(type); if (reason != null) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, reason, type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, reason, type.ToString()); } public KeyToBinaryVectorMappingTransformer(IHostEnvironment env, params ColumnInfo[] columns) @@ -137,7 +138,7 @@ private KeyToBinaryVectorMappingTransformer(IHost host, ModelLoadContext ctx) { _columns = new ColumnInfo[ColumnPairs.Length]; for (int i = 0; i < ColumnPairs.Length; i++) - _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output); + _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName); } private static IDataTransform Create(IHostEnvironment env, IDataView input, params ColumnInfo[] columns) => @@ -157,7 +158,7 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - cols[i] = new ColumnInfo(item.Source ?? item.Name, item.Name); + cols[i] = new ColumnInfo(item.Name, item.Source ?? item.Name); }; } return new KeyToBinaryVectorMappingTransformer(env, cols).MakeDataTransform(input); @@ -178,13 +179,13 @@ private sealed class Mapper : OneToOneMapperBase private sealed class ColInfo { public readonly string Name; - public readonly string Source; + public readonly string InputColumnName; public readonly ColumnType TypeSrc; - public ColInfo(string name, string source, ColumnType type) + public ColInfo(string name, string inputColumnName, ColumnType type) { Name = name; - Source = source; + InputColumnName = inputColumnName; TypeSrc = type; } } @@ -221,11 +222,11 @@ private ColInfo[] CreateInfos(Schema inputSchema) var infos = new ColInfo[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName); var type = inputSchema[colSrc].Type; _parent.CheckInputColumn(inputSchema, i, colSrc); - infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); + infos[i] = new ColInfo(_parent.ColumnPairs[i].outputColumnName, _parent.ColumnPairs[i].inputColumnName, type); } return infos; } @@ -235,19 +236,19 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colIndex); Host.Assert(colIndex >= 0); var builder = new MetadataBuilder(); AddMetadata(i, builder); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i], builder.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], builder.GetMetadata()); } return result; } private void AddMetadata(int iinfo, MetadataBuilder builder) { - InputSchema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol); + InputSchema.TryGetColumnIndex(_infos[iinfo].InputColumnName, out int srcCol); var inputMetadata = InputSchema[srcCol].Metadata; var srcType = _infos[iinfo].TypeSrc; // See if the source has key names. @@ -319,7 +320,7 @@ private void GetSlotNames(int iinfo, ref VBuffer> dst) // Get the source slot names, defaulting to empty text. var namesSlotSrc = default(VBuffer>); - var inputMetadata = InputSchema[_infos[iinfo].Source].Metadata; + var inputMetadata = InputSchema[_infos[iinfo].InputColumnName].Metadata; VectorType typeSlotSrc = null; if (inputMetadata != null) typeSlotSrc = inputMetadata.Schema.GetColumnOrNull(MetadataUtils.Kinds.SlotNames)?.Type as VectorType; @@ -388,7 +389,7 @@ private ValueGetter> MakeGetterOne(Row input, int iinfo) int dstLength = _types[iinfo].Size; Host.Assert(dstLength > 0); - input.Schema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol); + input.Schema.TryGetColumnIndex(_infos[iinfo].InputColumnName, out int srcCol); Host.Assert(srcCol >= 0); var getSrc = RowCursorUtils.GetGetterAs(NumberType.U4, input, srcCol); var src = default(uint); @@ -416,7 +417,7 @@ private ValueGetter> MakeGetterInd(Row input, int iinfo, VectorTy int cv = typeSrc.Size; Host.Assert(cv >= 0); - input.Schema.TryGetColumnIndex(_infos[iinfo].Source, out int srcCol); + input.Schema.TryGetColumnIndex(_infos[iinfo].InputColumnName, out int srcCol); Host.Assert(srcCol >= 0); var getSrc = RowCursorUtils.GetVecGetterAs(NumberType.U4, input, srcCol); var src = default(VBuffer); @@ -463,8 +464,8 @@ public KeyToBinaryVectorMappingEstimator(IHostEnvironment env, params KeyToBinar { } - public KeyToBinaryVectorMappingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null) - : this(env, new KeyToBinaryVectorMappingTransformer(env, new KeyToBinaryVectorMappingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn))) + public KeyToBinaryVectorMappingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null) + : this(env, new KeyToBinaryVectorMappingTransformer(env, new KeyToBinaryVectorMappingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName))) { } @@ -479,10 +480,10 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (!(col.ItemType is VectorType || col.ItemType is PrimitiveType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var keyMeta)) @@ -490,7 +491,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, keyMeta.ItemType, false)); if (col.Kind == SchemaShape.Column.VectorKind.Scalar) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs index e58357cb19..ddc4b69a22 100644 --- a/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueDroppingTransformer.cs @@ -75,14 +75,14 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = "DropNAs"; - public IReadOnlyList<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); + public IReadOnlyList<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); /// /// Initializes a new instance of /// /// The environment to use. /// The names of the input columns of the transformation and the corresponding names for the output columns. - public MissingValueDroppingTransformer(IHostEnvironment env, params (string input, string output)[] columns) + public MissingValueDroppingTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueDroppingTransformer)), columns) { } @@ -98,8 +98,8 @@ private MissingValueDroppingTransformer(IHostEnvironment env, ModelLoadContext c Host.CheckValue(ctx, nameof(ctx)); } - private static (string input, string output)[] GetColumnPairs(Column[] columns) - => columns.Select(c => (c.Source ?? c.Name, c.Name)).ToArray(); + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(Column[] columns) + => columns.Select(c => (c.Name, c.Source ?? c.Name)).ToArray(); protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { @@ -161,7 +161,7 @@ public Mapper(MissingValueDroppingTransformer parent, Schema inputSchema) : _isNAs = new Delegate[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i]); + inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i]); var srcCol = inputSchema[_srcCols[i]]; _srcTypes[i] = srcCol.Type; _types[i] = new VectorType((PrimitiveType)srcCol.Type.GetItemType()); @@ -187,7 +187,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var builder = new MetadataBuilder(); builder.Add(InputSchema[ColMapNewToOld[i]].Metadata, x => x == MetadataUtils.Kinds.KeyValues || x == MetadataUtils.Kinds.IsNormalized); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i], builder.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], builder.GetMetadata()); } return result; } @@ -348,7 +348,7 @@ public sealed class MissingValueDroppingEstimator : TrivialEstimator /// The environment to use. /// The names of the input columns of the transformation and the corresponding names for the output columns. - public MissingValueDroppingEstimator(IHostEnvironment env, params (string input, string output)[] columns) + public MissingValueDroppingEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueDroppingEstimator)), new MissingValueDroppingTransformer(env, columns)) { Contracts.CheckValue(env, nameof(env)); @@ -358,10 +358,10 @@ public MissingValueDroppingEstimator(IHostEnvironment env, params (string input, /// Drops missing values from columns. /// /// The environment to use. - /// The name of the input column of the transformation. - /// The name of the column produced by the transformation. - public MissingValueDroppingEstimator(IHostEnvironment env, string input, string output = null) - : this(env, (input, output ?? input)) + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. + public MissingValueDroppingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null) + : this(env, (outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -374,16 +374,16 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colPair.input, out var col) || !Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.input); + if (!inputSchema.TryFindColumn(colPair.inputColumnName, out var col) || !Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName); if (!(col.Kind == SchemaShape.Column.VectorKind.Vector || col.Kind == SchemaShape.Column.VectorKind.VariableVector)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.input, "Vector", col.GetTypeString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName, "Vector", col.GetTypeString()); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.KeyValues, out var keyMeta)) metadata.Add(keyMeta); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.IsNormalized, out var normMeta)) metadata.Add(normMeta); - result[colPair.output] = new SchemaShape.Column(colPair.output, SchemaShape.Column.VectorKind.VariableVector, col.ItemType, false, new SchemaShape(metadata.ToArray())); + result[colPair.outputColumnName] = new SchemaShape.Column(colPair.outputColumnName, SchemaShape.Column.VectorKind.VariableVector, col.ItemType, false, new SchemaShape(metadata.ToArray())); } return new SchemaShape(result.Values); } diff --git a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs index 74bfc5a40f..d63c9792b2 100644 --- a/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueHandlingTransformer.cs @@ -114,16 +114,16 @@ internal bool TryUnparse(StringBuilder sb) /// /// Host Environment. /// Input . This is the output from previous transform or loader. - /// Name of the output column. - /// Name of the column to be transformed. If this is null '' will be used. + /// Name of the output column. + /// Name of the column to be transformed. If this is null '' will be used. /// The replacement method to utilize. - public static IDataView Create(IHostEnvironment env, IDataView input, string name, string source = null, ReplacementKind replaceWith = ReplacementKind.DefaultValue) + public static IDataView Create(IHostEnvironment env, IDataView input, string outputColumnName, string inputColumnName = null, ReplacementKind replaceWith = ReplacementKind.DefaultValue) { var args = new Arguments() { Column = new[] { - new Column() { Source = source ?? name, Name = name } + new Column() { Name = outputColumnName, Source = inputColumnName ?? outputColumnName } }, ReplaceWith = replaceWith }; @@ -153,7 +153,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat var addInd = column.ConcatIndicator ?? args.Concat; if (!addInd) { - replaceCols.Add(new MissingValueReplacingTransformer.ColumnInfo(column.Source, column.Name, (MissingValueReplacingTransformer.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); + replaceCols.Add(new MissingValueReplacingTransformer.ColumnInfo(column.Name, column.Source,(MissingValueReplacingTransformer.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); continue; } @@ -183,11 +183,11 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat { throw h.Except("Cannot get a DataKind for type '{0}'", replaceItemType.RawType); } - naConvCols.Add(new TypeConvertingTransformer.ColumnInfo(tmpIsMissingColName, tmpIsMissingColName, replaceItemTypeKind)); + naConvCols.Add(new TypeConvertingTransformer.ColumnInfo(tmpIsMissingColName, replaceItemTypeKind, tmpIsMissingColName)); } // Add the NAReplaceTransform column. - replaceCols.Add(new MissingValueReplacingTransformer.ColumnInfo(column.Source, tmpReplacementColName, (MissingValueReplacingTransformer.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); + replaceCols.Add(new MissingValueReplacingTransformer.ColumnInfo(tmpReplacementColName, column.Source, (MissingValueReplacingTransformer.ColumnInfo.ReplacementMode)(column.Kind ?? args.ReplaceWith), column.ImputeBySlot ?? args.ImputeBySlot)); // Add the ConcatTransform column. if (replaceType is VectorType) diff --git a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs index 5e0a8ce10f..10386ea02e 100644 --- a/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs +++ b/src/Microsoft.ML.Transforms/MissingValueIndicatorTransformer.cs @@ -78,14 +78,14 @@ private static VersionInfo GetVersionInfo() private const string RegistrationName = nameof(MissingValueIndicatorTransformer); - public IReadOnlyList<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); + public IReadOnlyList<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); /// /// Initializes a new instance of /// /// The environment to use. /// The names of the input columns of the transformation and the corresponding names for the output columns. - public MissingValueIndicatorTransformer(IHostEnvironment env, params (string input, string output)[] columns) + public MissingValueIndicatorTransformer(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueIndicatorTransformer)), columns) { } @@ -101,8 +101,8 @@ private MissingValueIndicatorTransformer(IHostEnvironment env, ModelLoadContext Host.CheckValue(ctx, nameof(ctx)); } - private static (string input, string output)[] GetColumnPairs(Column[] columns) - => columns.Select(c => (c.Source ?? c.Name, c.Name)).ToArray(); + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(Column[] columns) + => columns.Select(c => (c.Name, c.Source ?? c.Name)).ToArray(); // Factory method for SignatureLoadModel internal static MissingValueIndicatorTransformer Create(IHostEnvironment env, ModelLoadContext ctx) @@ -145,16 +145,16 @@ private sealed class Mapper : OneToOneMapperBase private sealed class ColInfo { - public readonly string Output; - public readonly string Input; + public readonly string Name; + public readonly string InputColumnName; public readonly ColumnType OutputType; public readonly ColumnType InputType; public readonly Delegate InputIsNA; - public ColInfo(string input, string output, ColumnType inType, ColumnType outType) + public ColInfo(string name, string inputColumnName, ColumnType inType, ColumnType outType) { - Input = input; - Output = output; + Name = name; + InputColumnName = inputColumnName; InputType = inType; OutputType = outType; InputIsNA = GetIsNADelegate(InputType); @@ -174,8 +174,8 @@ private ColInfo[] CreateInfos(Schema inputSchema) var infos = new ColInfo[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName); _parent.CheckInputColumn(inputSchema, i, colSrc); var inType = inputSchema[colSrc].Type; ColumnType outType; @@ -183,7 +183,7 @@ private ColInfo[] CreateInfos(Schema inputSchema) outType = BoolType.Instance; else outType = new VectorType(BoolType.Instance, vectorType); - infos[i] = new ColInfo(_parent.ColumnPairs[i].input, _parent.ColumnPairs[i].output, inType, outType); + infos[i] = new ColInfo(_parent.ColumnPairs[i].outputColumnName, _parent.ColumnPairs[i].inputColumnName, inType, outType); } return infos; } @@ -193,7 +193,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int iinfo = 0; iinfo < _infos.Length; iinfo++) { - InputSchema.TryGetColumnIndex(_infos[iinfo].Input, out int colIndex); + InputSchema.TryGetColumnIndex(_infos[iinfo].InputColumnName, out int colIndex); Host.Assert(colIndex >= 0); var builder = new MetadataBuilder(); builder.Add(InputSchema[colIndex].Metadata, x => x == MetadataUtils.Kinds.SlotNames); @@ -202,7 +202,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() dst = true; }; builder.Add(MetadataUtils.Kinds.IsNormalized, BoolType.Instance, getter); - result[iinfo] = new Schema.DetachedColumn(_infos[iinfo].Output, _infos[iinfo].OutputType, builder.GetMetadata()); + result[iinfo] = new Schema.DetachedColumn(_infos[iinfo].Name, _infos[iinfo].OutputType, builder.GetMetadata()); } return result; } @@ -434,7 +434,7 @@ public sealed class MissingValueIndicatorEstimator : TrivialEstimator /// The environment to use. /// The names of the input columns of the transformation and the corresponding names for the output columns. - public MissingValueIndicatorEstimator(IHostEnvironment env, params (string input, string output)[] columns) + public MissingValueIndicatorEstimator(IHostEnvironment env, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(MissingValueIndicatorTransformer)), new MissingValueIndicatorTransformer(env, columns)) { Contracts.CheckValue(env, nameof(env)); @@ -444,10 +444,10 @@ public MissingValueIndicatorEstimator(IHostEnvironment env, params (string input /// Initializes a new instance of /// /// The environment to use. - /// The name of the input column of the transformation. - /// The name of the column produced by the transformation. - public MissingValueIndicatorEstimator(IHostEnvironment env, string input, string output = null) - : this(env, (input, output ?? input)) + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. + public MissingValueIndicatorEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null) + : this(env, (outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -460,8 +460,8 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colPair.input, out var col) || !Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.input); + if (!inputSchema.TryFindColumn(colPair.inputColumnName, out var col) || !Data.Conversion.Conversions.Instance.TryGetIsNAPredicate(col.ItemType, out Delegate del)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) metadata.Add(slotMeta); @@ -469,7 +469,7 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) ColumnType type = !(col.ItemType is VectorType vectorType) ? (ColumnType)BoolType.Instance : new VectorType(BoolType.Instance, vectorType); - result[colPair.output] = new SchemaShape.Column(colPair.output, col.Kind, type, false, new SchemaShape(metadata.ToArray())); + result[colPair.outputColumnName] = new SchemaShape.Column(colPair.outputColumnName, col.Kind, type, false, new SchemaShape(metadata.ToArray())); } return new SchemaShape(result.Values); } diff --git a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs index 1e01802b01..675c740b5a 100644 --- a/src/Microsoft.ML.Transforms/MissingValueReplacing.cs +++ b/src/Microsoft.ML.Transforms/MissingValueReplacing.cs @@ -187,26 +187,26 @@ public enum ReplacementMode : byte Maximum = 3, } - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly bool ImputeBySlot; public readonly ReplacementMode Replacement; /// /// Describes how the transformer handles one column pair. /// - /// Name of input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// What to replace the missing value with. /// If true, per-slot imputation of replacement is performed. /// Otherwise, replacement value is imputed for the entire vector column. This setting is ignored for scalars and variable vectors, /// where imputation is always for the entire column. - public ColumnInfo(string input, string output = null, ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, + public ColumnInfo(string name, string inputColumnName = null, ReplacementMode replacementMode = MissingValueReplacingEstimator.Defaults.ReplacementMode, bool imputeBySlot = MissingValueReplacingEstimator.Defaults.ImputeBySlot) { - Contracts.CheckNonWhiteSpace(input, nameof(input)); - Input = input; - Output = output ?? input; + Contracts.CheckNonWhiteSpace(name, nameof(name)); + Name = name; + InputColumnName = inputColumnName ?? name; ImputeBySlot = imputeBySlot; Replacement = replacementMode; } @@ -214,10 +214,10 @@ public ColumnInfo(string input, string output = null, ReplacementMode replacemen internal string ReplacementString { get; set; } } - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } // The output column types, parallel to Infos. @@ -249,8 +249,8 @@ public MissingValueReplacingTransformer(IHostEnvironment env, IDataView input, p // Check that all the input columns are present and correct. for (int i = 0; i < ColumnPairs.Length; i++) { - if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) - throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input); + if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].inputColumnName, out int srcCol)) + throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].inputColumnName); CheckInputColumn(input.Schema, i, srcCol); } GetReplacementValues(input, columns, out _repValues, out _repIsDefault, out _replaceTypes); @@ -324,7 +324,7 @@ private void GetReplacementValues(IDataView input, ColumnInfo[] columns, out obj var sourceColumns = new List(); for (int iinfo = 0; iinfo < columns.Length; iinfo++) { - input.Schema.TryGetColumnIndex(columns[iinfo].Input, out int colSrc); + input.Schema.TryGetColumnIndex(columns[iinfo].InputColumnName, out int colSrc); sources[iinfo] = colSrc; var type = input.Schema[colSrc].Type; if (type is VectorType vectorType) @@ -481,8 +481,9 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat if (!Enum.IsDefined(typeof(ReplacementKind), kind)) throw env.ExceptUserArg(nameof(args.ReplacementKind), "Undefined sorting criteria '{0}' detected for column '{1}'", kind, item.Name); - cols[i] = new ColumnInfo(item.Source, + cols[i] = new ColumnInfo( item.Name, + item.Source, (ColumnInfo.ReplacementMode)(item.Kind ?? args.ReplacementKind), item.Slot ?? args.ImputeBySlot); cols[i].ReplacementString = item.ReplacementString; @@ -568,13 +569,13 @@ private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx private sealed class ColInfo { public readonly string Name; - public readonly string Source; + public readonly string InputColumnName; public readonly ColumnType TypeSrc; - public ColInfo(string name, string source, ColumnType type) + public ColInfo(string outputColumnName, string inputColumnName, ColumnType type) { - Name = name; - Source = source; + Name = outputColumnName; + InputColumnName = inputColumnName; TypeSrc = type; } } @@ -605,18 +606,18 @@ public Mapper(MissingValueReplacingTransformer parent, Schema inputSchema) var repType = _parent._repIsDefault[i] != null ? _parent._replaceTypes[i] : _parent._replaceTypes[i].GetItemType(); if (!type.GetItemType().Equals(repType.GetItemType())) throw Host.ExceptParam(nameof(InputSchema), "Column '{0}' item type '{1}' does not match expected ColumnType of '{2}'", - _infos[i].Source, _parent._replaceTypes[i].GetItemType().ToString(), _infos[i].TypeSrc); + _infos[i].InputColumnName, _parent._replaceTypes[i].GetItemType().ToString(), _infos[i].TypeSrc); // If type is a vector and the value is not either a scalar or a vector of the same size, throw an error. if (repType is VectorType repVectorType) { if (vectorType == null) throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' item type '{1}' cannot be a vector when Columntype is a scalar of type '{2}'", - _infos[i].Source, repType, type); + _infos[i].InputColumnName, repType, type); if (!vectorType.IsKnownSize) - throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' is unknown size vector '{1}' must be a scalar instead of type '{2}'", _infos[i].Source, type, parent._replaceTypes[i]); + throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' is unknown size vector '{1}' must be a scalar instead of type '{2}'", _infos[i].InputColumnName, type, parent._replaceTypes[i]); if (vectorType.Size != repVectorType.Size) throw Host.ExceptParam(nameof(inputSchema), "Column '{0}' item type '{1}' must be a scalar or a vector of the same size as Columntype '{2}'", - _infos[i].Source, repType, type); + _infos[i].InputColumnName, repType, type); } _types[i] = type; _isNAs[i] = _parent.GetIsNADelegate(type); @@ -629,11 +630,11 @@ private ColInfo[] CreateInfos(Schema inputSchema) var infos = new ColInfo[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colSrc)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colSrc)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName); _parent.CheckInputColumn(inputSchema, i, colSrc); var type = inputSchema[colSrc].Type; - infos[i] = new ColInfo(_parent.ColumnPairs[i].output, _parent.ColumnPairs[i].input, type); + infos[i] = new ColInfo(_parent.ColumnPairs[i].outputColumnName, _parent.ColumnPairs[i].inputColumnName, type); } return infos; } @@ -643,11 +644,11 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colIndex); Host.Assert(colIndex >= 0); var builder = new MetadataBuilder(); builder.Add(InputSchema[colIndex].Metadata, x => x == MetadataUtils.Kinds.SlotNames || x == MetadataUtils.Kinds.IsNormalized); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i], builder.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], builder.GetMetadata()); } return result; } @@ -885,14 +886,14 @@ public void SaveAsOnnx(OnnxContext ctx) for (int iinfo = 0; iinfo < _infos.Length; ++iinfo) { ColInfo info = _infos[iinfo]; - string sourceColumnName = info.Source; - if (!ctx.ContainsColumn(sourceColumnName)) + string inputColumnName = info.InputColumnName; + if (!ctx.ContainsColumn(inputColumnName)) { ctx.RemoveColumn(info.Name, false); continue; } - if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(sourceColumnName), + if (!SaveAsOnnxCore(ctx, iinfo, info, ctx.GetVariableName(inputColumnName), ctx.AddIntermediateVariable(_parent._replaceTypes[iinfo], info.Name))) { ctx.RemoveColumn(info.Name, true); @@ -941,18 +942,18 @@ public static class Defaults private readonly IHost _host; private readonly MissingValueReplacingTransformer.ColumnInfo[] _columns; - public MissingValueReplacingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementKind = Defaults.ReplacementMode) - : this(env, new MissingValueReplacingTransformer.ColumnInfo(outputColumn ?? inputColumn, inputColumn, replacementKind)) - { + public MissingValueReplacingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, MissingValueReplacingTransformer.ColumnInfo.ReplacementMode replacementKind = Defaults.ReplacementMode) + : this(env, new MissingValueReplacingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, replacementKind)) + { - } + } - public MissingValueReplacingEstimator(IHostEnvironment env, params MissingValueReplacingTransformer.ColumnInfo[] columns) - { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(MissingValueReplacingEstimator)); - _columns = columns; - } + public MissingValueReplacingEstimator(IHostEnvironment env, params MissingValueReplacingTransformer.ColumnInfo[] columns) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(MissingValueReplacingEstimator)); + _columns = columns; + } public SchemaShape GetOutputSchema(SchemaShape inputSchema) { @@ -960,8 +961,8 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); string reason = MissingValueReplacingTransformer.TestType(col.ItemType); if (reason != null) throw _host.ExceptParam(nameof(inputSchema), reason); @@ -973,12 +974,12 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var type = !(col.ItemType is VectorType vectorType) ? col.ItemType : new VectorType(vectorType.ItemType, vectorType); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, col.Kind, type, false, new SchemaShape(metadata.ToArray())); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, col.Kind, type, false, new SchemaShape(metadata.ToArray())); } return new SchemaShape(result.Values); } - public MissingValueReplacingTransformer Fit(IDataView input) => new MissingValueReplacingTransformer(_host, input, _columns); - } + public MissingValueReplacingTransformer Fit(IDataView input) => new MissingValueReplacingTransformer(_host, input, _columns); + } -} + } diff --git a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs index 24e8bf3d66..caaa550fb6 100644 --- a/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs +++ b/src/Microsoft.ML.Transforms/MutualInformationFeatureSelection.cs @@ -60,7 +60,7 @@ public sealed class Arguments : TransformInputBase } private IHost _host; - private readonly (string input, string output)[] _columns; + private readonly (string outputColumnName, string inputColumnName)[] _columns; private readonly string _labelColumn; private readonly int _slotsInOutput; private readonly int _numBins; @@ -82,7 +82,7 @@ public MutualInformationFeatureSelectingEstimator(IHostEnvironment env, string labelColumn = Defaults.LabelColumn, int slotsInOutput = Defaults.SlotsInOutput, int numBins = Defaults.NumBins, - params(string input, string output)[] columns) + params (string outputColumnName, string inputColumnName)[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(RegistrationName); @@ -100,8 +100,8 @@ public MutualInformationFeatureSelectingEstimator(IHostEnvironment env, /// /// The environment to use. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Name of the column to use for labels. /// The maximum number of slots to preserve in the output. The number of slots to preserve is taken across all input columns. /// Max number of bins used to approximate mutual information between each input column and the label column. Power of 2 recommended. @@ -112,9 +112,9 @@ public MutualInformationFeatureSelectingEstimator(IHostEnvironment env, /// ]]> /// /// - public MutualInformationFeatureSelectingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, + public MutualInformationFeatureSelectingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, string labelColumn = Defaults.LabelColumn, int slotsInOutput = Defaults.SlotsInOutput, int numBins = Defaults.NumBins) - : this(env, labelColumn, slotsInOutput, numBins, (inputColumn, outputColumn ?? inputColumn)) + : this(env, labelColumn, slotsInOutput, numBins, (outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -129,7 +129,7 @@ public ITransformer Fit(IDataView input) var colSet = new HashSet(); foreach (var col in _columns) { - if (!colSet.Add(col.input)) + if (!colSet.Add(col.inputColumnName)) ch.Warning("Column '{0}' specified multiple time.", col); } var colArr = colSet.ToArray(); @@ -143,8 +143,8 @@ public ITransformer Fit(IDataView input) // If no slots should be dropped in a column, use CopyColumn to generate the corresponding output column. SlotsDroppingTransformer.ColumnInfo[] dropSlotsColumns; - (string input, string output)[] copyColumnPairs; - CreateDropAndCopyColumns(colArr.Length, scores, threshold, tiedScoresToKeep, _columns.Where(col => colSet.Contains(col.input)).ToArray(), out int[] selectedCount, out dropSlotsColumns, out copyColumnPairs); + (string outputColumnName, string inputColumnName)[] copyColumnPairs; + CreateDropAndCopyColumns(colArr.Length, scores, threshold, tiedScoresToKeep, _columns.Where(col => colSet.Contains(col.inputColumnName)).ToArray(), out int[] selectedCount, out dropSlotsColumns, out copyColumnPairs); for (int i = 0; i < selectedCount.Length; i++) ch.Info("Selected {0} slots out of {1} in column '{2}'", selectedCount[i], colSizes[i], colArr[i]); @@ -170,18 +170,18 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colPair in _columns) { - if (!inputSchema.TryFindColumn(colPair.input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.input); + if (!inputSchema.TryFindColumn(colPair.inputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colPair.inputColumnName); if (!MutualInformationFeatureSelectionUtils.IsValidColumnType(col.ItemType)) throw _host.ExceptUserArg(nameof(inputSchema), - "Column '{0}' does not have compatible type. Expected types are float, double, int, bool and key.", colPair.input); + "Column '{0}' does not have compatible type. Expected types are float, double, int, bool and key.", colPair.inputColumnName); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) metadata.Add(slotMeta); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.CategoricalSlotRanges, out var categoricalSlotMeta)) metadata.Add(categoricalSlotMeta); metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); - result[colPair.output] = new SchemaShape.Column(colPair.output, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray())); + result[colPair.outputColumnName] = new SchemaShape.Column(colPair.outputColumnName, col.Kind, col.ItemType, false, new SchemaShape(metadata.ToArray())); } return new SchemaShape(result.Values); } @@ -200,7 +200,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat host.CheckNonWhiteSpace(args.LabelColumn, nameof(args.LabelColumn)); host.Check(args.NumBins > 1, "numBins must be greater than 1."); - (string input, string output)[] cols = args.Column.Select(col => (col, col)).ToArray(); + (string outputColumnName, string inputColumnName)[] cols = args.Column.Select(col => (col, col)).ToArray(); return new MutualInformationFeatureSelectingEstimator(env, args.LabelColumn, args.SlotsInOutput, args.NumBins, cols).Fit(input).Transform(input) as IDataTransform; } @@ -253,8 +253,8 @@ private static float ComputeThreshold(float[][] scores, int topk, out int tiedSc return threshold; } - private static void CreateDropAndCopyColumns(int size, float[][] scores, float threshold, int tiedScoresToKeep, (string input, string output)[] cols, - out int[] selectedCount, out SlotsDroppingTransformer.ColumnInfo[] dropSlotsColumns, out (string input, string output)[] copyColumnsPairs) + private static void CreateDropAndCopyColumns(int size, float[][] scores, float threshold, int tiedScoresToKeep, (string outputColumnName, string inputColumnName)[] cols, + out int[] selectedCount, out SlotsDroppingTransformer.ColumnInfo[] dropSlotsColumns, out (string outputColumnName, string inputColumnName)[] copyColumnsPairs) { Contracts.Assert(size > 0); Contracts.Assert(Utils.Size(scores) == size); @@ -262,7 +262,7 @@ private static void CreateDropAndCopyColumns(int size, float[][] scores, float t Contracts.Assert(threshold > 0 || (threshold == 0 && tiedScoresToKeep == 0)); var dropCols = new List(); - var copyCols = new List<(string input, string output)>(); + var copyCols = new List<(string outputColumnName, string inputColumnName)>(); selectedCount = new int[scores.Length]; for (int i = 0; i < size; i++) { @@ -307,7 +307,7 @@ private static void CreateDropAndCopyColumns(int size, float[][] scores, float t if (slots.Count <= 0) copyCols.Add(cols[i]); else - dropCols.Add(new SlotsDroppingTransformer.ColumnInfo(cols[i].input, cols[i].output, slots.ToArray())); + dropCols.Add(new SlotsDroppingTransformer.ColumnInfo(cols[i].outputColumnName, cols[i].inputColumnName, slots.ToArray())); } dropSlotsColumns = dropCols.ToArray(); copyColumnsPairs = copyCols.ToArray(); diff --git a/src/Microsoft.ML.Transforms/OneHotEncoding.cs b/src/Microsoft.ML.Transforms/OneHotEncoding.cs index 03ef15114a..afd3e70726 100644 --- a/src/Microsoft.ML.Transforms/OneHotEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotEncoding.cs @@ -131,8 +131,8 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat foreach (var column in args.Column) { var col = new OneHotEncodingEstimator.ColumnInfo( - column.Source ?? column.Name, column.Name, + column.Source ?? column.Name, column.OutputKind ?? args.OutputKind, column.MaxNumTerms ?? args.MaxNumTerms, column.Sort ?? args.Sort, @@ -191,18 +191,18 @@ public class ColumnInfo : ValueToKeyMappingTransformer.ColumnInfo /// /// Describes how the transformer handles one column pair. /// - /// Name of input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Output kind: Bag (multi-set vector), Ind (indicator vector), Key (index), or Binary encoded indicator vector. /// Maximum number of terms to keep per column when auto-training. /// How items should be ordered when vectorized. If choosen they will be in the order encountered. /// If , items are sorted according to their default comparison, for example, text sorting will be case sensitive (for example, 'A' then 'Z' then 'a'). /// List of terms. - public ColumnInfo(string input, string output = null, + public ColumnInfo(string name, string inputColumnName = null, OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutKind, int maxNumTerms = ValueToKeyMappingEstimator.Defaults.MaxNumTerms, ValueToKeyMappingTransformer.SortOrder sort = ValueToKeyMappingEstimator.Defaults.Sort, string[] term = null) - : base(input, output, maxNumTerms, sort, term, true) + : base(name, inputColumnName ?? name, maxNumTerms, sort, term, true) { OutputKind = outputKind; } @@ -220,12 +220,12 @@ internal void SetTerms(string terms) /// Initializes an instance of the . /// Host Environment. - /// Name of the column to be transformed. - /// Name of the output column. If this is null, is used. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// The type of output expected. - public OneHotEncodingEstimator(IHostEnvironment env, string inputColumn, - string outputColumn = null, OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutKind) - : this(env, new[] { new ColumnInfo(inputColumn, outputColumn ?? inputColumn, outputKind) }) + public OneHotEncodingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, + OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutKind) + : this(env, new[] { new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, outputKind) }) { } @@ -234,8 +234,8 @@ public OneHotEncodingEstimator(IHostEnvironment env, ColumnInfo[] columns, IData Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(OneHotEncodingEstimator)); _term = new ValueToKeyMappingEstimator(_host, columns, keyData); - var binaryCols = new List<(string input, string output)>(); - var cols = new List<(string input, string output, bool bag)>(); + var binaryCols = new List<(string outputColumnName, string inputColumnName)>(); + var cols = new List<(string outputColumnName, string inputColumnName, bool bag)>(); for (int i = 0; i < columns.Length; i++) { var column = columns[i]; @@ -247,22 +247,22 @@ public OneHotEncodingEstimator(IHostEnvironment env, ColumnInfo[] columns, IData case OneHotEncodingTransformer.OutputKind.Key: continue; case OneHotEncodingTransformer.OutputKind.Bin: - binaryCols.Add((column.Output, column.Output)); + binaryCols.Add((column.Name, column.Name)); break; case OneHotEncodingTransformer.OutputKind.Ind: - cols.Add((column.Output, column.Output, false)); + cols.Add((column.Name, column.Name, false)); break; case OneHotEncodingTransformer.OutputKind.Bag: - cols.Add((column.Output, column.Output, true)); + cols.Add((column.Name, column.Name, true)); break; } } IEstimator toBinVector = null; IEstimator toVector = null; if (binaryCols.Count > 0) - toBinVector = new KeyToBinaryVectorMappingEstimator(_host, binaryCols.Select(x => new KeyToBinaryVectorMappingTransformer.ColumnInfo(x.input, x.output)).ToArray()); + toBinVector = new KeyToBinaryVectorMappingEstimator(_host, binaryCols.Select(x => new KeyToBinaryVectorMappingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName)).ToArray()); if (cols.Count > 0) - toVector = new KeyToVectorMappingEstimator(_host, cols.Select(x => new KeyToVectorMappingTransformer.ColumnInfo(x.input, x.output, x.bag)).ToArray()); + toVector = new KeyToVectorMappingEstimator(_host, cols.Select(x => new KeyToVectorMappingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, x.bag)).ToArray()); if (toBinVector != null && toVector != null) _toSomething = toVector.Append(toBinVector); diff --git a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs index bfbb6fc10c..25bda1d07b 100644 --- a/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs +++ b/src/Microsoft.ML.Transforms/OneHotHashEncoding.cs @@ -124,7 +124,7 @@ public sealed class Arguments : TransformInputBase internal const string Summary = "Converts the categorical value into an indicator array by hashing the value and using the hash as an index in the " + "bag. If the input column is a vector, a single indicator bag is returned for it."; - public const string UserName = "Categorical Hash Transform"; + internal const string UserName = "Categorical Hash Transform"; /// /// A helper method to create . @@ -162,8 +162,8 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat foreach (var column in args.Column) { var col = new OneHotHashEncodingEstimator.ColumnInfo( - column.Source ?? column.Name, column.Name, + column.Source ?? column.Name, column.OutputKind ?? args.OutputKind, column.HashBits ?? args.HashBits, column.Seed ?? args.Seed, @@ -218,8 +218,8 @@ public sealed class ColumnInfo /// /// Describes how the transformer handles one column pair. /// - /// Name of input column. - /// Name of output column. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Kind of output: bag, indicator vector etc. /// Number of bits to hash into. Must be between 1 and 31, inclusive. /// Hashing seed. @@ -228,14 +228,14 @@ public sealed class ColumnInfo /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. - public ColumnInfo(string input, string output, + public ColumnInfo(string name, string inputColumnName = null, OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutputKind, int hashBits = Defaults.HashBits, uint seed = Defaults.Seed, bool ordered = Defaults.Ordered, int invertHash = Defaults.InvertHash) { - HashInfo = new HashingTransformer.ColumnInfo(input, output, hashBits, seed, ordered, invertHash); + HashInfo = new HashingTransformer.ColumnInfo(name, inputColumnName ?? name, hashBits, seed, ordered, invertHash); OutputKind = outputKind; } } @@ -244,10 +244,13 @@ public ColumnInfo(string input, string output, private readonly IEstimator _toSomething; private HashingEstimator _hash; + /// /// A helper method to create for public facing API. + /// /// Host Environment. - /// Name of the input column. - /// Name of the output column. If this is null '' will be used. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. + /// If set to , the value of the will be used as source. /// Number of bits to hash into. Must be between 1 and 30, inclusive. /// During hashing we constuct mappings between original values and the produced hash values. /// Text representation of original values are stored in the slot names of the metadata for the new column.Hashing, as such, can map many initial values to one. @@ -255,12 +258,12 @@ public ColumnInfo(string input, string output, /// 0 does not retain any input values. -1 retains all input values mapping to each hash. /// The type of output expected. public OneHotHashEncodingEstimator(IHostEnvironment env, - string inputColumn, - string outputColumn, + string outputColumnName, + string inputColumnName = null, int hashBits = OneHotHashEncodingEstimator.Defaults.HashBits, int invertHash = OneHotHashEncodingEstimator.Defaults.InvertHash, OneHotEncodingTransformer.OutputKind outputKind = Defaults.OutputKind) - : this(env, new ColumnInfo(inputColumn, outputColumn ?? inputColumn, outputKind, hashBits, invertHash: invertHash)) + : this(env, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, outputKind, hashBits, invertHash: invertHash)) { } @@ -271,8 +274,8 @@ public OneHotHashEncodingEstimator(IHostEnvironment env, params ColumnInfo[] col _hash = new HashingEstimator(_host, columns.Select(x => x.HashInfo).ToArray()); using (var ch = _host.Start(nameof(OneHotHashEncodingEstimator))) { - var binaryCols = new List<(string input, string output)>(); - var cols = new List<(string input, string output, bool bag)>(); + var binaryCols = new List<(string outputColumnName, string inputColumnName)>(); + var cols = new List<(string outputColumnName, string inputColumnName, bool bag)>(); for (int i = 0; i < columns.Length; i++) { var column = columns[i]; @@ -286,22 +289,22 @@ public OneHotHashEncodingEstimator(IHostEnvironment env, params ColumnInfo[] col case OneHotEncodingTransformer.OutputKind.Bin: if ((column.HashInfo.InvertHash) != 0) ch.Warning("Invert hashing is being used with binary encoding."); - binaryCols.Add((column.HashInfo.Output, column.HashInfo.Output)); + binaryCols.Add((column.HashInfo.Name, column.HashInfo.Name)); break; case OneHotEncodingTransformer.OutputKind.Ind: - cols.Add((column.HashInfo.Output, column.HashInfo.Output, false)); + cols.Add((column.HashInfo.Name, column.HashInfo.Name, false)); break; case OneHotEncodingTransformer.OutputKind.Bag: - cols.Add((column.HashInfo.Output, column.HashInfo.Output, true)); + cols.Add((column.HashInfo.Name, column.HashInfo.Name, true)); break; } } IEstimator toBinVector = null; IEstimator toVector = null; if (binaryCols.Count > 0) - toBinVector = new KeyToBinaryVectorMappingEstimator(_host, binaryCols.Select(x => new KeyToBinaryVectorMappingTransformer.ColumnInfo(x.input, x.output)).ToArray()); + toBinVector = new KeyToBinaryVectorMappingEstimator(_host, binaryCols.Select(x => new KeyToBinaryVectorMappingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName)).ToArray()); if (cols.Count > 0) - toVector = new KeyToVectorMappingEstimator(_host, cols.Select(x => new KeyToVectorMappingTransformer.ColumnInfo(x.input, x.output, x.bag)).ToArray()); + toVector = new KeyToVectorMappingEstimator(_host, cols.Select(x => new KeyToVectorMappingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, x.bag)).ToArray()); if (toBinVector != null && toVector != null) _toSomething = toVector.Append(toBinVector); diff --git a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs index d6e33231e8..63abdc723f 100644 --- a/src/Microsoft.ML.Transforms/ProduceIdTransform.cs +++ b/src/Microsoft.ML.Transforms/ProduceIdTransform.cs @@ -80,7 +80,7 @@ public Func GetDependencies(Func predicate) } internal const string Summary = "Produces a new column with the row ID."; - public const string LoaderSignature = "ProduceIdTransform"; + internal const string LoaderSignature = "ProduceIdTransform"; private static VersionInfo GetVersionInfo() { return new VersionInfo( diff --git a/src/Microsoft.ML.Transforms/ProjectionCatalog.cs b/src/Microsoft.ML.Transforms/ProjectionCatalog.cs index 50eff9b01c..5ef9670f25 100644 --- a/src/Microsoft.ML.Transforms/ProjectionCatalog.cs +++ b/src/Microsoft.ML.Transforms/ProjectionCatalog.cs @@ -13,8 +13,8 @@ public static class ProjectionCatalog /// Takes column filled with a vector of floats and maps its to a random low-dimensional feature space. /// /// The transform's catalog. - /// Name of the column to be transformed. - /// Name of the output column. If this is null '' will be used. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// The number of random Fourier features to create. /// Create two features for every random Fourier frequency? (one for cos and one for sin). /// @@ -25,11 +25,11 @@ public static class ProjectionCatalog /// /// public static RandomFourierFeaturizingEstimator CreateRandomFourierFeatures(this TransformsCatalog.ProjectionTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, int newDim = RandomFourierFeaturizingEstimator.Defaults.NewDim, bool useSin = RandomFourierFeaturizingEstimator.Defaults.UseSin) - => new RandomFourierFeaturizingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, newDim, useSin); + => new RandomFourierFeaturizingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, newDim, useSin); /// /// Takes columns filled with a vector of floats and maps its to a random low-dimensional feature space. @@ -43,8 +43,8 @@ public static RandomFourierFeaturizingEstimator CreateRandomFourierFeatures(this /// Takes column filled with a vector of floats and computes L-p norm of it. /// /// The transform's catalog. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Type of norm to use to normalize each sample. /// Subtract mean from each value before normalizing. /// @@ -54,9 +54,9 @@ public static RandomFourierFeaturizingEstimator CreateRandomFourierFeatures(this /// ]]> /// /// - public static LpNormalizingEstimator LpNormalize(this TransformsCatalog.ProjectionTransforms catalog, string inputColumn, string outputColumn =null, + public static LpNormalizingEstimator LpNormalize(this TransformsCatalog.ProjectionTransforms catalog, string outputColumnName, string inputColumnName = null, LpNormalizingEstimatorBase.NormalizerKind normKind = LpNormalizingEstimatorBase.Defaults.NormKind, bool subMean = LpNormalizingEstimatorBase.Defaults.LpSubstractMean) - => new LpNormalizingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, normKind, subMean); + => new LpNormalizingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, normKind, subMean); /// /// Takes columns filled with a vector of floats and computes L-p norm of it. @@ -70,8 +70,8 @@ public static LpNormalizingEstimator LpNormalize(this TransformsCatalog.Projecti /// Takes column filled with a vector of floats and computes global contrast normalization of it. /// /// The transform's catalog. - /// Name of the input column. - /// Name of the column resulting from the transformation of . Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Subtract mean from each value before normalizing. /// Normalize by standard deviation rather than L2 norm. /// Scale features by this value. @@ -82,11 +82,11 @@ public static LpNormalizingEstimator LpNormalize(this TransformsCatalog.Projecti /// ]]> /// /// - public static GlobalContrastNormalizingEstimator GlobalContrastNormalize(this TransformsCatalog.ProjectionTransforms catalog, string inputColumn, string outputColumn = null, + public static GlobalContrastNormalizingEstimator GlobalContrastNormalize(this TransformsCatalog.ProjectionTransforms catalog, string outputColumnName, string inputColumnName = null, bool substractMean = LpNormalizingEstimatorBase.Defaults.GcnSubstractMean, bool useStdDev = LpNormalizingEstimatorBase.Defaults.UseStdDev, float scale = LpNormalizingEstimatorBase.Defaults.Scale) - => new GlobalContrastNormalizingEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, substractMean, useStdDev, scale); + => new GlobalContrastNormalizingEstimator(CatalogUtils.GetEnvironment(catalog), outputColumnName, inputColumnName, substractMean, useStdDev, scale); /// /// Takes columns filled with a vector of floats and computes global contrast normalization of it. diff --git a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs index 3b7a45a610..f9149f09f7 100644 --- a/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs +++ b/src/Microsoft.ML.Transforms/RandomFourierFeaturizing.cs @@ -234,8 +234,8 @@ private static string TestColumnType(ColumnType type) public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly IComponentFactory Generator; public readonly int NewDim; public readonly bool UseSin; @@ -244,17 +244,17 @@ public sealed class ColumnInfo /// /// Describes how the transformer handles one column pair. /// - /// Name of input column. - /// Name of output column. - /// Which fourier generator to use. + /// Name of the column resulting from the transformation of . /// The number of random Fourier features to create. /// Create two features for every random Fourier frequency? (one for cos and one for sin). + /// Name of column to transform. + /// Which fourier generator to use. /// The seed of the random number generator for generating the new features (if unspecified, the global random is used. - public ColumnInfo(string input, string output, int newDim, bool useSin, IComponentFactory generator = null, int? seed = null) + public ColumnInfo(string name, int newDim, bool useSin, string inputColumnName = null, IComponentFactory generator = null, int? seed = null) { Contracts.CheckUserArg(newDim > 0, nameof(newDim), "must be positive."); - Input = input; - Output = output; + InputColumnName = inputColumnName ?? name; + Name = name; Generator = generator ?? new GaussianFourierSampler.Arguments(); NewDim = newDim; UseSin = useSin; @@ -262,10 +262,10 @@ public ColumnInfo(string input, string output, int newDim, bool useSin, ICompone } } - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) @@ -273,9 +273,9 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol var type = inputSchema[srcCol].Type; string reason = TestColumnType(type); if (reason != null) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, reason, type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, reason, type.ToString()); if (_transformInfos[col].SrcDim != type.GetVectorSize()) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, new VectorType(NumberType.Float, _transformInfos[col].SrcDim).ToString(), type.ToString()); } @@ -286,7 +286,7 @@ public RandomFourierFeaturizingTransformer(IHostEnvironment env, IDataView input _transformInfos = new TransformInfo[columns.Length]; for (int i = 0; i < columns.Length; i++) { - input.Schema.TryGetColumnIndex(columns[i].Input, out int srcCol); + input.Schema.TryGetColumnIndex(columns[i].InputColumnName, out int srcCol); var typeSrc = input.Schema[srcCol].Type; _transformInfos[i] = new TransformInfo(Host.Register(string.Format("column{0}", i)), columns[i], typeSrc.GetValueCount(), avgDistances[i]); @@ -313,12 +313,12 @@ private float[] GetAvgDistances(ColumnInfo[] columns, IDataView input) int[] srcCols = new int[columns.Length]; for (int i = 0; i < columns.Length; i++) { - if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].input, out int srcCol)) - throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input); + if (!input.Schema.TryGetColumnIndex(ColumnPairs[i].inputColumnName, out int srcCol)) + throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].inputColumnName); var type = input.Schema[srcCol].Type; string reason = TestColumnType(type); if (reason != null) - throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].input, reason, type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(input), "input", ColumnPairs[i].inputColumnName, reason, type.ToString()); srcCols[i] = srcCol; activeColumns.Add(input.Schema[srcCol]); } @@ -462,10 +462,11 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - cols[i] = new ColumnInfo(item.Source ?? item.Name, + cols[i] = new ColumnInfo( item.Name, item.NewDim ?? args.NewDim, item.UseSin ?? args.UseSin, + item.Source ?? item.Name, item.MatrixGenerator ?? args.MatrixGenerator, item.Seed ?? args.Seed); }; @@ -521,7 +522,7 @@ public Mapper(RandomFourierFeaturizingTransformer parent, Schema inputSchema) _srcCols = new int[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i]); + inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i]); var srcCol = inputSchema[_srcCols[i]]; _srcTypes[i] = srcCol.Type; //validate typeSrc.ValueCount and transformInfo.SrcDim @@ -534,7 +535,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i], null); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], null); return result; } @@ -656,12 +657,12 @@ internal static class Defaults /// Convinence constructor for simple one column case /// /// Host Environment. - /// Name of the column to be transformed. - /// Name of the output column. If this is null '' will be used. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// The number of random Fourier features to create. /// Create two features for every random Fourier frequency? (one for cos and one for sin). - public RandomFourierFeaturizingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, int newDim = Defaults.NewDim, bool useSin = Defaults.UseSin) - : this(env, new RandomFourierFeaturizingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, newDim, useSin)) + public RandomFourierFeaturizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, int newDim = Defaults.NewDim, bool useSin = Defaults.UseSin) + : this(env, new RandomFourierFeaturizingTransformer.ColumnInfo(outputColumnName, newDim, useSin, inputColumnName ?? outputColumnName)) { } @@ -680,12 +681,12 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (col.ItemType.RawType != typeof(float) || col.Kind != SchemaShape.Column.VectorKind.Vector) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Transforms/SerializableLambdaTransform.cs b/src/Microsoft.ML.Transforms/SerializableLambdaTransform.cs index 5a0aeb735b..a7ef1afc89 100644 --- a/src/Microsoft.ML.Transforms/SerializableLambdaTransform.cs +++ b/src/Microsoft.ML.Transforms/SerializableLambdaTransform.cs @@ -34,8 +34,8 @@ public static VersionInfo GetVersionInfo() loaderAssemblyName: typeof(SerializableLambdaTransform).Assembly.FullName); } - public const string LoaderSignature = "UserLambdaMapTransform"; - public const string Summary = "Allows the definition of convenient user defined transforms"; + internal const string LoaderSignature = "UserLambdaMapTransform"; + internal const string Summary = "Allows the definition of convenient user defined transforms"; /// /// Creates an instance of the transform from a context. diff --git a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs index b278f9bdef..bd312b55fe 100644 --- a/src/Microsoft.ML.Transforms/Text/LdaTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/LdaTransform.cs @@ -163,8 +163,8 @@ internal bool TryUnparse(StringBuilder sb) public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly int NumTopic; public readonly float AlphaSum; public readonly float Beta; @@ -180,8 +180,8 @@ public sealed class ColumnInfo /// /// Describes how the transformer handles one column pair. /// - /// The column representing the document as a vector of floats. - /// The column containing the output scores over a set of topics, represented as a vector of floats. A null value for the column means is replaced. + /// The column containing the output scores over a set of topics, represented as a vector of floats. + /// The column representing the document as a vector of floats.A null value for the column means is replaced. /// The number of topics. /// Dirichlet prior on document-topic vectors. /// Dirichlet prior on vocab-topic vectors. @@ -193,8 +193,8 @@ public sealed class ColumnInfo /// The number of words to summarize the topic. /// The number of burn-in iterations. /// Reset the random number generator for each document. - public ColumnInfo(string input, - string output = null, + public ColumnInfo(string name, + string inputColumnName = null, int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, float alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, float beta = LatentDirichletAllocationEstimator.Defaults.Beta, @@ -207,8 +207,8 @@ public ColumnInfo(string input, int numBurninIter = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator) { - Contracts.CheckValue(input, nameof(input)); - Contracts.CheckValueOrNull(output); + Contracts.CheckValue(name, nameof(name)); + Contracts.CheckValueOrNull(inputColumnName); Contracts.CheckParam(numTopic > 0, nameof(numTopic), "Must be positive."); Contracts.CheckParam(mhStep > 0, nameof(mhStep), "Must be positive."); Contracts.CheckParam(numIter > 0, nameof(numIter), "Must be positive."); @@ -218,8 +218,8 @@ public ColumnInfo(string input, Contracts.CheckParam(numSummaryTermPerTopic > 0, nameof(numSummaryTermPerTopic), "Must be positive"); Contracts.CheckParam(numBurninIter >= 0, nameof(numBurninIter), "Must be non-negative."); - Input = input; - Output = output ?? input; + Name = name; + InputColumnName = inputColumnName ?? name; NumTopic = numTopic; AlphaSum = alphaSum; Beta = beta; @@ -234,7 +234,8 @@ public ColumnInfo(string input, } internal ColumnInfo(Column item, Arguments args) : - this(item.Source ?? item.Name, item.Name, + this(item.Name, + item.Source ?? item.Name, item.NumTopic ?? args.NumTopic, item.AlphaSum ?? args.AlphaSum, item.Beta ?? args.Beta, @@ -710,13 +711,13 @@ public Mapper(LatentDirichletAllocationTransformer parent, Schema inputSchema) for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i])) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i])) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName); var srcCol = inputSchema[_srcCols[i]]; var srcType = srcCol.Type as VectorType; if (srcType == null || !srcType.IsKnownSize || !(srcType.ItemType is NumberType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input, "a fixed vector of floats", srcCol.Type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName, "a fixed vector of floats", srcCol.Type.ToString()); } } @@ -726,7 +727,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() for (int i = 0; i < _parent.ColumnPairs.Length; i++) { var info = _parent._columns[i]; - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, new VectorType(NumberType.Float, info.NumTopic), null); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, new VectorType(NumberType.Float, info.NumTopic), null); } return result; } @@ -780,10 +781,10 @@ private static VersionInfo GetVersionInfo() internal const string UserName = "Latent Dirichlet Allocation Transform"; internal const string ShortName = "LightLda"; - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } /// @@ -944,12 +945,12 @@ private static List>> Train(IHostEnvironment env, I var inputSchema = inputData.Schema; for (int i = 0; i < columns.Length; i++) { - if (!inputData.Schema.TryGetColumnIndex(columns[i].Input, out int srcCol)) - throw env.ExceptSchemaMismatch(nameof(inputData), "input", columns[i].Input); + if (!inputData.Schema.TryGetColumnIndex(columns[i].InputColumnName, out int srcCol)) + throw env.ExceptSchemaMismatch(nameof(inputData), "input", columns[i].InputColumnName); var srcColType = inputSchema[srcCol].Type as VectorType; if (srcColType == null || !srcColType.IsKnownSize || !(srcColType.ItemType is NumberType)) - throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].Input, "a fixed vector of floats", srcColType.ToString()); + throw env.ExceptSchemaMismatch(nameof(inputSchema), "input", columns[i].InputColumnName, "a fixed vector of floats", srcColType.ToString()); srcCols[i] = srcCol; activeColumns.Add(inputData.Schema[srcCol]); @@ -1029,7 +1030,7 @@ private static List>> Train(IHostEnvironment env, I if (numDocArray[i] != rowCount) { ch.Assert(numDocArray[i] < rowCount); - ch.Warning($"Column '{columns[i].Input}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values."); + ch.Warning($"Column '{columns[i].InputColumnName}' has skipped {rowCount - numDocArray[i]} of {rowCount} rows either empty or with negative, non-finite, or fractional values."); } } } @@ -1040,7 +1041,7 @@ private static List>> Train(IHostEnvironment env, I var state = new LdaState(env, columns[i], numVocabs[i]); if (numDocArray[i] == 0 || corpusSize[i] == 0) - throw ch.Except("The specified documents are all empty in column '{0}'.", columns[i].Input); + throw ch.Except("The specified documents are all empty in column '{0}'.", columns[i].InputColumnName); state.AllocateDataMemory(numDocArray[i], corpusSize[i]); states[i] = state; @@ -1107,8 +1108,8 @@ internal static class Defaults /// /// The environment. - /// The column representing the document as a vector of floats. - /// The column containing the output scores over a set of topics, represented as a vector of floats. A null value for the column means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// The number of topics. /// Dirichlet prior on document-topic vectors. /// Dirichlet prior on vocab-topic vectors. @@ -1121,8 +1122,7 @@ internal static class Defaults /// The number of burn-in iterations. /// Reset the random number generator for each document. public LatentDirichletAllocationEstimator(IHostEnvironment env, - string inputColumn, - string outputColumn = null, + string outputColumnName, string inputColumnName = null, int numTopic = Defaults.NumTopic, float alphaSum = Defaults.AlphaSum, float beta = Defaults.Beta, @@ -1134,7 +1134,7 @@ public LatentDirichletAllocationEstimator(IHostEnvironment env, int numSummaryTermPerTopic = Defaults.NumSummaryTermPerTopic, int numBurninIterations = Defaults.NumBurninIterations, bool resetRandomGenerator = Defaults.ResetRandomGenerator) - : this(env, new[] { new LatentDirichletAllocationTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn, + : this(env, new[] { new LatentDirichletAllocationTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName, numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator) }) { } @@ -1158,12 +1158,12 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (col.ItemType.RawType != typeof(float) || col.Kind == SchemaShape.Column.VectorKind.Scalar) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, "a vector of floats", col.GetTypeString()); + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, "a vector of floats", col.GetTypeString()); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs index 1c5650e681..cb76b256e7 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramHashingTransformer.cs @@ -178,8 +178,8 @@ private static VersionInfo GetVersionInfo() /// public sealed class ColumnInfo { - public readonly string[] Inputs; - public readonly string Output; + public readonly string Name; + public readonly string[] InputColumnNames; public readonly int NgramLength; public readonly int SkipLength; public readonly bool AllLengths; @@ -195,8 +195,8 @@ public sealed class ColumnInfo /// /// Describes how the transformer handles one column pair. /// - /// Name of input columns. - /// Name of output column. + /// Name of the column resulting from the transformation of . + /// Name of the columns to transform. /// Maximum ngram length. /// Maximum number of tokens to skip when constructing an ngram. /// "Whether to store all ngram lengths up to ngramLength, or only ngramLength. @@ -208,7 +208,8 @@ public sealed class ColumnInfo /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. /// Whether to rehash unigrams. - public ColumnInfo(string[] inputs, string output, + public ColumnInfo(string name, + string[] inputColumnNames, int ngramLength = NgramHashingEstimator.Defaults.NgramLength, int skipLength = NgramHashingEstimator.Defaults.SkipLength, bool allLengths = NgramHashingEstimator.Defaults.AllLengths, @@ -218,8 +219,9 @@ public ColumnInfo(string[] inputs, string output, int invertHash = NgramHashingEstimator.Defaults.InvertHash, bool rehashUnigrams = NgramHashingEstimator.Defaults.RehashUnigrams) { - Contracts.CheckValue(inputs, nameof(inputs)); - Contracts.CheckParam(!inputs.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputs), + Contracts.CheckValue(name, nameof(name)); + Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames)); + Contracts.CheckParam(!inputColumnNames.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputColumnNames), "Contained some null or empty items"); if (invertHash < -1) throw Contracts.ExceptParam(nameof(invertHash), "Value too small, must be -1 or larger"); @@ -234,8 +236,8 @@ public ColumnInfo(string[] inputs, string output, $"The sum of skipLength and ngramLength must be less than or equal to {NgramBufferBuilder.MaxSkipNgramLength}"); } FriendlyNames = null; - Inputs = inputs; - Output = output; + Name = name; + InputColumnNames = inputColumnNames; NgramLength = ngramLength; SkipLength = skipLength; AllLengths = allLengths; @@ -262,10 +264,10 @@ internal ColumnInfo(ModelLoadContext ctx) // byte: Ordered // byte: AllLengths var inputsLength = ctx.Reader.ReadInt32(); - Inputs = new string[inputsLength]; - for (int i = 0; i < Inputs.Length; i++) - Inputs[i] = ctx.LoadNonEmptyString(); - Output = ctx.LoadNonEmptyString(); + InputColumnNames = new string[inputsLength]; + for (int i = 0; i < InputColumnNames.Length; i++) + InputColumnNames[i] = ctx.LoadNonEmptyString(); + Name = ctx.LoadNonEmptyString(); NgramLength = ctx.Reader.ReadInt32(); Contracts.CheckDecode(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength); SkipLength = ctx.Reader.ReadInt32(); @@ -279,14 +281,14 @@ internal ColumnInfo(ModelLoadContext ctx) AllLengths = ctx.Reader.ReadBoolByte(); } - internal ColumnInfo(ModelLoadContext ctx, string[] inputs, string output) + internal ColumnInfo(ModelLoadContext ctx, string name, string[] inputColumnNames) { Contracts.AssertValue(ctx); - Contracts.CheckValue(inputs, nameof(inputs)); - Contracts.CheckParam(!inputs.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputs), + Contracts.CheckValue(inputColumnNames, nameof(inputColumnNames)); + Contracts.CheckParam(!inputColumnNames.Any(r => string.IsNullOrWhiteSpace(r)), nameof(inputColumnNames), "Contained some null or empty items"); - Inputs = inputs; - Output = output; + InputColumnNames = inputColumnNames; + Name = name; // *** Binary format *** // string Output; // int: NgramLength @@ -324,11 +326,11 @@ internal void Save(ModelSaveContext ctx) // byte: Rehash // byte: Ordered // byte: AllLengths - Contracts.Assert(Inputs.Length > 0); - ctx.Writer.Write(Inputs.Length); - for (int i = 0; i < Inputs.Length; i++) - ctx.SaveNonEmptyString(Inputs[i]); - ctx.SaveNonEmptyString(Output); + Contracts.Assert(InputColumnNames.Length > 0); + ctx.Writer.Write(InputColumnNames.Length); + for (int i = 0; i < InputColumnNames.Length; i++) + ctx.SaveNonEmptyString(InputColumnNames[i]); + ctx.SaveNonEmptyString(Name); Contracts.Assert(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength); ctx.Writer.Write(NgramLength); @@ -385,13 +387,13 @@ internal NgramHashingTransformer(IHostEnvironment env, IDataView input, params C { columnWithInvertHash.Add(i); invertHashMaxCounts[i] = invertHashMaxCount; - for (int j = 0; j < _columns[i].Inputs.Length; j++) + for (int j = 0; j < _columns[i].InputColumnNames.Length; j++) { - if (!input.Schema.TryGetColumnIndex(_columns[i].Inputs[j], out int srcCol)) - throw Host.ExceptSchemaMismatch(nameof(input), "input", _columns[i].Inputs[j]); + if (!input.Schema.TryGetColumnIndex(_columns[i].InputColumnNames[j], out int srcCol)) + throw Host.ExceptSchemaMismatch(nameof(input), "input", _columns[i].InputColumnNames[j]); var columnType = input.Schema[srcCol].Type; if (!NgramHashingEstimator.IsColumnTypeValid(input.Schema[srcCol].Type)) - throw Host.ExceptSchemaMismatch(nameof(input), "input", _columns[i].Inputs[j], NgramHashingEstimator.ExpectedColumnType, columnType.ToString()); + throw Host.ExceptSchemaMismatch(nameof(input), "input", _columns[i].InputColumnNames[j], NgramHashingEstimator.ExpectedColumnType, columnType.ToString()); sourceColumnsForInvertHash.Add(input.Schema[srcCol]); } } @@ -498,7 +500,7 @@ private NgramHashingTransformer(IHostEnvironment env, ModelLoadContext ctx, bool // int number of columns // columns for (int i = 0; i < columnsLength; i++) - columns[i] = new ColumnInfo(ctx, inputs[i], outputs[i]); + columns[i] = new ColumnInfo(ctx, outputs[i], inputs[i]); } _columns = columns.ToImmutableArray(); TextModelHelper.LoadAll(Host, ctx, columnsLength, out _slotNames, out _slotNamesTypes); @@ -519,8 +521,9 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - cols[i] = new ColumnInfo(item.Source ?? new string[] { item.Name }, + cols[i] = new ColumnInfo( item.Name, + item.Source ?? new string[] { item.Name }, item.NgramLength ?? args.NgramLength, item.SkipLength ?? args.SkipLength, item.AllLengths ?? args.AllLengths, @@ -564,11 +567,11 @@ public Mapper(NgramHashingTransformer parent, Schema inputSchema, FinderDecorato _srcTypes = new ColumnType[_parent._columns.Length][]; for (int i = 0; i < _parent._columns.Length; i++) { - _srcIndices[i] = new int[_parent._columns[i].Inputs.Length]; - _srcTypes[i] = new ColumnType[_parent._columns[i].Inputs.Length]; - for (int j = 0; j < _parent._columns[i].Inputs.Length; j++) + _srcIndices[i] = new int[_parent._columns[i].InputColumnNames.Length]; + _srcTypes[i] = new ColumnType[_parent._columns[i].InputColumnNames.Length]; + for (int j = 0; j < _parent._columns[i].InputColumnNames.Length; j++) { - var srcName = _parent._columns[i].Inputs[j]; + var srcName = _parent._columns[i].InputColumnNames[j]; if (!inputSchema.TryGetColumnIndex(srcName, out int srcCol)) throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName); var columnType = inputSchema[srcCol].Type; @@ -784,7 +787,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var builder = new MetadataBuilder(); AddMetadata(i, builder); - result[i] = new Schema.DetachedColumn(_parent._columns[i].Output, _types[i], builder.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent._columns[i].Name, _types[i], builder.GetMetadata()); } return result; } @@ -842,10 +845,10 @@ public InvertHashHelper(NgramHashingTransformer parent, Schema inputSchema, stri _srcIndices = new int[_parent._columns.Length][]; for (int i = 0; i < _parent._columns.Length; i++) { - _srcIndices[i] = new int[_parent._columns[i].Inputs.Length]; - for (int j = 0; j < _parent._columns[i].Inputs.Length; j++) + _srcIndices[i] = new int[_parent._columns[i].InputColumnNames.Length]; + for (int j = 0; j < _parent._columns[i].InputColumnNames.Length; j++) { - var srcName = _parent._columns[i].Inputs[j]; + var srcName = _parent._columns[i].InputColumnNames[j]; if (!inputSchema.TryGetColumnIndex(srcName, out int srcCol)) throw _parent.Host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName); _srcIndices[i][j] = srcCol; @@ -965,7 +968,7 @@ public NgramIdFinder Decorate(int iinfo, NgramIdFinder finder) { srcNames = new string[srcIndices.Length]; for (int i = 0; i < srcIndices.Length; ++i) - srcNames[i] = _parent._columns[iinfo].Inputs[i]; + srcNames[i] = _parent._columns[iinfo].InputColumnNames[i]; } Contracts.Assert(Utils.Size(srcNames) == srcIndices.Length); string[] friendlyNames = _friendlyNames?[iinfo]; @@ -1057,15 +1060,15 @@ internal static class Defaults private readonly NgramHashingTransformer.ColumnInfo[] _columns; /// - /// Produces a bag of counts of hashed ngrams in - /// and outputs ngram vector as + /// Produces a bag of counts of hashed ngrams in + /// and outputs ngram vector as /// /// is different from in a way that /// takes tokenized text as input while tokenizes text internally. /// /// The environment. - /// Name of input column containing tokenized text. - /// Name of output column, will contain the ngram vector. Null means is replaced. + /// Name of output column, will contain the ngram vector. Null means is replaced. + /// Name of input column containing tokenized text. /// Number of bits to hash into. Must be between 1 and 30, inclusive. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. @@ -1077,8 +1080,8 @@ internal static class Defaults /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public NgramHashingEstimator(IHostEnvironment env, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, int hashBits = 16, int ngramLength = 2, int skipLength = 0, @@ -1086,20 +1089,20 @@ public NgramHashingEstimator(IHostEnvironment env, uint seed = 314489979, bool ordered = true, int invertHash = 0) - : this(env, new[] { (new[] { inputColumn }, outputColumn ?? inputColumn) }, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash) + : this(env, new[] { (outputColumnName, new[] { inputColumnName ?? outputColumnName }) }, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash) { } /// - /// Produces a bag of counts of hashed ngrams in - /// and outputs ngram vector as + /// Produces a bag of counts of hashed ngrams in + /// and outputs ngram vector as /// /// is different from in a way that /// takes tokenized text as input while tokenizes text internally. /// /// The environment. - /// Name of input columns containing tokenized text. - /// Name of output column, will contain the ngram vector. + /// Name of output column, will contain the ngram vector. + /// Name of input columns containing tokenized text. /// Number of bits to hash into. Must be between 1 and 30, inclusive. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. @@ -1111,8 +1114,8 @@ public NgramHashingEstimator(IHostEnvironment env, /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public NgramHashingEstimator(IHostEnvironment env, - string[] inputColumns, - string outputColumn, + string outputColumnName, + string[] inputColumnNames, int hashBits = 16, int ngramLength = 2, int skipLength = 0, @@ -1120,7 +1123,7 @@ public NgramHashingEstimator(IHostEnvironment env, uint seed = 314489979, bool ordered = true, int invertHash = 0) - : this(env, new[] { (inputColumns, outputColumn) }, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash) + : this(env, new[] { (outputColumnName, inputColumnNames) }, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash) { } @@ -1144,7 +1147,7 @@ public NgramHashingEstimator(IHostEnvironment env, /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public NgramHashingEstimator(IHostEnvironment env, - (string[] inputs, string output)[] columns, + (string outputColumnName, string[] inputColumnName)[] columns, int hashBits = 16, int ngramLength = 2, int skipLength = 0, @@ -1152,7 +1155,7 @@ public NgramHashingEstimator(IHostEnvironment env, uint seed = 314489979, bool ordered = true, int invertHash = 0) - : this(env, columns.Select(x => new NgramHashingTransformer.ColumnInfo(x.inputs, x.output, ngramLength, skipLength, allLengths, hashBits, seed, ordered, invertHash)).ToArray()) + : this(env, columns.Select(x => new NgramHashingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, hashBits, seed, ordered, invertHash)).ToArray()) { } @@ -1205,7 +1208,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { - foreach (var input in colInfo.Inputs) + foreach (var input in colInfo.InputColumnNames) { if (!inputSchema.TryFindColumn(input, out var col)) throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", input); @@ -1214,7 +1217,7 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) } var metadata = new List(); metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); } return new SchemaShape(result.Values); } diff --git a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs index 590bd55c73..c267a224c6 100644 --- a/src/Microsoft.ML.Transforms/Text/NgramTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/NgramTransform.cs @@ -127,8 +127,8 @@ private static VersionInfo GetVersionInfo() /// public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly int NgramLength; public readonly int SkipLength; public readonly bool AllLengths; @@ -142,31 +142,33 @@ public sealed class ColumnInfo /// /// Describes how the transformer handles one Gcn column pair. /// - /// Name of input column. - /// Name of output column. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Maximum ngram length. /// Maximum number of tokens to skip when constructing an ngram. /// "Whether to store all ngram lengths up to ngramLength, or only ngramLength. /// The weighting criteria. /// Maximum number of ngrams to store in the dictionary. - public ColumnInfo(string input, string output, + public ColumnInfo(string name, string inputColumnName = null, int ngramLength = NgramExtractingEstimator.Defaults.NgramLength, int skipLength = NgramExtractingEstimator.Defaults.SkipLength, bool allLengths = NgramExtractingEstimator.Defaults.AllLengths, NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.Defaults.Weighting, - int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms) : this(input, output, ngramLength, skipLength, allLengths, weighting, new int[] { maxNumTerms }) + int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms) + : this(name, ngramLength, skipLength, allLengths, weighting, new int[] { maxNumTerms }, inputColumnName ?? name) { } - internal ColumnInfo(string input, string output, + internal ColumnInfo(string name, int ngramLength, int skipLength, bool allLengths, NgramExtractingEstimator.WeightingCriteria weighting, - int[] maxNumTerms) + int[] maxNumTerms, + string inputColumnName = null) { - Input = input; - Output = output; + Name = name; + InputColumnName = inputColumnName ?? name; NgramLength = ngramLength; Contracts.CheckUserArg(0 < NgramLength && NgramLength <= NgramBufferBuilder.MaxSkipNgramLength, nameof(ngramLength)); SkipLength = skipLength; @@ -265,17 +267,17 @@ public void Save(ModelSaveContext ctx) // Ngram inverse document frequencies private readonly double[][] _invDocFreqs; - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { var type = inputSchema[srcCol].Type; if (!NgramExtractingEstimator.IsColumnTypeValid(type)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, NgramExtractingEstimator.ExpectedColumnType, type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, NgramExtractingEstimator.ExpectedColumnType, type.ToString()); } internal NgramExtractingTransformer(IHostEnvironment env, IDataView input, ColumnInfo[] columns) @@ -284,7 +286,7 @@ internal NgramExtractingTransformer(IHostEnvironment env, IDataView input, Colum var transformInfos = new TransformInfo[columns.Length]; for (int i = 0; i < columns.Length; i++) { - input.Schema.TryGetColumnIndex(columns[i].Input, out int srcCol); + input.Schema.TryGetColumnIndex(columns[i].InputColumnName, out int srcCol); var typeSrc = input.Schema[srcCol].Type; transformInfos[i] = new TransformInfo(columns[i]); } @@ -307,7 +309,7 @@ private static SequencePool[] Train(IHostEnvironment env, ColumnInfo[] columns, var srcCols = new int[columns.Length]; for (int iinfo = 0; iinfo < columns.Length; iinfo++) { - trainingData.Schema.TryGetColumnIndex(columns[iinfo].Input, out srcCols[iinfo]); + trainingData.Schema.TryGetColumnIndex(columns[iinfo].InputColumnName, out srcCols[iinfo]); srcTypes[iinfo] = trainingData.Schema[srcCols[iinfo]].Type; activeCols.Add(trainingData.Schema[srcCols[iinfo]]); } @@ -495,13 +497,14 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat { var item = args.Column[i]; var maxNumTerms = Utils.Size(item.MaxNumTerms) > 0 ? item.MaxNumTerms : args.MaxNumTerms; - cols[i] = new ColumnInfo(item.Source ?? item.Name, + cols[i] = new ColumnInfo( item.Name, item.NgramLength ?? args.NgramLength, item.SkipLength ?? args.SkipLength, item.AllLengths ?? args.AllLengths, item.Weighting ?? args.Weighting, - maxNumTerms); + maxNumTerms, + item.Source ?? item.Name); }; } return new NgramExtractingTransformer(env, input, cols).MakeDataTransform(input); @@ -562,7 +565,7 @@ public Mapper(NgramExtractingTransformer parent, Schema inputSchema) for (int i = 0; i < _parent.ColumnPairs.Length; i++) { _types[i] = new VectorType(NumberType.Float, _parent._ngramMaps[i].Count); - inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out _srcCols[i]); + inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out _srcCols[i]); _srcTypes[i] = inputSchema[_srcCols[i]].Type; } } @@ -575,7 +578,7 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var builder = new MetadataBuilder(); AddMetadata(i, builder); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i], builder.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], builder.GetMetadata()); } return result; } @@ -777,26 +780,25 @@ internal static class Defaults private readonly NgramExtractingTransformer.ColumnInfo[] _columns; /// - /// Produces a bag of counts of ngrams (sequences of consecutive words) in - /// and outputs bag of word vector as + /// Produces a bag of counts of ngrams (sequences of consecutive words) in + /// and outputs bag of word vector as /// /// The environment. - /// The column containing text to compute bag of word vector. - /// The column containing bag of word vector. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. /// Whether to include all ngram lengths up to or only . /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. public NgramExtractingEstimator(IHostEnvironment env, - string inputColumn, - string outputColumn = null, + string outputColumnName, string inputColumnName = null, int ngramLength = Defaults.NgramLength, int skipLength = Defaults.SkipLength, bool allLengths = Defaults.AllLengths, int maxNumTerms = Defaults.MaxNumTerms, WeightingCriteria weighting = Defaults.Weighting) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, ngramLength, skipLength, allLengths, maxNumTerms, weighting) + : this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, ngramLength, skipLength, allLengths, maxNumTerms, weighting) { } @@ -812,13 +814,13 @@ public NgramExtractingEstimator(IHostEnvironment env, /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. public NgramExtractingEstimator(IHostEnvironment env, - (string input, string output)[] columns, + (string outputColumnName, string inputColumnName)[] columns, int ngramLength = Defaults.NgramLength, int skipLength = Defaults.SkipLength, bool allLengths = Defaults.AllLengths, int maxNumTerms = Defaults.MaxNumTerms, WeightingCriteria weighting = Defaults.Weighting) - : this(env, columns.Select(x => new NgramExtractingTransformer.ColumnInfo(x.input, x.output, ngramLength, skipLength, allLengths, weighting, maxNumTerms)).ToArray()) + : this(env, columns.Select(x => new NgramExtractingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, ngramLength, skipLength, allLengths, weighting, maxNumTerms)).ToArray()) { } @@ -869,14 +871,14 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (!IsSchemaColumnValid(col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, ExpectedColumnType, col.GetTypeString()); + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, ExpectedColumnType, col.GetTypeString()); var metadata = new List(); if (col.HasKeyValues()) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, new SchemaShape(metadata)); } return new SchemaShape(result.Values); } diff --git a/src/Microsoft.ML.Transforms/Text/SentimentAnalyzingTransform.cs b/src/Microsoft.ML.Transforms/Text/SentimentAnalyzingTransform.cs index b1db0ca187..1963da8131 100644 --- a/src/Microsoft.ML.Transforms/Text/SentimentAnalyzingTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/SentimentAnalyzingTransform.cs @@ -81,7 +81,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat // 2. Copy source column to a column with the name expected by the pretrained model featurization // transform pipeline. - var copyTransformer = new ColumnCopyingTransformer(env, (args.Source, ModelInputColumnName)); + var copyTransformer = new ColumnCopyingTransformer(env, (ModelInputColumnName, args.Source)); input = copyTransformer.Transform(input); @@ -90,7 +90,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat // 4. Copy the output column from the pretrained model to a temporary column. var scoreTempName = input.Schema.GetTempColumnName("sa_out"); - copyTransformer = new ColumnCopyingTransformer(env, (ModelScoreColumnName, scoreTempName)); + copyTransformer = new ColumnCopyingTransformer(env, (scoreTempName, ModelScoreColumnName)); input = copyTransformer.Transform(input); // 5. Drop all the columns created by the pretrained model, including the expected input column @@ -103,7 +103,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat input = UnaliasIfNeeded(env, input, aliased); // 7. Copy the temporary column with the score we created in (4) to a column with the user-specified destination name. - copyTransformer = new ColumnCopyingTransformer(env, (scoreTempName, args.Name)); + copyTransformer = new ColumnCopyingTransformer(env, (args.Name, scoreTempName)); input = copyTransformer.Transform(input); // 8. Drop the temporary column with the score created in (4). @@ -131,8 +131,8 @@ private static IDataView AliasIfNeeded(IHostEnvironment env, IDataView input, st return input; hiddenNames = toHide.Select(colName => - new KeyValuePair(colName, input.Schema.GetTempColumnName(colName))).ToArray(); - return new ColumnCopyingTransformer(env, hiddenNames.Select(x => (Input: x.Key, Output: x.Value)).ToArray()).Transform(input); + new KeyValuePair(input.Schema.GetTempColumnName(colName), colName)).ToArray(); + return new ColumnCopyingTransformer(env, hiddenNames.Select(x => (Name: x.Key, Source: x.Value)).ToArray()).Transform(input); } private static IDataView UnaliasIfNeeded(IHostEnvironment env, IDataView input, KeyValuePair[] hiddenNames) @@ -140,7 +140,7 @@ private static IDataView UnaliasIfNeeded(IHostEnvironment env, IDataView input, if (Utils.Size(hiddenNames) == 0) return input; - input = new ColumnCopyingTransformer(env, hiddenNames.Select(x => (Input: x.Key, Output: x.Value)).ToArray()).Transform(input); + input = new ColumnCopyingTransformer(env, hiddenNames.Select(x => (outputColumnName: x.Key, inputColumnName: x.Value)).ToArray()).Transform(input); return ColumnSelectingTransformer.CreateDrop(env, input, hiddenNames.Select(pair => pair.Value).ToArray()); } diff --git a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs index b8cce1c769..e8690253a2 100644 --- a/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs +++ b/src/Microsoft.ML.Transforms/Text/StopWordsRemovingTransformer.cs @@ -60,7 +60,7 @@ public sealed class PredefinedStopWordsRemoverFactory : IStopWordsRemoverFactory { public IDataTransform CreateComponent(IHostEnvironment env, IDataView input, OneToOneColumn[] columns) { - return new StopWordsRemovingEstimator(env, columns.Select(x => new StopWordsRemovingTransformer.ColumnInfo(x.Source, x.Name)).ToArray()).Fit(input).Transform(input) as IDataTransform; + return new StopWordsRemovingEstimator(env, columns.Select(x => new StopWordsRemovingTransformer.ColumnInfo(x.Name, x.Source)).ToArray()).Fit(input).Transform(input) as IDataTransform; } } @@ -182,38 +182,43 @@ private static NormStr.Pool[] StopWords /// public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly StopWordsRemovingEstimator.Language Language; public readonly string LanguageColumn; /// /// Describes how the transformer handles one column pair. /// - /// Name of input column. - /// Name of output column. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Language-specific stop words list. /// Optional column to use for languages. This overrides language value. - public ColumnInfo(string input, string output, StopWordsRemovingEstimator.Language language = StopWordsRemovingEstimator.Defaults.DefaultLanguage, string languageColumn = null) + public ColumnInfo(string name, + string inputColumnName = null, + StopWordsRemovingEstimator.Language language = StopWordsRemovingEstimator.Defaults.DefaultLanguage, + string languageColumn = null) { - Input = input; - Output = output; + Contracts.CheckNonWhiteSpace(name, nameof(name)); + + Name = name; + InputColumnName = inputColumnName ?? name; Language = language; LanguageColumn = languageColumn; } } - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { var type = inputSchema[srcCol].Type; if (!StopWordsRemovingEstimator.IsColumnTypeValid(type)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, StopWordsRemovingEstimator.ExpectedColumnType, type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, StopWordsRemovingEstimator.ExpectedColumnType, type.ToString()); } /// @@ -261,7 +266,7 @@ private StopWordsRemovingTransformer(IHost host, ModelLoadContext ctx) : var lang = (StopWordsRemovingEstimator.Language)ctx.Reader.ReadInt32(); Contracts.CheckDecode(Enum.IsDefined(typeof(StopWordsRemovingEstimator.Language), lang)); var langColName = ctx.LoadStringOrNull(); - _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, lang, langColName); + _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, lang, langColName); } } @@ -287,8 +292,9 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - cols[i] = new ColumnInfo(item.Source ?? item.Name, + cols[i] = new ColumnInfo( item.Name, + item.Source ?? item.Name, item.Language ?? args.Language, item.LanguagesColumn ?? args.LanguagesColumn); } @@ -387,14 +393,14 @@ public Mapper(StopWordsRemovingTransformer parent, Schema inputSchema) for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int srcCol)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].input); + if (!inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int srcCol)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", _parent.ColumnPairs[i].inputColumnName); _parent.CheckInputColumn(InputSchema, i, srcCol); _colMapNewToOld.Add(i, srcCol); var srcType = InputSchema[srcCol].Type; if (!StopWordsRemovingEstimator.IsColumnTypeValid(srcType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent._columns[i].Input, StopWordsRemovingEstimator.ExpectedColumnType, srcType.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent._columns[i].InputColumnName, StopWordsRemovingEstimator.ExpectedColumnType, srcType.ToString()); _types[i] = new VectorType(TextType.Instance); if (!string.IsNullOrEmpty(_parent._columns[i].LanguageColumn)) @@ -414,9 +420,9 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colIndex); Host.Assert(colIndex >= 0); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i]); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i]); } return result; } @@ -547,15 +553,15 @@ public static bool IsColumnTypeValid(ColumnType type) => internal const string ExpectedColumnType = "vector of Text type"; /// - /// Removes stop words from incoming token streams in - /// and outputs the token streams without stopwords as . + /// Removes stop words from incoming token streams in + /// and outputs the token streams without stopwords as . /// /// The environment. - /// The column containing text to remove stop words on. - /// The column containing output text. Null means is replaced. - /// Langauge of the input text column . - public StopWordsRemovingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, Language language = Language.English) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, language) + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. + /// Langauge of the input text column . + public StopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, Language language = Language.English) + : this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, language) { } @@ -566,8 +572,8 @@ public StopWordsRemovingEstimator(IHostEnvironment env, string inputColumn, stri /// The environment. /// Pairs of columns to remove stop words on. /// Langauge of the input text columns . - public StopWordsRemovingEstimator(IHostEnvironment env, (string input, string output)[] columns, Language language = Language.English) - : this(env, columns.Select(x => new StopWordsRemovingTransformer.ColumnInfo(x.input, x.output, language)).ToArray()) + public StopWordsRemovingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, Language language = Language.English) + : this(env, columns.Select(x => new StopWordsRemovingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, language)).ToArray()) { } @@ -582,11 +588,11 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (col.Kind == SchemaShape.Column.VectorKind.Scalar || !(col.ItemType is TextType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, ExpectedColumnType, col.ItemType.ToString()); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.VariableVector, TextType.Instance, false); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, ExpectedColumnType, col.ItemType.ToString()); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.VariableVector, TextType.Instance, false); } return new SchemaShape(result.Values); } @@ -647,9 +653,9 @@ public sealed class LoaderArguments : ArgumentsBase, IStopWordsRemoverFactory public IDataTransform CreateComponent(IHostEnvironment env, IDataView input, OneToOneColumn[] column) { if (Utils.Size(Stopword) > 0) - return new CustomStopWordsRemovingTransformer(env, Stopword, column.Select(x => (x.Source, x.Name)).ToArray()).Transform(input) as IDataTransform; + return new CustomStopWordsRemovingTransformer(env, Stopword, column.Select(x => (x.Name, x.Source)).ToArray()).Transform(input) as IDataTransform; else - return new CustomStopWordsRemovingTransformer(env, Stopwords, DataFile, StopwordsColumn, Loader, column.Select(x => (x.Source, x.Name)).ToArray()).Transform(input) as IDataTransform; + return new CustomStopWordsRemovingTransformer(env, Stopwords, DataFile, StopwordsColumn, Loader, column.Select(x => (x.Name, x.Source)).ToArray()).Transform(input) as IDataTransform; } } @@ -801,7 +807,7 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory stopwords, string d } } - public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); + public IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); /// /// Custom stopword remover removes specified list of stop words. @@ -809,7 +815,7 @@ private void LoadStopWords(IChannel ch, ReadOnlyMemory stopwords, string d /// The environment. /// Array of words to remove. /// Pairs of columns to remove stop words from. - public CustomStopWordsRemovingTransformer(IHostEnvironment env, string[] stopwords, params (string input, string output)[] columns) : + public CustomStopWordsRemovingTransformer(IHostEnvironment env, string[] stopwords, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) { _stopWordsMap = new NormStr.Pool(); @@ -828,7 +834,7 @@ public CustomStopWordsRemovingTransformer(IHostEnvironment env, string[] stopwor } internal CustomStopWordsRemovingTransformer(IHostEnvironment env, string stopwords, - string dataFile, string stopwordsColumn, IComponentFactory loader, params (string input, string output)[] columns) : + string dataFile, string stopwordsColumn, IComponentFactory loader, params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) { var ch = Host.Start("LoadStopWords"); @@ -937,11 +943,11 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat env.CheckValue(input, nameof(input)); env.CheckValue(args.Column, nameof(args.Column)); - var cols = new (string input, string output)[args.Column.Length]; + var cols = new (string outputColumnName, string inputColumnName)[args.Column.Length]; for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - cols[i] = (item.Source ?? item.Name, item.Name); + cols[i] = (item.Name, item.Source ?? item.Name); } CustomStopWordsRemovingTransformer transfrom = null; if (Utils.Size(args.Stopword) > 0) @@ -973,10 +979,10 @@ public Mapper(CustomStopWordsRemovingTransformer parent, Schema inputSchema) _types = new ColumnType[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int srcCol); + inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int srcCol); var srcType = inputSchema[srcCol].Type; if (!StopWordsRemovingEstimator.IsColumnTypeValid(srcType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.ColumnPairs[i].input, StopWordsRemovingEstimator.ExpectedColumnType, srcType.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", parent.ColumnPairs[i].inputColumnName, StopWordsRemovingEstimator.ExpectedColumnType, srcType.ToString()); _types[i] = new VectorType(TextType.Instance); } @@ -987,9 +993,9 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colIndex); Host.Assert(colIndex >= 0); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i]); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i]); } return result; } @@ -1042,15 +1048,15 @@ public sealed class CustomStopWordsRemovingEstimator : TrivialEstimator - /// Removes stop words from incoming token streams in - /// and outputs the token streams without stopwords as . + /// Removes stop words from incoming token streams in + /// and outputs the token streams without stopwords as . /// /// The environment. - /// The column containing text to remove stop words on. - /// The column containing output text. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Array of words to remove. - public CustomStopWordsRemovingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, params string[] stopwords) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, stopwords) + public CustomStopWordsRemovingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, params string[] stopwords) + : this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, stopwords) { } @@ -1061,7 +1067,7 @@ public CustomStopWordsRemovingEstimator(IHostEnvironment env, string inputColumn /// The environment. /// Pairs of columns to remove stop words on. /// Array of words to remove. - public CustomStopWordsRemovingEstimator(IHostEnvironment env, (string input, string output)[] columns, string[] stopwords) : + public CustomStopWordsRemovingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, string[] stopwords) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(CustomStopWordsRemovingEstimator)), new CustomStopWordsRemovingTransformer(env, stopwords, columns)) { } @@ -1072,11 +1078,11 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input); + if (!inputSchema.TryFindColumn(colInfo.inputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName); if (col.Kind == SchemaShape.Column.VectorKind.Scalar || !(col.ItemType is TextType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input, ExpectedColumnType, col.ItemType.ToString()); - result[colInfo.output] = new SchemaShape.Column(colInfo.output, SchemaShape.Column.VectorKind.VariableVector, TextType.Instance, false); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName, ExpectedColumnType, col.ItemType.ToString()); + result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName, SchemaShape.Column.VectorKind.VariableVector, TextType.Instance, false); } return new SchemaShape(result.Values); } diff --git a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs index 7c8913d157..1bb41c8b49 100644 --- a/src/Microsoft.ML.Transforms/Text/TextCatalog.cs +++ b/src/Microsoft.ML.Transforms/Text/TextCatalog.cs @@ -18,8 +18,8 @@ public static class TextCatalog /// Transform a text column into featurized float array that represents counts of ngrams and char-grams. /// /// The text-related transform's catalog. - /// The input column - /// The output column + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Advanced transform settings /// /// @@ -29,39 +29,39 @@ public static class TextCatalog /// /// public static TextFeaturizingEstimator FeaturizeText(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, Action advancedSettings = null) => new TextFeaturizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - inputColumn, outputColumn, advancedSettings); + outputColumnName, inputColumnName, advancedSettings); /// /// Transform several text columns into featurized float array that represents counts of ngrams and char-grams. /// /// The text-related transform's catalog. - /// The input columns - /// The output column + /// Name of the column resulting from the transformation of . + /// Name of the columns to transform. If set to , the value of the will be used as source. /// Advanced transform settings public static TextFeaturizingEstimator FeaturizeText(this TransformsCatalog.TextTransforms catalog, - IEnumerable inputColumns, - string outputColumn, + string outputColumnName, + IEnumerable inputColumnNames, Action advancedSettings = null) => new TextFeaturizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - inputColumns, outputColumn, advancedSettings); + outputColumnName, inputColumnNames, advancedSettings); /// - /// Tokenize incoming text in and output the tokens as . + /// Tokenize incoming text in and output the tokens as . /// /// The text-related transform's catalog. - /// The column containing text to tokenize. - /// The column containing output tokens. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Whether to use marker characters to separate words. public static TokenizingByCharactersEstimator TokenizeCharacters(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, bool useMarkerCharacters = CharTokenizingDefaults.UseMarkerCharacters) => new TokenizingByCharactersEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - useMarkerCharacters, new[] { (inputColumn, outputColumn) }); + useMarkerCharacters, new[] { (outputColumnName, inputColumnName) }); /// /// Tokenize incoming text in input columns and output the tokens as output columns. @@ -72,56 +72,56 @@ public static TokenizingByCharactersEstimator TokenizeCharacters(this Transforms public static TokenizingByCharactersEstimator TokenizeCharacters(this TransformsCatalog.TextTransforms catalog, bool useMarkerCharacters = CharTokenizingDefaults.UseMarkerCharacters, - params (string inputColumn, string outputColumn)[] columns) + params (string outputColumnName, string inputColumnName)[] columns) => new TokenizingByCharactersEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), useMarkerCharacters, columns); /// - /// Normalizes incoming text in by changing case, removing diacritical marks, punctuation marks and/or numbers - /// and outputs new text as . + /// Normalizes incoming text in by changing case, removing diacritical marks, punctuation marks and/or numbers + /// and outputs new text as . /// /// The text-related transform's catalog. - /// The column containing text to normalize. - /// The column containing output tokens. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Casing text using the rules of the invariant culture. /// Whether to keep diacritical marks or remove them. /// Whether to keep punctuation marks or remove them. /// Whether to keep numbers or remove them. public static TextNormalizingEstimator NormalizeText(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, TextNormalizingEstimator.CaseNormalizationMode textCase = TextNormalizeDefaults.TextCase, bool keepDiacritics = TextNormalizeDefaults.KeepDiacritics, bool keepPunctuations = TextNormalizeDefaults.KeepPunctuations, bool keepNumbers = TextNormalizeDefaults.KeepNumbers) => new TextNormalizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - inputColumn, outputColumn, textCase, keepDiacritics, keepPunctuations, keepNumbers); + outputColumnName, inputColumnName, textCase, keepDiacritics, keepPunctuations, keepNumbers); /// /// Extracts word embeddings. /// /// The text-related transform's catalog. - /// The input column. - /// The optional output column. If it is null the input column will be substituted with its value. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// The embeddings to use. public static WordEmbeddingsExtractingEstimator ExtractWordEmbeddings(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe) - => new WordEmbeddingsExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), inputColumn, outputColumn, modelKind); + => new WordEmbeddingsExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName, modelKind); /// /// Extracts word embeddings. /// /// The text-related transform's catalog. - /// The input column. - /// The optional output column. If it is null the input column will be substituted with its value. + /// Name of the column resulting from the transformation of . /// The path of the pre-trained embeedings model to use. + /// Name of the column to transform. public static WordEmbeddingsExtractingEstimator ExtractWordEmbeddings(this TransformsCatalog.TextTransforms catalog, - string inputColumn, + string outputColumnName, string customModelFile, - string outputColumn = null) + string inputColumnName = null) => new WordEmbeddingsExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - inputColumn, outputColumn, customModelFile); + outputColumnName, customModelFile, inputColumnName ?? outputColumnName); /// /// Extracts word embeddings. @@ -135,18 +135,18 @@ public static WordEmbeddingsExtractingEstimator ExtractWordEmbeddings(this Trans => new WordEmbeddingsExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), modelKind, columns); /// - /// Tokenizes incoming text in , using as separators, - /// and outputs the tokens as . + /// Tokenizes incoming text in , using as separators, + /// and outputs the tokens as . /// /// The text-related transform's catalog. - /// The column containing text to tokenize. - /// The column containing output tokens. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// The separators to use (uses space character by default). public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, char[] separators = null) - => new WordTokenizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), inputColumn, outputColumn, separators); + => new WordTokenizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName, separators); /// /// Tokenizes incoming text in input columns and outputs the tokens using as separators. @@ -155,7 +155,7 @@ public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextT /// Pairs of columns to run the tokenization on. /// The separators to use (uses space character by default). public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextTransforms catalog, - (string inputColumn, string outputColumn)[] columns, + (string outputColumnName, string inputColumnName)[] columns, char[] separators = null) => new WordTokenizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns, separators); @@ -169,12 +169,12 @@ public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextT => new WordTokenizingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns); /// - /// Produces a bag of counts of ngrams (sequences of consecutive words) in - /// and outputs bag of word vector as + /// Produces a bag of counts of ngrams (sequences of consecutive words) in + /// and outputs bag of word vector as /// /// The text-related transform's catalog. - /// Name of input column containing tokenized text. - /// Name of output column, will contain the ngram vector. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. /// Whether to include all ngram lengths up to or only . @@ -188,14 +188,14 @@ public static WordTokenizingEstimator TokenizeWords(this TransformsCatalog.TextT /// /// public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, int ngramLength = NgramExtractingEstimator.Defaults.NgramLength, int skipLength = NgramExtractingEstimator.Defaults.SkipLength, bool allLengths = NgramExtractingEstimator.Defaults.AllLengths, int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms, NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.Defaults.Weighting) => - new NgramExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), inputColumn, outputColumn, + new NgramExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName, ngramLength, skipLength, allLengths, maxNumTerms, weighting); /// @@ -210,7 +210,7 @@ public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.Text /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.TextTransforms catalog, - (string input, string output)[] columns, + (string outputColumnName, string inputColumnName)[] columns, int ngramLength = NgramExtractingEstimator.Defaults.NgramLength, int skipLength = NgramExtractingEstimator.Defaults.SkipLength, bool allLengths = NgramExtractingEstimator.Defaults.AllLengths, @@ -230,13 +230,13 @@ public static NgramExtractingEstimator ProduceNgrams(this TransformsCatalog.Text => new NgramExtractingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns); /// - /// Removes stop words from incoming token streams in - /// and outputs the token streams without stopwords as . + /// Removes stop words from incoming token streams in + /// and outputs the token streams without stopwords as . /// /// The text-related transform's catalog. - /// The column containing text to remove stop words on. - /// The column containing output text. Null means is replaced. - /// Langauge of the input text column . + /// The column containing output text. Null means is replaced. + /// The column containing text to remove stop words on. + /// Langauge of the input text column . /// /// /// /// public static StopWordsRemovingEstimator RemoveDefaultStopWords(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, StopWordsRemovingEstimator.Language language = StopWordsRemovingEstimator.Language.English) - => new StopWordsRemovingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), inputColumn, outputColumn, language); + => new StopWordsRemovingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName, language); /// /// Removes stop words from incoming token streams in input columns @@ -263,17 +263,17 @@ public static StopWordsRemovingEstimator RemoveDefaultStopWords(this TransformsC /// ]]> /// public static StopWordsRemovingEstimator RemoveDefaultStopWords(this TransformsCatalog.TextTransforms catalog, - (string input, string output)[] columns, + (string outputColumnName, string inputColumnName)[] columns, StopWordsRemovingEstimator.Language language = StopWordsRemovingEstimator.Language.English) => new StopWordsRemovingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns, language); /// - /// Removes stop words from incoming token streams in - /// and outputs the token streams without stopwords as . + /// Removes stop words from incoming token streams in + /// and outputs the token streams without stopwords as . /// /// The text-related transform's catalog. - /// The column containing text to remove stop words on. - /// The column containing output text. Null means is replaced. + /// The column containing output text. Null means is replaced. + /// The column containing text to remove stop words on. /// Array of words to remove. /// /// @@ -282,10 +282,10 @@ public static StopWordsRemovingEstimator RemoveDefaultStopWords(this TransformsC /// ]]> /// public static CustomStopWordsRemovingEstimator RemoveStopWords(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, params string[] stopwords) - => new CustomStopWordsRemovingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), inputColumn, outputColumn, stopwords); + => new CustomStopWordsRemovingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), outputColumnName, inputColumnName, stopwords); /// /// Removes stop words from incoming token streams in input columns @@ -301,55 +301,55 @@ public static CustomStopWordsRemovingEstimator RemoveStopWords(this TransformsCa /// ]]> /// public static CustomStopWordsRemovingEstimator RemoveStopWords(this TransformsCatalog.TextTransforms catalog, - (string input, string output)[] columns, + (string outputColumnName, string inputColumnName)[] columns, params string[] stopwords) => new CustomStopWordsRemovingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns, stopwords); /// - /// Produces a bag of counts of ngrams (sequences of consecutive words) in - /// and outputs bag of word vector as + /// Produces a bag of counts of ngrams (sequences of consecutive words) in + /// and outputs bag of word vector as /// /// The text-related transform's catalog. - /// The column containing text to compute bag of word vector. - /// The column containing bag of word vector. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. /// Whether to include all ngram lengths up to or only . /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, int ngramLength = NgramExtractingEstimator.Defaults.NgramLength, int skipLength = NgramExtractingEstimator.Defaults.SkipLength, bool allLengths = NgramExtractingEstimator.Defaults.AllLengths, int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms, NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf) => new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - inputColumn, outputColumn, ngramLength, skipLength, allLengths, maxNumTerms); + outputColumnName, inputColumnName, ngramLength, skipLength, allLengths, maxNumTerms); /// - /// Produces a bag of counts of ngrams (sequences of consecutive words) in - /// and outputs bag of word vector as + /// Produces a bag of counts of ngrams (sequences of consecutive words) in + /// and outputs bag of word vector as /// /// The text-related transform's catalog. - /// The columns containing text to compute bag of word vector. - /// The column containing output tokens. + /// Name of the column resulting from the transformation of . + /// Name of the columns to transform. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. /// Whether to include all ngram lengths up to or only . /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransforms catalog, - string[] inputColumns, - string outputColumn, + string outputColumnName, + string[] inputColumnNames, int ngramLength = NgramExtractingEstimator.Defaults.NgramLength, int skipLength = NgramExtractingEstimator.Defaults.SkipLength, bool allLengths = NgramExtractingEstimator.Defaults.AllLengths, int maxNumTerms = NgramExtractingEstimator.Defaults.MaxNumTerms, NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf) => new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - inputColumns, outputColumn, ngramLength, skipLength, allLengths, maxNumTerms, weighting); + outputColumnName, inputColumnNames, ngramLength, skipLength, allLengths, maxNumTerms, weighting); /// /// Produces a bag of counts of ngrams (sequences of consecutive words) in @@ -363,7 +363,7 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransforms catalog, - (string[] inputs, string output)[] columns, + (string outputColumnName, string[] inputColumnNames)[] columns, int ngramLength = NgramExtractingEstimator.Defaults.NgramLength, int skipLength = NgramExtractingEstimator.Defaults.SkipLength, bool allLengths = NgramExtractingEstimator.Defaults.AllLengths, @@ -372,12 +372,12 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf => new WordBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), columns, ngramLength, skipLength, allLengths, maxNumTerms, weighting); /// - /// Produces a bag of counts of hashed ngrams in - /// and outputs bag of word vector as + /// Produces a bag of counts of hashed ngrams in + /// and outputs bag of word vector as /// /// The text-related transform's catalog. - /// The column containing text to compute bag of word vector. - /// The column containing bag of word vector. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Number of bits to hash into. Must be between 1 and 30, inclusive. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. @@ -389,8 +389,8 @@ public static WordBagEstimator ProduceWordBags(this TransformsCatalog.TextTransf /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public static WordHashBagEstimator ProduceHashedWordBags(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, int hashBits = NgramHashExtractingTransformer.DefaultArguments.HashBits, int ngramLength = NgramHashExtractingTransformer.DefaultArguments.NgramLength, int skipLength = NgramHashExtractingTransformer.DefaultArguments.SkipLength, @@ -399,15 +399,15 @@ public static WordHashBagEstimator ProduceHashedWordBags(this TransformsCatalog. bool ordered = NgramHashExtractingTransformer.DefaultArguments.Ordered, int invertHash = NgramHashExtractingTransformer.DefaultArguments.InvertHash) => new WordHashBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - inputColumn, outputColumn, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash); + outputColumnName, inputColumnName, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash); /// - /// Produces a bag of counts of hashed ngrams in - /// and outputs bag of word vector as + /// Produces a bag of counts of hashed ngrams in + /// and outputs bag of word vector as /// /// The text-related transform's catalog. - /// The columns containing text to compute bag of word vector. - /// The column containing output tokens. + /// Name of the column resulting from the transformation of . + /// Name of the columns to transform. If set to , the value of the will be used as source. /// Number of bits to hash into. Must be between 1 and 30, inclusive. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. @@ -419,8 +419,8 @@ public static WordHashBagEstimator ProduceHashedWordBags(this TransformsCatalog. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public static WordHashBagEstimator ProduceHashedWordBags(this TransformsCatalog.TextTransforms catalog, - string[] inputColumns, - string outputColumn, + string outputColumnName, + string[] inputColumnNames, int hashBits = NgramHashExtractingTransformer.DefaultArguments.HashBits, int ngramLength = NgramHashExtractingTransformer.DefaultArguments.NgramLength, int skipLength = NgramHashExtractingTransformer.DefaultArguments.SkipLength, @@ -429,7 +429,7 @@ public static WordHashBagEstimator ProduceHashedWordBags(this TransformsCatalog. bool ordered = NgramHashExtractingTransformer.DefaultArguments.Ordered, int invertHash = NgramHashExtractingTransformer.DefaultArguments.InvertHash) => new WordHashBagEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - inputColumns, outputColumn, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash); + outputColumnName, inputColumnNames, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash); /// /// Produces a bag of counts of hashed ngrams in @@ -448,7 +448,7 @@ public static WordHashBagEstimator ProduceHashedWordBags(this TransformsCatalog. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public static WordHashBagEstimator ProduceHashedWordBags(this TransformsCatalog.TextTransforms catalog, - (string[] inputs, string output)[] columns, + (string outputColumnName, string[] inputColumnNames)[] columns, int hashBits = NgramHashExtractingTransformer.DefaultArguments.HashBits, int ngramLength = NgramHashExtractingTransformer.DefaultArguments.NgramLength, int skipLength = NgramHashExtractingTransformer.DefaultArguments.SkipLength, @@ -460,15 +460,15 @@ public static WordHashBagEstimator ProduceHashedWordBags(this TransformsCatalog. columns, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash); /// - /// Produces a bag of counts of hashed ngrams in - /// and outputs ngram vector as + /// Produces a bag of counts of hashed ngrams in + /// and outputs ngram vector as /// /// is different from in a way that /// takes tokenized text as input while tokenizes text internally. /// /// The text-related transform's catalog. - /// Name of input column containing tokenized text. - /// Name of output column, will contain the ngram vector. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Number of bits to hash into. Must be between 1 and 30, inclusive. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. @@ -480,8 +480,8 @@ public static WordHashBagEstimator ProduceHashedWordBags(this TransformsCatalog. /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, int hashBits = NgramHashingEstimator.Defaults.HashBits, int ngramLength = NgramHashingEstimator.Defaults.NgramLength, int skipLength = NgramHashingEstimator.Defaults.SkipLength, @@ -490,18 +490,18 @@ public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.T bool ordered = NgramHashingEstimator.Defaults.Ordered, int invertHash = NgramHashingEstimator.Defaults.InvertHash) => new NgramHashingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - inputColumn, outputColumn, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash); + outputColumnName, inputColumnName, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash); /// - /// Produces a bag of counts of hashed ngrams in - /// and outputs ngram vector as + /// Produces a bag of counts of hashed ngrams in + /// and outputs ngram vector as /// /// is different from in a way that /// takes tokenized text as input while tokenizes text internally. /// /// The text-related transform's catalog. - /// Name of input columns containing tokenized text. - /// Name of output column, will contain the ngram vector. + /// Name of the column resulting from the transformation of . + /// Name of the columns to transform. /// Number of bits to hash into. Must be between 1 and 30, inclusive. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. @@ -513,8 +513,8 @@ public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.T /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.TextTransforms catalog, - string[] inputColumns, - string outputColumn, + string outputColumnName, + string[] inputColumnNames, int hashBits = NgramHashingEstimator.Defaults.HashBits, int ngramLength = NgramHashingEstimator.Defaults.NgramLength, int skipLength = NgramHashingEstimator.Defaults.SkipLength, @@ -523,7 +523,7 @@ public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.T bool ordered = NgramHashingEstimator.Defaults.Ordered, int invertHash = NgramHashingEstimator.Defaults.InvertHash) => new NgramHashingEstimator(Contracts.CheckRef(catalog, nameof(catalog)).GetEnvironment(), - inputColumns, outputColumn, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash); + outputColumnName, inputColumnNames, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash); /// /// Produces a bag of counts of hashed ngrams in @@ -545,7 +545,7 @@ public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.T /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.TextTransforms catalog, - (string[] inputs, string output)[] columns, + (string outputColumnName, string[] inputColumnNames)[] columns, int hashBits = NgramHashingEstimator.Defaults.HashBits, int ngramLength = NgramHashingEstimator.Defaults.NgramLength, int skipLength = NgramHashingEstimator.Defaults.SkipLength, @@ -561,8 +561,8 @@ public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.T /// into a vector of floats over a set of topics. /// /// The transform's catalog. - /// The column representing the document as a vector of floats. - /// The column containing the output scores over a set of topics, represented as a vector of floats. A null value for the column means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// The number of topics. /// Dirichlet prior on document-topic vectors. /// Dirichlet prior on vocab-topic vectors. @@ -582,8 +582,8 @@ public static NgramHashingEstimator ProduceHashedNgrams(this TransformsCatalog.T /// /// public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this TransformsCatalog.TextTransforms catalog, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, int numTopic = LatentDirichletAllocationEstimator.Defaults.NumTopic, float alphaSum = LatentDirichletAllocationEstimator.Defaults.AlphaSum, float beta = LatentDirichletAllocationEstimator.Defaults.Beta, @@ -595,8 +595,9 @@ public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this int numSummaryTermPerTopic = LatentDirichletAllocationEstimator.Defaults.NumSummaryTermPerTopic, int numBurninIterations = LatentDirichletAllocationEstimator.Defaults.NumBurninIterations, bool resetRandomGenerator = LatentDirichletAllocationEstimator.Defaults.ResetRandomGenerator) - => new LatentDirichletAllocationEstimator(CatalogUtils.GetEnvironment(catalog), inputColumn, outputColumn, numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, numMaxDocToken, - numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator); + => new LatentDirichletAllocationEstimator(CatalogUtils.GetEnvironment(catalog), + outputColumnName, inputColumnName, numTopic, alphaSum, beta, mhstep, numIterations, likelihoodInterval, numThreads, + numMaxDocToken, numSummaryTermPerTopic, numBurninIterations, resetRandomGenerator); /// /// Uses LightLDA to transform a document (represented as a vector of floats) @@ -604,7 +605,9 @@ public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this /// /// The transform's catalog. /// Describes the parameters of LDA for each column pair. - public static LatentDirichletAllocationEstimator LatentDirichletAllocation(this TransformsCatalog.TextTransforms catalog, params LatentDirichletAllocationTransformer.ColumnInfo[] columns) + public static LatentDirichletAllocationEstimator LatentDirichletAllocation( + this TransformsCatalog.TextTransforms catalog, + params LatentDirichletAllocationTransformer.ColumnInfo[] columns) => new LatentDirichletAllocationEstimator(CatalogUtils.GetEnvironment(catalog), columns); } } diff --git a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs index 07d7a7211f..19453c203d 100644 --- a/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs +++ b/src/Microsoft.ML.Transforms/Text/TextFeaturizingEstimator.cs @@ -258,25 +258,25 @@ public TransformApplierParams(TextFeaturizingEstimator parent) private const string TransformedTextColFormat = "{0}_TransformedText"; - public TextFeaturizingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, + public TextFeaturizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, Action advancedSettings = null) - : this(env, new[] { inputColumn }, outputColumn ?? inputColumn, advancedSettings) + : this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName }, advancedSettings) { } - public TextFeaturizingEstimator(IHostEnvironment env, IEnumerable inputColumns, string outputColumn, + public TextFeaturizingEstimator(IHostEnvironment env, string name, IEnumerable source, Action advancedSettings = null) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(TextFeaturizingEstimator)); - _host.CheckValue(inputColumns, nameof(inputColumns)); - _host.CheckParam(inputColumns.Any(), nameof(inputColumns)); - _host.CheckParam(!inputColumns.Any(string.IsNullOrWhiteSpace), nameof(inputColumns)); - _host.CheckNonEmpty(outputColumn, nameof(outputColumn)); + _host.CheckValue(source, nameof(source)); + _host.CheckParam(source.Any(), nameof(source)); + _host.CheckParam(!source.Any(string.IsNullOrWhiteSpace), nameof(source)); + _host.CheckNonEmpty(name, nameof(name)); _host.CheckValueOrNull(advancedSettings); - _inputColumns = inputColumns.ToArray(); - OutputColumn = outputColumn; + _inputColumns = source.ToArray(); + OutputColumn = name; AdvancedSettings = new Settings(); advancedSettings?.Invoke(AdvancedSettings); @@ -312,13 +312,13 @@ public ITransformer Fit(IDataView input) if (tparams.NeedsNormalizeTransform) { - var xfCols = new (string input, string output)[textCols.Length]; + var xfCols = new (string outputColumnName, string inputColumnName)[textCols.Length]; string[] dstCols = new string[textCols.Length]; for (int i = 0; i < textCols.Length; i++) { dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer"); tempCols.Add(dstCols[i]); - xfCols[i] = (textCols[i], dstCols[i]); + xfCols[i] = (dstCols[i], textCols[i]); } view = new TextNormalizingEstimator(h, tparams.TextCase, tparams.KeepDiacritics, tparams.KeepPunctuations, tparams.KeepNumbers, xfCols).Fit(view).Transform(view); @@ -332,10 +332,10 @@ public ITransformer Fit(IDataView input) wordTokCols = new string[textCols.Length]; for (int i = 0; i < textCols.Length; i++) { - var col = new WordTokenizingTransformer.ColumnInfo(textCols[i], GenerateColumnName(view.Schema, textCols[i], "WordTokenizer")); + var col = new WordTokenizingTransformer.ColumnInfo(GenerateColumnName(view.Schema, textCols[i], "WordTokenizer"), textCols[i]); xfCols[i] = col; - wordTokCols[i] = col.Output; - tempCols.Add(col.Output); + wordTokCols[i] = col.Name; + tempCols.Add(col.Name); } view = new WordTokenizingEstimator(h, xfCols).Fit(view).Transform(view); @@ -349,7 +349,7 @@ public ITransformer Fit(IDataView input) for (int i = 0; i < wordTokCols.Length; i++) { var tempName = GenerateColumnName(view.Schema, wordTokCols[i], "StopWordsRemoverTransform"); - var col = new StopWordsRemovingTransformer.ColumnInfo(wordTokCols[i], tempName, tparams.StopwordsLanguage); + var col = new StopWordsRemovingTransformer.ColumnInfo(tempName, wordTokCols[i], tparams.StopwordsLanguage); dstCols[i] = tempName; tempCols.Add(tempName); @@ -384,12 +384,12 @@ public ITransformer Fit(IDataView input) { var srcCols = tparams.UsePredefinedStopWordRemover ? wordTokCols : textCols; charTokCols = new string[srcCols.Length]; - var xfCols = new (string input, string output)[srcCols.Length]; + var xfCols = new (string outputColumnName, string inputColumnName)[srcCols.Length]; for (int i = 0; i < srcCols.Length; i++) { - xfCols[i] = (srcCols[i], GenerateColumnName(view.Schema, srcCols[i], "CharTokenizer")); - tempCols.Add(xfCols[i].output); - charTokCols[i] = xfCols[i].output; + xfCols[i] = (GenerateColumnName(view.Schema, srcCols[i], "CharTokenizer"), srcCols[i]); + tempCols.Add(xfCols[i].outputColumnName); + charTokCols[i] = xfCols[i].outputColumnName; } view = new TokenizingByCharactersTransformer(h, columns: xfCols).Transform(view); } @@ -415,7 +415,7 @@ public ITransformer Fit(IDataView input) { var dstCol = GenerateColumnName(view.Schema, charFeatureCol, "LpCharNorm"); tempCols.Add(dstCol); - xfCols.Add(new LpNormalizingTransformer.LpNormColumnInfo(charFeatureCol, dstCol, normalizerKind: tparams.LpNormalizerKind)); + xfCols.Add(new LpNormalizingTransformer.LpNormColumnInfo(dstCol, charFeatureCol, normalizerKind: tparams.LpNormalizerKind)); charFeatureCol = dstCol; } @@ -423,7 +423,7 @@ public ITransformer Fit(IDataView input) { var dstCol = GenerateColumnName(view.Schema, wordFeatureCol, "LpWordNorm"); tempCols.Add(dstCol); - xfCols.Add(new LpNormalizingTransformer.LpNormColumnInfo(wordFeatureCol, dstCol, normalizerKind: tparams.LpNormalizerKind)); + xfCols.Add(new LpNormalizingTransformer.LpNormColumnInfo(dstCol, wordFeatureCol, normalizerKind: tparams.LpNormalizerKind)); wordFeatureCol = dstCol; } @@ -516,7 +516,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV s.UseCharExtractor = args.CharFeatureExtractor != null; }; - var estimator = new TextFeaturizingEstimator(env, args.Column.Source ?? new[] { args.Column.Name }, args.Column.Name, settings); + var estimator = new TextFeaturizingEstimator(env, args.Column.Name, args.Column.Source ?? new[] { args.Column.Name }, settings); estimator._dictionary = args.Dictionary; estimator._wordFeatureExtractor = args.WordFeatureExtractor; estimator._charFeatureExtractor = args.CharFeatureExtractor; diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs index 8c56872e35..534fb61992 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizing.cs @@ -91,7 +91,7 @@ private static VersionInfo GetVersionInfo() } private const string RegistrationName = "TextNormalizer"; - public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); + public IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); private readonly TextNormalizingEstimator.CaseNormalizationMode _textCase; private readonly bool _keepDiacritics; @@ -103,7 +103,7 @@ public TextNormalizingTransformer(IHostEnvironment env, bool keepDiacritics = TextNormalizingEstimator.Defaults.KeepDiacritics, bool keepPunctuations = TextNormalizingEstimator.Defaults.KeepPunctuations, bool keepNumbers = TextNormalizingEstimator.Defaults.KeepNumbers, - params (string input, string output)[] columns) : + params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) { _textCase = textCase; @@ -117,7 +117,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol { var type = inputSchema[srcCol].Type; if (!TextNormalizingEstimator.IsColumnTypeValid(type)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, TextNormalizingEstimator.ExpectedColumnType, type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, TextNormalizingEstimator.ExpectedColumnType, type.ToString()); } public override void Save(ModelSaveContext ctx) @@ -176,11 +176,11 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData env.CheckValue(input, nameof(input)); env.CheckValue(args.Column, nameof(args.Column)); - var cols = new (string input, string output)[args.Column.Length]; + var cols = new (string outputColumnName, string inputColumnName)[args.Column.Length]; for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - cols[i] = (item.Source ?? item.Name, item.Name); + cols[i] = (item.Name, item.Source ?? item.Name); } return new TextNormalizingTransformer(env, args.TextCase, args.KeepDiacritics, args.KeepPunctuations, args.KeepNumbers, cols).MakeDataTransform(input); } @@ -207,7 +207,7 @@ public Mapper(TextNormalizingTransformer parent, Schema inputSchema) _types = new ColumnType[_parent.ColumnPairs.Length]; for (int i = 0; i < _types.Length; i++) { - inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int srcCol); + inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int srcCol); var srcType = inputSchema[srcCol].Type; _types[i] = srcType is VectorType ? new VectorType(TextType.Instance) : srcType; } @@ -218,9 +218,9 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colIndex); Host.Assert(colIndex >= 0); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _types[i], null); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _types[i], null); } return result; } @@ -286,7 +286,7 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func act Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); disposer = null; - var srcType = input.Schema[_parent.ColumnPairs[iinfo].input].Type; + var srcType = input.Schema[_parent.ColumnPairs[iinfo].inputColumnName].Type; Host.Assert(srcType.GetItemType() is TextType); if (srcType is VectorType vectorType) @@ -455,24 +455,24 @@ internal static class Defaults internal const string ExpectedColumnType = "Text or vector of text."; /// - /// Normalizes incoming text in by changing case, removing diacritical marks, punctuation marks and/or numbers - /// and outputs new text as . + /// Normalizes incoming text in by changing case, removing diacritical marks, punctuation marks and/or numbers + /// and outputs new text as . /// /// The environment. - /// The column containing text to normalize. - /// The column containing output tokens. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Casing text using the rules of the invariant culture. /// Whether to keep diacritical marks or remove them. /// Whether to keep punctuation marks or remove them. /// Whether to keep numbers or remove them. public TextNormalizingEstimator(IHostEnvironment env, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, CaseNormalizationMode textCase = Defaults.TextCase, bool keepDiacritics = Defaults.KeepDiacritics, bool keepPunctuations = Defaults.KeepPunctuations, bool keepNumbers = Defaults.KeepNumbers) - : this(env, textCase, keepDiacritics, keepPunctuations, keepNumbers, (inputColumn, outputColumn ?? inputColumn)) + : this(env, textCase, keepDiacritics, keepPunctuations, keepNumbers, (outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -491,8 +491,9 @@ public TextNormalizingEstimator(IHostEnvironment env, bool keepDiacritics = Defaults.KeepDiacritics, bool keepPunctuations = Defaults.KeepPunctuations, bool keepNumbers = Defaults.KeepNumbers, - params (string input, string output)[] columns) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextNormalizingEstimator)), new TextNormalizingTransformer(env, textCase, keepDiacritics, keepPunctuations, keepNumbers, columns)) + params (string outputColumnName, string inputColumnName)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextNormalizingEstimator)), + new TextNormalizingTransformer(env, textCase, keepDiacritics, keepPunctuations, keepNumbers, columns)) { } @@ -502,11 +503,11 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input); + if (!inputSchema.TryFindColumn(colInfo.inputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName); if (!IsColumnTypeValid(col.ItemType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input, TextNormalizingEstimator.ExpectedColumnType, col.ItemType.ToString()); - result[colInfo.output] = new SchemaShape.Column(colInfo.output, col.Kind == SchemaShape.Column.VectorKind.Vector ? SchemaShape.Column.VectorKind.VariableVector : SchemaShape.Column.VectorKind.Scalar, col.ItemType, false); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName, TextNormalizingEstimator.ExpectedColumnType, col.ItemType.ToString()); + result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName, col.Kind == SchemaShape.Column.VectorKind.Vector ? SchemaShape.Column.VectorKind.VariableVector : SchemaShape.Column.VectorKind.Scalar, col.ItemType, false); } return new SchemaShape(result.Values); } diff --git a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs index dbfcf59f21..7720158246 100644 --- a/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs +++ b/src/Microsoft.ML.Transforms/Text/TokenizingByCharacters.cs @@ -108,13 +108,14 @@ private static VersionInfo GetVersionInfo() /// The environment. /// Whether to use marker characters to separate words. /// Pairs of columns to run the tokenization on. - public TokenizingByCharactersTransformer(IHostEnvironment env, bool useMarkerCharacters = TokenizingByCharactersEstimator.Defaults.UseMarkerCharacters, params (string input, string output)[] columns) : + public TokenizingByCharactersTransformer(IHostEnvironment env, bool useMarkerCharacters = TokenizingByCharactersEstimator.Defaults.UseMarkerCharacters, + params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) { _useMarkerChars = useMarkerCharacters; } - public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); + public IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol) { @@ -169,11 +170,11 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat env.CheckValue(input, nameof(input)); env.CheckValue(args.Column, nameof(args.Column)); - var cols = new (string input, string output)[args.Column.Length]; + var cols = new (string outputColumnName, string inputColumnName)[args.Column.Length]; for (int i = 0; i < cols.Length; i++) { var item = args.Column[i]; - cols[i] = (item.Source ?? item.Name, item.Name); + cols[i] = (item.Name, item.Source ?? item.Name); } return new TokenizingByCharactersTransformer(env, args.UseMarkerChars, cols).MakeDataTransform(input); } @@ -201,7 +202,7 @@ public Mapper(TokenizingByCharactersTransformer parent, Schema inputSchema) _type = new VectorType(keyType); _isSourceVector = new bool[_parent.ColumnPairs.Length]; for (int i = 0; i < _isSourceVector.Length; i++) - _isSourceVector[i] = inputSchema[_parent.ColumnPairs[i].input].Type is VectorType; + _isSourceVector[i] = inputSchema[_parent.ColumnPairs[i].inputColumnName].Type is VectorType; } protected override Schema.DetachedColumn[] GetOutputColumnsCore() @@ -211,14 +212,14 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() { var builder = new MetadataBuilder(); AddMetadata(i, builder); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _type, builder.GetMetadata()); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _type, builder.GetMetadata()); } return result; } private void AddMetadata(int iinfo, MetadataBuilder builder) { - builder.Add(InputSchema[_parent.ColumnPairs[iinfo].input].Metadata, name => name == MetadataUtils.Kinds.SlotNames); + builder.Add(InputSchema[_parent.ColumnPairs[iinfo].inputColumnName].Metadata, name => name == MetadataUtils.Kinds.SlotNames); ValueGetter>> getter = (ref VBuffer> dst) => { @@ -403,7 +404,7 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func act Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); disposer = null; - if (!(input.Schema[_parent.ColumnPairs[iinfo].input].Type is VectorType)) + if (!(input.Schema[_parent.ColumnPairs[iinfo].inputColumnName].Type is VectorType)) return MakeGetterOne(input, iinfo); return MakeGetterVec(input, iinfo); } @@ -560,14 +561,15 @@ internal static class Defaults internal const string ExpectedColumnType = "Text"; /// - /// Tokenize incoming text in and output the tokens as . + /// Tokenize incoming text in and output the tokens as . /// /// The environment. - /// The column containing text to tokenize. - /// The column containing output tokens. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Whether to use marker characters to separate words. - public TokenizingByCharactersEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, bool useMarkerCharacters = Defaults.UseMarkerCharacters) - : this(env, useMarkerCharacters, new[] { (inputColumn, outputColumn ?? inputColumn) }) + public TokenizingByCharactersEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, + bool useMarkerCharacters = Defaults.UseMarkerCharacters) + : this(env, useMarkerCharacters, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }) { } @@ -578,7 +580,8 @@ public TokenizingByCharactersEstimator(IHostEnvironment env, string inputColumn, /// Whether to use marker characters to separate words. /// Pairs of columns to run the tokenization on. - public TokenizingByCharactersEstimator(IHostEnvironment env, bool useMarkerCharacters = Defaults.UseMarkerCharacters, params (string input, string output)[] columns) + public TokenizingByCharactersEstimator(IHostEnvironment env, bool useMarkerCharacters = Defaults.UseMarkerCharacters, + params (string outputColumnName, string inputColumnName)[] columns) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TokenizingByCharactersEstimator)), new TokenizingByCharactersTransformer(env, useMarkerCharacters, columns)) { } @@ -589,15 +592,15 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input); + if (!inputSchema.TryFindColumn(colInfo.inputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName); if (!IsColumnTypeValid(col.ItemType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input, ExpectedColumnType, col.ItemType.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.inputColumnName, ExpectedColumnType, col.ItemType.ToString()); var metadata = new List(); if (col.Metadata.TryFindColumn(MetadataUtils.Kinds.SlotNames, out var slotMeta)) metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, slotMeta.ItemType, false)); metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.KeyValues, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); - result[colInfo.output] = new SchemaShape.Column(colInfo.output, SchemaShape.Column.VectorKind.VariableVector, NumberType.U2, true, new SchemaShape(metadata.ToArray())); + result[colInfo.outputColumnName] = new SchemaShape.Column(colInfo.outputColumnName, SchemaShape.Column.VectorKind.VariableVector, NumberType.U2, true, new SchemaShape(metadata.ToArray())); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs index 0442dc546a..a43d4625de 100644 --- a/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordBagTransform.cs @@ -144,7 +144,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV h.CheckUserArg(Utils.Size(column.Source) > 0, nameof(column.Source)); h.CheckUserArg(column.Source.All(src => !string.IsNullOrWhiteSpace(src)), nameof(column.Source)); - tokenizeColumns[iinfo] = new WordTokenizingTransformer.ColumnInfo(column.Source.Length > 1 ? column.Name : column.Source[0], column.Name); + tokenizeColumns[iinfo] = new WordTokenizingTransformer.ColumnInfo(column.Name, column.Source.Length > 1 ? column.Name : column.Source[0]); extractorArgs.Column[iinfo] = new NgramExtractorTransform.Column() @@ -353,12 +353,13 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV for (int iinfo = 0; iinfo < args.Column.Length; iinfo++) { var column = args.Column[iinfo]; - ngramColumns[iinfo] = new NgramExtractingTransformer.ColumnInfo(isTermCol[iinfo] ? column.Name : column.Source, column.Name, + ngramColumns[iinfo] = new NgramExtractingTransformer.ColumnInfo(column.Name, column.NgramLength ?? args.NgramLength, column.SkipLength ?? args.SkipLength, column.AllLengths ?? args.AllLengths, column.Weighting ?? args.Weighting, - column.MaxNumTerms ?? args.MaxNumTerms + column.MaxNumTerms ?? args.MaxNumTerms, + isTermCol[iinfo] ? column.Name : column.Source ); } diff --git a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs index 63aa847afd..e927995070 100644 --- a/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs +++ b/src/Microsoft.ML.Transforms/Text/WordEmbeddingsExtractor.cs @@ -74,7 +74,7 @@ public sealed class Arguments : TransformInputBase "vectors using a pre-trained model"; internal const string UserName = "Word Embeddings Transform"; internal const string ShortName = "WordEmbeddings"; - public const string LoaderSignature = "WordEmbeddingsTransform"; + internal const string LoaderSignature = "WordEmbeddingsTransform"; public static VersionInfo GetVersionInfo() { @@ -94,7 +94,7 @@ public static VersionInfo GetVersionInfo() private readonly int _linesToSkip; private readonly Model _currentVocab; private static Dictionary> _vocab = new Dictionary>(); - public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); + public IReadOnlyCollection<(string outputColumnName, string inputColumnName)> Columns => ColumnPairs.AsReadOnly(); private sealed class Model { @@ -153,16 +153,15 @@ public List GetWordLabels() /// public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; - public ColumnInfo(string input, string output) + public ColumnInfo(string name, string inputColumnName = null) { - Contracts.CheckNonEmpty(input, nameof(input)); - Contracts.CheckNonEmpty(output, nameof(output)); + Contracts.CheckNonEmpty(name, nameof(name)); - Input = input; - Output = output; + Name = name; + InputColumnName = inputColumnName ?? name; } } @@ -174,12 +173,12 @@ public ColumnInfo(string input, string output) /// Instantiates using the pretrained word embedding model specified by . /// /// Host Environment. - /// Name of the input column. - /// Name of the output column. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// The pretrained word embedding model. - public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string inputColumn, string outputColumn, + public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputColumnName, string inputColumnName = null, PretrainedModelKind modelKind = PretrainedModelKind.Sswe) - : this(env, modelKind, new ColumnInfo(inputColumn, outputColumn)) + : this(env, modelKind, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -187,11 +186,11 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string inputCol /// Instantiates using the custom word embedding model by loading it from the file specified by the . /// /// Host Environment. - /// Name of the input column. - /// Name of the output column. + /// Name of the column resulting from the transformation of . /// Filename for custom word embedding model. - public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string inputColumn, string outputColumn, string customModelFile) - : this(env, customModelFile, new ColumnInfo(inputColumn, outputColumn)) + /// Name of the column to transform. If set to , the value of the will be used as source. + public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string outputColumnName, string customModelFile, string inputColumnName = null) + : this(env, customModelFile, new ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -229,10 +228,10 @@ public WordEmbeddingsExtractingTransformer(IHostEnvironment env, string customMo _currentVocab = GetVocabularyDictionary(env); } - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string outputColumnName, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckValue(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } // Factory method for SignatureDataTransform. @@ -253,8 +252,8 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat { var item = args.Column[i]; cols[i] = new ColumnInfo( - item.Source ?? item.Name, - item.Name); + item.Name, + item.Source ?? item.Name); } bool customLookup = !string.IsNullOrWhiteSpace(args.CustomLookupTable); @@ -322,7 +321,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol { var colType = inputSchema[srcCol].Type; if (!(colType is VectorType vectorType && vectorType.ItemType is TextType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, "Text", inputSchema[srcCol].Type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, "Text", inputSchema[srcCol].Type.ToString()); } private sealed class Mapper : OneToOneMapperBase, ISaveAsOnnx @@ -347,15 +346,15 @@ public Mapper(WordEmbeddingsExtractingTransformer parent, Schema inputSchema) public bool CanSaveOnnx(OnnxContext ctx) => true; protected override Schema.DetachedColumn[] GetOutputColumnsCore() - => _parent.ColumnPairs.Select(x => new Schema.DetachedColumn(x.output, _outputType, null)).ToArray(); + => _parent.ColumnPairs.Select(x => new Schema.DetachedColumn(x.outputColumnName, _outputType, null)).ToArray(); public void SaveAsOnnx(OnnxContext ctx) { - foreach (var (input, output) in _parent.Columns) + foreach (var (outputColumnName, inputColumnName) in _parent.Columns) { - var srcVariableName = ctx.GetVariableName(input); + var srcVariableName = ctx.GetVariableName(inputColumnName); var schema = _parent.GetOutputSchema(InputSchema); - var dstVariableName = ctx.AddIntermediateVariable(schema[output].Type, output); + var dstVariableName = ctx.AddIntermediateVariable(schema[outputColumnName].Type, outputColumnName); SaveAsOnnxCore(ctx, srcVariableName, dstVariableName); } } @@ -795,12 +794,12 @@ public sealed class WordEmbeddingsExtractingEstimator : IEstimator /// /// The local instance of - /// The input column. - /// The optional output column. If it is null the input column will be substituted with its value. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// The embeddings to use. - public WordEmbeddingsExtractingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, + public WordEmbeddingsExtractingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe) - : this(env, modelKind, new WordEmbeddingsExtractingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn)) + : this(env, modelKind, new WordEmbeddingsExtractingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -808,11 +807,11 @@ public WordEmbeddingsExtractingEstimator(IHostEnvironment env, string inputColum /// Initializes a new instance of /// /// The local instance of - /// The input column. - /// The optional output column. If it is null the input column will be substituted with its value. + /// Name of the column resulting from the transformation of . /// The path of the pre-trained embeedings model to use. - public WordEmbeddingsExtractingEstimator(IHostEnvironment env, string inputColumn, string outputColumn, string customModelFile) - : this(env, customModelFile, new WordEmbeddingsExtractingTransformer.ColumnInfo(inputColumn, outputColumn ?? inputColumn)) + /// Name of the column to transform. + public WordEmbeddingsExtractingEstimator(IHostEnvironment env, string outputColumnName, string customModelFile, string inputColumnName = null) + : this(env, customModelFile, new WordEmbeddingsExtractingTransformer.ColumnInfo(outputColumnName, inputColumnName ?? outputColumnName)) { } @@ -823,7 +822,8 @@ public WordEmbeddingsExtractingEstimator(IHostEnvironment env, string inputColum /// The embeddings to use. /// The array columns, and per-column configurations to extract embeedings from. public WordEmbeddingsExtractingEstimator(IHostEnvironment env, - WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe, params WordEmbeddingsExtractingTransformer.ColumnInfo[] columns) + WordEmbeddingsExtractingTransformer.PretrainedModelKind modelKind = WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe, + params WordEmbeddingsExtractingTransformer.ColumnInfo[] columns) { Contracts.CheckValue(env, nameof(env)); _host = env.Register(nameof(WordEmbeddingsExtractingEstimator)); @@ -847,12 +847,12 @@ public SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in _columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (!(col.ItemType is TextType) || (col.Kind != SchemaShape.Column.VectorKind.VariableVector && col.Kind != SchemaShape.Column.VectorKind.Vector)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, new VectorType(TextType.Instance).ToString(), col.GetTypeString()); + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, new VectorType(TextType.Instance).ToString(), col.GetTypeString()); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs index 8d002ca62b..e16576e3a3 100644 --- a/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/WordHashBagProducingTransform.cs @@ -114,7 +114,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV var curTmpNames = new string[srcCount]; Contracts.Assert(uniqueSourceNames[iinfo].Length == args.Column[iinfo].Source.Length); for (int isrc = 0; isrc < srcCount; isrc++) - tokenizeColumns.Add(new WordTokenizingTransformer.ColumnInfo(args.Column[iinfo].Source[isrc], curTmpNames[isrc] = uniqueSourceNames[iinfo][isrc])); + tokenizeColumns.Add(new WordTokenizingTransformer.ColumnInfo(curTmpNames[isrc] = uniqueSourceNames[iinfo][isrc], args.Column[iinfo].Source[isrc])); tmpColNames.AddRange(curTmpNames); extractorCols[iinfo] = @@ -360,12 +360,12 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV }); } - hashColumns.Add(new HashingTransformer.ColumnInfo(termLoaderArgs == null ? column.Source[isrc] : tmpName, - tmpName, 30, column.Seed ?? args.Seed, false, column.InvertHash ?? args.InvertHash)); + hashColumns.Add(new HashingTransformer.ColumnInfo(tmpName, termLoaderArgs == null ? column.Source[isrc] : tmpName, + 30, column.Seed ?? args.Seed, false, column.InvertHash ?? args.InvertHash)); } ngramHashColumns[iinfo] = - new NgramHashingTransformer.ColumnInfo(tmpColNames[iinfo], column.Name, + new NgramHashingTransformer.ColumnInfo(column.Name, tmpColNames[iinfo], column.NgramLength ?? args.NgramLength, column.SkipLength ?? args.SkipLength, column.AllLengths ?? args.AllLengths, @@ -395,7 +395,7 @@ public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataV if (termLoaderArgs.DropUnknowns) { - var missingDropColumns = new (string input, string output)[termCols.Count]; + var missingDropColumns = new (string outputColumnName, string inputColumnName)[termCols.Count]; for (int iinfo = 0; iinfo < termCols.Count; iinfo++) missingDropColumns[iinfo] = (termCols[iinfo].Name, termCols[iinfo].Name); view = new MissingValueDroppingTransformer(h, missingDropColumns).Transform(view); diff --git a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs index 8554ce3d18..a6e8155c87 100644 --- a/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs +++ b/src/Microsoft.ML.Transforms/Text/WordTokenizing.cs @@ -113,30 +113,30 @@ private static VersionInfo GetVersionInfo() public sealed class ColumnInfo { - public readonly string Input; - public readonly string Output; + public readonly string Name; + public readonly string InputColumnName; public readonly char[] Separators; /// /// Describes how the transformer handles one column pair. /// - /// Name of input column. - /// Name of output column. + /// Name of the column resulting from the transformation of . + /// Name of column to transform. If set to , the value of the will be used as source. /// Casing text using the rules of the invariant culture. If not specified, space will be used as separator. - public ColumnInfo(string input, string output, char[] separators = null) + public ColumnInfo(string name, string inputColumnName = null, char[] separators = null) { - Input = input; - Output = output; + Name = name; + InputColumnName = inputColumnName ?? name; Separators = separators ?? new[] { ' ' }; } } public IReadOnlyCollection Columns => _columns.AsReadOnly(); private readonly ColumnInfo[] _columns; - private static (string input, string output)[] GetColumnPairs(ColumnInfo[] columns) + private static (string name, string inputColumnName)[] GetColumnPairs(ColumnInfo[] columns) { Contracts.CheckNonEmpty(columns, nameof(columns)); - return columns.Select(x => (x.Input, x.Output)).ToArray(); + return columns.Select(x => (x.Name, x.InputColumnName)).ToArray(); } public WordTokenizingTransformer(IHostEnvironment env, params ColumnInfo[] columns) : @@ -149,7 +149,7 @@ protected override void CheckInputColumn(Schema inputSchema, int col, int srcCol { var type = inputSchema[srcCol].Type; if (!WordTokenizingEstimator.IsColumnTypeValid(type)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, WordTokenizingEstimator.ExpectedColumnType, type.ToString()); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].inputColumnName, WordTokenizingEstimator.ExpectedColumnType, type.ToString()); } private WordTokenizingTransformer(IHost host, ModelLoadContext ctx) : @@ -165,7 +165,7 @@ private WordTokenizingTransformer(IHost host, ModelLoadContext ctx) : { var separators = ctx.Reader.ReadCharArray(); Contracts.CheckDecode(Utils.Size(separators) > 0); - _columns[i] = new ColumnInfo(ColumnPairs[i].input, ColumnPairs[i].output, separators); + _columns[i] = new ColumnInfo(ColumnPairs[i].outputColumnName, ColumnPairs[i].inputColumnName, separators); } } @@ -211,7 +211,7 @@ internal static IDataTransform Create(IHostEnvironment env, Arguments args, IDat { var item = args.Column[i]; var separators = args.CharArrayTermSeparators ?? PredictionUtil.SeparatorFromString(item.TermSeparators ?? args.TermSeparators); - cols[i] = new ColumnInfo(item.Source ?? item.Name, item.Name, separators); + cols[i] = new ColumnInfo(item.Name, item.Source ?? item.Name, separators); } return new WordTokenizingTransformer(env, cols).MakeDataTransform(input); @@ -239,7 +239,7 @@ public Mapper(WordTokenizingTransformer parent, Schema inputSchema) _isSourceVector = new bool[_parent._columns.Length]; for (int i = 0; i < _isSourceVector.Length; i++) { - inputSchema.TryGetColumnIndex(_parent._columns[i].Input, out int srcCol); + inputSchema.TryGetColumnIndex(_parent._columns[i].InputColumnName, out int srcCol); var srcType = inputSchema[srcCol].Type; _isSourceVector[i] = srcType is VectorType; } @@ -250,9 +250,9 @@ protected override Schema.DetachedColumn[] GetOutputColumnsCore() var result = new Schema.DetachedColumn[_parent.ColumnPairs.Length]; for (int i = 0; i < _parent.ColumnPairs.Length; i++) { - InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].inputColumnName, out int colIndex); Host.Assert(colIndex >= 0); - result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].output, _type, null); + result[i] = new Schema.DetachedColumn(_parent.ColumnPairs[i].outputColumnName, _type, null); } return result; } @@ -263,7 +263,7 @@ protected override Delegate MakeGetter(Row input, int iinfo, Func act Host.Assert(0 <= iinfo && iinfo < _parent._columns.Length); disposer = null; - input.Schema.TryGetColumnIndex(_parent._columns[iinfo].Input, out int srcCol); + input.Schema.TryGetColumnIndex(_parent._columns[iinfo].InputColumnName, out int srcCol); var srcType = input.Schema[srcCol].Type; Host.Assert(srcType.GetItemType() is TextType); @@ -366,20 +366,20 @@ void ISaveAsPfa.SaveAsPfa(BoundPfaContext ctx) for (int iinfo = 0; iinfo < _parent._columns.Length; ++iinfo) { var info = _parent._columns[iinfo]; - var srcName = info.Input; + var srcName = info.InputColumnName; string srcToken = ctx.TokenOrNullForName(srcName); if (srcToken == null) { - toHide.Add(info.Output); + toHide.Add(info.Name); continue; } var result = SaveAsPfaCore(ctx, iinfo, srcToken); if (result == null) { - toHide.Add(info.Output); + toHide.Add(info.Name); continue; } - toDeclare.Add(new KeyValuePair(info.Output, result)); + toDeclare.Add(new KeyValuePair(info.Name, result)); } ctx.Hide(toHide.ToArray()); ctx.DeclareVar(toDeclare.ToArray()); @@ -433,14 +433,14 @@ public sealed class WordTokenizingEstimator : TrivialEstimator - /// Tokenize incoming text in and output the tokens as . + /// Tokenize incoming text in and output the tokens as . /// /// The environment. - /// The column containing text to tokenize. - /// The column containing output tokens. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// The separators to use (uses space character by default). - public WordTokenizingEstimator(IHostEnvironment env, string inputColumn, string outputColumn = null, char[] separators = null) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, separators) + public WordTokenizingEstimator(IHostEnvironment env, string outputColumnName, string inputColumnName = null, char[] separators = null) + : this(env, new[] { (outputColumnName, inputColumnName ?? outputColumnName) }, separators) { } @@ -450,8 +450,8 @@ public WordTokenizingEstimator(IHostEnvironment env, string inputColumn, string /// The environment. /// Pairs of columns to run the tokenization on. /// The separators to use (uses space character by default). - public WordTokenizingEstimator(IHostEnvironment env, (string input, string output)[] columns, char[] separators = null) - : this(env, columns.Select(x => new WordTokenizingTransformer.ColumnInfo(x.input, x.output, separators)).ToArray()) + public WordTokenizingEstimator(IHostEnvironment env, (string outputColumnName, string inputColumnName)[] columns, char[] separators = null) + : this(env, columns.Select(x => new WordTokenizingTransformer.ColumnInfo(x.outputColumnName, x.inputColumnName, separators)).ToArray()) { } @@ -471,11 +471,11 @@ public override SchemaShape GetOutputSchema(SchemaShape inputSchema) var result = inputSchema.ToDictionary(x => x.Name); foreach (var colInfo in Transformer.Columns) { - if (!inputSchema.TryFindColumn(colInfo.Input, out var col)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input); + if (!inputSchema.TryFindColumn(colInfo.InputColumnName, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName); if (!IsColumnTypeValid(col.ItemType)) - throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.Input, ExpectedColumnType, col.ItemType.ToString()); - result[colInfo.Output] = new SchemaShape.Column(colInfo.Output, SchemaShape.Column.VectorKind.VariableVector, col.ItemType, false); + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.InputColumnName, ExpectedColumnType, col.ItemType.ToString()); + result[colInfo.Name] = new SchemaShape.Column(colInfo.Name, SchemaShape.Column.VectorKind.VariableVector, col.ItemType, false); } return new SchemaShape(result.Values); diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs index 4f6a98137d..191901305d 100644 --- a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs +++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs @@ -16,7 +16,7 @@ namespace Microsoft.ML.Transforms.Text /// public sealed class WordBagEstimator : TrainedWrapperEstimatorBase { - private readonly (string[] inputs, string output)[] _columns; + private readonly (string outputColumnName, string[] sourceColumnsNames)[] _columns; private readonly int _ngramLength; private readonly int _skipLength; private readonly bool _allLengths; @@ -24,50 +24,50 @@ public sealed class WordBagEstimator : TrainedWrapperEstimatorBase private readonly NgramExtractingEstimator.WeightingCriteria _weighting; /// - /// Produces a bag of counts of ngrams (sequences of consecutive words) in - /// and outputs bag of word vector as + /// Produces a bag of counts of ngrams (sequences of consecutive words) in + /// and outputs bag of word vector as /// /// The environment. - /// The column containing text to compute bag of word vector. - /// The column containing bag of word vector. Null means is replaced. + /// Name of the column resulting from the transformation of . + /// Name of the column to transform. If set to , the value of the will be used as source. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. /// Whether to include all ngram lengths up to or only . /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. public WordBagEstimator(IHostEnvironment env, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, int ngramLength = 1, int skipLength = 0, bool allLengths = true, int maxNumTerms = 10000000, NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf) - : this(env, new[] { (new[] { inputColumn }, outputColumn ?? inputColumn) }, ngramLength, skipLength, allLengths, maxNumTerms, weighting) + : this(env, outputColumnName, new[] { inputColumnName ?? outputColumnName }, ngramLength, skipLength, allLengths, maxNumTerms, weighting) { } /// - /// Produces a bag of counts of ngrams (sequences of consecutive words) in - /// and outputs bag of word vector as + /// Produces a bag of counts of ngrams (sequences of consecutive words) in + /// and outputs bag of word vector as /// /// The environment. - /// The columns containing text to compute bag of word vector. - /// The column containing output tokens. + /// The column containing output tokens. + /// The columns containing text to compute bag of word vector. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. /// Whether to include all ngram lengths up to or only . /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. public WordBagEstimator(IHostEnvironment env, - string[] inputColumns, - string outputColumn, + string outputColumnName, + string[] inputColumnNames, int ngramLength = 1, int skipLength = 0, bool allLengths = true, int maxNumTerms = 10000000, NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf) - : this(env, new[] { (inputColumns, outputColumn) }, ngramLength, skipLength, allLengths, maxNumTerms, weighting) + : this(env, new[] { (outputColumnName, inputColumnNames) }, ngramLength, skipLength, allLengths, maxNumTerms, weighting) { } @@ -83,7 +83,7 @@ public WordBagEstimator(IHostEnvironment env, /// Maximum number of ngrams to store in the dictionary. /// Statistical measure used to evaluate how important a word is to a document in a corpus. public WordBagEstimator(IHostEnvironment env, - (string[] inputs, string output)[] columns, + (string outputColumnName, string[] inputColumnNames)[] columns, int ngramLength = 1, int skipLength = 0, bool allLengths = true, @@ -91,10 +91,10 @@ public WordBagEstimator(IHostEnvironment env, NgramExtractingEstimator.WeightingCriteria weighting = NgramExtractingEstimator.WeightingCriteria.Tf) : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(WordBagEstimator))) { - foreach (var (input, output) in columns) + foreach (var (outputColumnName, inputColumnName) in columns) { - Host.CheckUserArg(Utils.Size(input) > 0, nameof(input)); - Host.CheckValue(output, nameof(input)); + Host.CheckUserArg(Utils.Size(inputColumnName) > 0, nameof(columns)); + Host.CheckValue(outputColumnName, nameof(columns)); } _columns = columns; @@ -110,7 +110,7 @@ public override TransformWrapper Fit(IDataView input) // Create arguments. var args = new WordBagBuildingTransformer.Arguments { - Column = _columns.Select(x => new WordBagBuildingTransformer.Column { Source = x.inputs, Name = x.output }).ToArray(), + Column = _columns.Select(x => new WordBagBuildingTransformer.Column { Name = x.outputColumnName, Source = x.sourceColumnsNames }).ToArray(), NgramLength = _ngramLength, SkipLength = _skipLength, AllLengths = _allLengths, @@ -128,7 +128,7 @@ public override TransformWrapper Fit(IDataView input) /// public sealed class WordHashBagEstimator : TrainedWrapperEstimatorBase { - private readonly (string[] inputs, string output)[] _columns; + private readonly (string outputColumnName, string[] inputColumnNames)[] _columns; private readonly int _hashBits; private readonly int _ngramLength; private readonly int _skipLength; @@ -138,12 +138,12 @@ public sealed class WordHashBagEstimator : TrainedWrapperEstimatorBase private readonly int _invertHash; /// - /// Produces a bag of counts of hashed ngrams in - /// and outputs bag of word vector as + /// Produces a bag of counts of hashed ngrams in + /// and outputs bag of word vector as /// /// The environment. - /// The column containing text to compute bag of word vector. - /// The column containing bag of word vector. Null means is replaced. + /// The column containing bag of word vector. Null means is replaced. + /// The column containing text to compute bag of word vector. /// Number of bits to hash into. Must be between 1 and 30, inclusive. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. @@ -155,8 +155,8 @@ public sealed class WordHashBagEstimator : TrainedWrapperEstimatorBase /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public WordHashBagEstimator(IHostEnvironment env, - string inputColumn, - string outputColumn = null, + string outputColumnName, + string inputColumnName = null, int hashBits = 16, int ngramLength = 1, int skipLength = 0, @@ -164,17 +164,17 @@ public WordHashBagEstimator(IHostEnvironment env, uint seed = 314489979, bool ordered = true, int invertHash = 0) - : this(env, new[] { (new[] { inputColumn }, outputColumn ?? inputColumn) }, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash) + : this(env, new[] { (outputColumnName, new[] { inputColumnName ?? outputColumnName }) }, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash) { } /// - /// Produces a bag of counts of hashed ngrams in - /// and outputs bag of word vector as + /// Produces a bag of counts of hashed ngrams in + /// and outputs bag of word vector as /// /// The environment. - /// The columns containing text to compute bag of word vector. - /// The column containing output tokens. + /// The column containing output tokens. + /// The columns containing text to compute bag of word vector. /// Number of bits to hash into. Must be between 1 and 30, inclusive. /// Ngram length. /// Maximum number of tokens to skip when constructing an ngram. @@ -186,8 +186,8 @@ public WordHashBagEstimator(IHostEnvironment env, /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public WordHashBagEstimator(IHostEnvironment env, - string[] inputColumns, - string outputColumn, + string outputColumnName, + string[] inputColumnNames, int hashBits = 16, int ngramLength = 1, int skipLength = 0, @@ -195,7 +195,7 @@ public WordHashBagEstimator(IHostEnvironment env, uint seed = 314489979, bool ordered = true, int invertHash = 0) - : this(env, new[] { (inputColumns, outputColumn) }, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash) + : this(env, new[] { (outputColumnName, inputColumnNames) }, hashBits, ngramLength, skipLength, allLengths, seed, ordered, invertHash) { } @@ -216,7 +216,7 @@ public WordHashBagEstimator(IHostEnvironment env, /// specifies the upper bound of the number of distinct input values mapping to a hash that should be retained. /// 0 does not retain any input values. -1 retains all input values mapping to each hash. public WordHashBagEstimator(IHostEnvironment env, - (string[] inputs, string output)[] columns, + (string outputColumnName, string[] inputColumnNames)[] columns, int hashBits = 16, int ngramLength = 1, int skipLength = 0, @@ -247,7 +247,7 @@ public override TransformWrapper Fit(IDataView input) // Create arguments. var args = new WordHashBagProducingTransformer.Arguments { - Column = _columns.Select(x => new WordHashBagProducingTransformer.Column { Source = x.inputs, Name = x.output }).ToArray(), + Column = _columns.Select(x => new WordHashBagProducingTransformer.Column { Name = x.outputColumnName ,Source = x.inputColumnNames}).ToArray(), HashBits = _hashBits, NgramLength = _ngramLength, SkipLength = _skipLength, diff --git a/test/Microsoft.ML.Benchmarks/HashBench.cs b/test/Microsoft.ML.Benchmarks/HashBench.cs index 7bea396ef7..06461276fd 100644 --- a/test/Microsoft.ML.Benchmarks/HashBench.cs +++ b/test/Microsoft.ML.Benchmarks/HashBench.cs @@ -73,7 +73,7 @@ private void InitMap(T val, ColumnType type, int hashBits = 20, ValueGetter dst = val; _inRow = RowImpl.Create(type, getter); // One million features is a nice, typical number. - var info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: hashBits); + var info = new HashingTransformer.ColumnInfo("Bar", "Foo", hashBits: hashBits); var xf = new HashingTransformer(_env, new[] { info }); var mapper = xf.GetRowToRowMapper(_inRow.Schema); var column = mapper.OutputSchema["Bar"]; diff --git a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs index 8ac1e68496..d0d9ba7d0c 100644 --- a/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs +++ b/test/Microsoft.ML.Benchmarks/PredictionEngineBench.cs @@ -83,7 +83,7 @@ public void SetupSentimentPipeline() IDataView data = reader.Read(_sentimentDataPath); - var pipeline = new TextFeaturizingEstimator(env, "SentimentText", "Features") + var pipeline = new TextFeaturizingEstimator(env, "Features", "SentimentText") .Append(env.BinaryClassification.Trainers.StochasticDualCoordinateAscent( new SdcaBinaryTrainer.Options {NumThreads = 1, ConvergenceTolerance = 1e-2f, })); diff --git a/test/Microsoft.ML.Benchmarks/README.md b/test/Microsoft.ML.Benchmarks/README.md index 244f6c2bb8..4cefdfbaab 100644 --- a/test/Microsoft.ML.Benchmarks/README.md +++ b/test/Microsoft.ML.Benchmarks/README.md @@ -84,3 +84,24 @@ public class NonTrainingBenchmark [Config(typeof(TrainConfig))] public class TrainingBenchmark ``` +## Running the `BenchmarksProjectIsNotBroken` test + +If your build is failing in the build machines, in the release configuraiton due to the `BenchmarksProjectIsNotBroken` test failing, +you can debug this test locally by: + +1- Building the solution in the release mode locally + +build.cmd -release -buildNative + +2- Changing the configuration in Visual Studio from Debug -> Release +3- Changing the annotation in the `BenchmarksProjectIsNotBroken` to replace `ConditionalTheory` with `Theory`, as below. + +```cs +[Theory] +[MemberData(nameof(GetBenchmarks))] +public void BenchmarksProjectIsNotBroken(Type type) + +``` + +4- Restart Visual Studio +5- Proceed to running the tests normally from the Test Explorer view. \ No newline at end of file diff --git a/test/Microsoft.ML.Benchmarks/RffTransform.cs b/test/Microsoft.ML.Benchmarks/RffTransform.cs index 3bf628eb15..46aed852e7 100644 --- a/test/Microsoft.ML.Benchmarks/RffTransform.cs +++ b/test/Microsoft.ML.Benchmarks/RffTransform.cs @@ -43,7 +43,7 @@ public void CV_Multiclass_Digits_RffTransform_OVAAveragedPerceptron() var data = reader.Read(_dataPath_Digits); - var pipeline = mlContext.Transforms.Projection.CreateRandomFourierFeatures("Features", "FeaturesRFF") + var pipeline = mlContext.Transforms.Projection.CreateRandomFourierFeatures("FeaturesRFF", "Features") .AppendCacheCheckpoint(mlContext) .Append(mlContext.Transforms.Concatenate("Features", "FeaturesRFF")) .Append(new ValueToKeyMappingEstimator(mlContext, "Label")) diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs index 231ffa0dff..c7fef788cd 100644 --- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs @@ -101,7 +101,7 @@ public void TrainSentiment() AllowSparse = false }; var loader = _env.Data.ReadFromTextFile(_sentimentDataPath, arguments); - var text = new TextFeaturizingEstimator(_env, "SentimentText", "WordEmbeddings", args => + var text = new TextFeaturizingEstimator(_env, "WordEmbeddings", "SentimentText", args => { args.OutputTokens = true; args.KeepPunctuations = false; @@ -110,7 +110,7 @@ public void TrainSentiment() args.UseCharExtractor = false; args.UseWordExtractor = false; }).Fit(loader).Transform(loader); - var trans = new WordEmbeddingsExtractingEstimator(_env, "WordEmbeddings_TransformedText", "Features", + var trans = new WordEmbeddingsExtractingEstimator(_env, "Features", "WordEmbeddings_TransformedText", WordEmbeddingsExtractingTransformer.PretrainedModelKind.Sswe).Fit(text).Transform(text); // Train var trainer = _env.MulticlassClassification.Trainers.StochasticDualCoordinateAscent(); diff --git a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs index dd03997b9b..47afd3cee8 100644 --- a/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs +++ b/test/Microsoft.ML.Core.Tests/UnitTests/TestEntryPoints.cs @@ -460,9 +460,9 @@ public void EntryPointCreateEnsemble() ScoreModel.Score(Env, new ScoreModel.Input { Data = splitOutput.TestData[nModels], PredictorModel = predictorModels[i] }) .ScoredData; - individualScores[i] = new ColumnCopyingTransformer(Env,( - MetadataUtils.Const.ScoreValueKind.Score, - (MetadataUtils.Const.ScoreValueKind.Score + i).ToString()) + individualScores[i] = new ColumnCopyingTransformer(Env, ( + (MetadataUtils.Const.ScoreValueKind.Score + i).ToString(), + MetadataUtils.Const.ScoreValueKind.Score) ).Transform(individualScores[i]); individualScores[i] = ColumnSelectingTransformer.CreateDrop(Env, individualScores[i], MetadataUtils.Const.ScoreValueKind.Score); @@ -746,8 +746,8 @@ public void EntryPointPipelineEnsemble() { var data = splitOutput.TrainData[i]; data = new RandomFourierFeaturizingEstimator(Env, new[] { - new RandomFourierFeaturizingTransformer.ColumnInfo("Features", "Features1", 10, false), - new RandomFourierFeaturizingTransformer.ColumnInfo("Features", "Features2", 10, false), + new RandomFourierFeaturizingTransformer.ColumnInfo("Features1", 10, false, "Features"), + new RandomFourierFeaturizingTransformer.ColumnInfo("Features2", 10, false, "Features"), }).Fit(data).Transform(data); data = new ColumnConcatenatingTransformer(Env, "Features", new[] { "Features1", "Features2" }).Transform(data); @@ -999,7 +999,7 @@ public void EntryPointPipelineEnsembleText() var data = splitOutput.TrainData[i]; if (i % 2 == 0) { - data = new TextFeaturizingEstimator(Env, "Text", "Features", args => + data = new TextFeaturizingEstimator(Env, "Features", "Text", args => { args.UseStopRemover = true; }).Fit(data).Transform(data); @@ -1198,8 +1198,8 @@ public void EntryPointMulticlassPipelineEnsemble() { var data = splitOutput.TrainData[i]; data = new RandomFourierFeaturizingEstimator(Env, new[] { - new RandomFourierFeaturizingTransformer.ColumnInfo("Features", "Features1", 10, false), - new RandomFourierFeaturizingTransformer.ColumnInfo("Features", "Features2", 10, false), + new RandomFourierFeaturizingTransformer.ColumnInfo("Features1", 10, false, "Features"), + new RandomFourierFeaturizingTransformer.ColumnInfo("Features2", 10, false, "Features"), }).Fit(data).Transform(data); data = new ColumnConcatenatingTransformer(Env, "Features", new[] { "Features1", "Features2" }).Transform(data); diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index df0bfd2185..4cf6548d6c 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -79,7 +79,7 @@ module SmokeTest1 = let ml = MLContext(seed = new System.Nullable(1), conc = 1) let data = ml.Data.ReadFromTextFile(testDataPath, hasHeader = true) - let pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") + let pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") .Append(ml.BinaryClassification.Trainers.FastTree(numLeaves = 5, numTrees = 5)) let model = pipeline.Fit(data) @@ -119,7 +119,7 @@ module SmokeTest2 = let ml = MLContext(seed = new System.Nullable(1), conc = 1) let data = ml.Data.ReadFromTextFile(testDataPath, hasHeader = true) - let pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") + let pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") .Append(ml.BinaryClassification.Trainers.FastTree(numLeaves = 5, numTrees = 5)) let model = pipeline.Fit(data) @@ -156,7 +156,7 @@ module SmokeTest3 = let ml = MLContext(seed = new System.Nullable(1), conc = 1) let data = ml.Data.ReadFromTextFile(testDataPath, hasHeader = true) - let pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") + let pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") .Append(ml.BinaryClassification.Trainers.FastTree(numLeaves = 5, numTrees = 5)) let model = pipeline.Fit(data) diff --git a/test/Microsoft.ML.OnnxTransformTest/DnnImageFeaturizerTest.cs b/test/Microsoft.ML.OnnxTransformTest/DnnImageFeaturizerTest.cs index 588e9a018a..722b634547 100644 --- a/test/Microsoft.ML.OnnxTransformTest/DnnImageFeaturizerTest.cs +++ b/test/Microsoft.ML.OnnxTransformTest/DnnImageFeaturizerTest.cs @@ -78,7 +78,7 @@ void TestDnnImageFeaturizer() var xyData = new List { new TestDataXY() { A = new float[inputSize] } }; var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } }; var sizeData = new List { new TestDataSize() { data_0 = new float[2] } }; - var pipe = new DnnImageFeaturizerEstimator(Env, m => m.ModelSelector.ResNet18(m.Environment, m.InputColumn, m.OutputColumn), "data_0", "output_1"); + var pipe = new DnnImageFeaturizerEstimator(Env, "output_1", m => m.ModelSelector.ResNet18(m.Environment, m.OutputColumn, m.InputColumn), "data_0"); var invalidDataWrongNames = ML.Data.ReadFromEnumerable(xyData); var invalidDataWrongTypes = ML.Data.ReadFromEnumerable(stringData); @@ -117,7 +117,7 @@ public void OnnxStatic() .Append(row => ( row.name, data_0: row.imagePath.LoadAsImage(imageFolder).Resize(imageHeight, imageWidth).ExtractPixels(interleaveArgb: true))) - .Append(row => (row.name, output_1: row.data_0.DnnImageFeaturizer(m => m.ModelSelector.ResNet18(m.Environment, m.InputColumn, m.OutputColumn)))); + .Append(row => (row.name, output_1: row.data_0.DnnImageFeaturizer(m => m.ModelSelector.ResNet18(m.Environment, m.OutputColumn, m.InputColumn)))); TestEstimatorCore(pipe.AsDynamic, data.AsDynamic); @@ -158,7 +158,7 @@ public void TestOldSavingAndLoading() var inputNames = "data_0"; var outputNames = "output_1"; - var est = new DnnImageFeaturizerEstimator(Env, m => m.ModelSelector.ResNet18(m.Environment, m.InputColumn, m.OutputColumn), inputNames, outputNames); + var est = new DnnImageFeaturizerEstimator(Env, outputNames, m => m.ModelSelector.ResNet18(m.Environment, m.OutputColumn ,m.InputColumn), inputNames); var transformer = est.Fit(dataView); var result = transformer.Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs b/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs index 280ad35c59..e06222ac2a 100644 --- a/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs +++ b/test/Microsoft.ML.OnnxTransformTest/OnnxTransformTests.cs @@ -108,7 +108,7 @@ void TestSimpleCase() var xyData = new List { new TestDataXY() { A = new float[inputSize] } }; var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } }; var sizeData = new List { new TestDataSize() { data_0 = new float[2] } }; - var pipe = new OnnxScoringEstimator(Env, modelFile, new[] { "data_0" }, new[] { "softmaxout_1" }); + var pipe = new OnnxScoringEstimator(Env, new[] { "softmaxout_1" }, new[] { "data_0" }, modelFile); var invalidDataWrongNames = ML.Data.ReadFromEnumerable(xyData); var invalidDataWrongTypes = ML.Data.ReadFromEnumerable(stringData); @@ -148,7 +148,7 @@ void TestOldSavingAndLoading(int? gpuDeviceId, bool fallbackToCpu) var inputNames = new[] { "data_0" }; var outputNames = new[] { "softmaxout_1" }; - var est = new OnnxScoringEstimator(Env, modelFile, inputNames, outputNames, gpuDeviceId, fallbackToCpu); + var est = new OnnxScoringEstimator(Env, outputNames, inputNames, modelFile, gpuDeviceId, fallbackToCpu); var transformer = est.Fit(dataView); var result = transformer.Transform(dataView); var resultRoles = new RoleMappedData(result); @@ -263,7 +263,7 @@ public void OnnxModelScenario() } }); - var onnx = new OnnxTransformer(env, modelFile, "data_0", "softmaxout_1").Transform(dataView); + var onnx = new OnnxTransformer(env, "softmaxout_1", modelFile, "data_0").Transform(dataView); onnx.Schema.TryGetColumnIndex("softmaxout_1", out int score); @@ -299,7 +299,7 @@ public void OnnxModelMultiInput() inb = new float[] {1,2,3,4,5} } }); - var onnx = new OnnxTransformer(env, modelFile, new[] { "ina", "inb" }, new[] { "outa", "outb" }).Transform(dataView); + var onnx = new OnnxTransformer(env, new[] { "outa", "outb" }, new[] { "ina", "inb" }, modelFile).Transform(dataView); onnx.Schema.TryGetColumnIndex("outa", out int scoresa); onnx.Schema.TryGetColumnIndex("outb", out int scoresb); diff --git a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs index d0062d35ec..e1666d8d77 100644 --- a/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs +++ b/test/Microsoft.ML.Predictor.Tests/TestPredictors.cs @@ -631,7 +631,7 @@ public void TestTreeEnsembleCombinerWithCategoricalSplits() var dataPath = GetDataPath("adult.tiny.with-schema.txt"); var dataView = ML.Data.ReadFromTextFile(dataPath); - var cat = new OneHotEncodingEstimator(ML, "Categories", "Features").Fit(dataView).Transform(dataView); + var cat = new OneHotEncodingEstimator(ML, "Features", "Categories").Fit(dataView).Transform(dataView); var fastTrees = new PredictorModel[3]; for (int i = 0; i < 3; i++) { diff --git a/test/Microsoft.ML.Tests/CachingTests.cs b/test/Microsoft.ML.Tests/CachingTests.cs index 622d0b9200..a08c1d9867 100644 --- a/test/Microsoft.ML.Tests/CachingTests.cs +++ b/test/Microsoft.ML.Tests/CachingTests.cs @@ -42,19 +42,19 @@ public void CacheCheckpointTest() { var trainData = Enumerable.Range(0, 100).Select(c => new MyData()).ToArray(); - var pipe = ML.Transforms.CopyColumns("Features", "F1") - .Append(ML.Transforms.Normalize("F1", "Norm1")) - .Append(ML.Transforms.Normalize("F1", "Norm2", Transforms.Normalizers.NormalizingEstimator.NormalizerMode.MeanVariance)); + var pipe = ML.Transforms.CopyColumns("F1", "Features") + .Append(ML.Transforms.Normalize("Norm1", "F1")) + .Append(ML.Transforms.Normalize("Norm2", "F1", Transforms.Normalizers.NormalizingEstimator.NormalizerMode.MeanVariance)); pipe.Fit(ML.Data.ReadFromEnumerable(trainData)); Assert.True(trainData.All(x => x.AccessCount == 2)); trainData = Enumerable.Range(0, 100).Select(c => new MyData()).ToArray(); - pipe = ML.Transforms.CopyColumns("Features", "F1") + pipe = ML.Transforms.CopyColumns("F1", "Features") .AppendCacheCheckpoint(ML) - .Append(ML.Transforms.Normalize("F1", "Norm1")) - .Append(ML.Transforms.Normalize("F1", "Norm2", Transforms.Normalizers.NormalizingEstimator.NormalizerMode.MeanVariance)); + .Append(ML.Transforms.Normalize("Norm1", "F1")) + .Append(ML.Transforms.Normalize("Norm2", "F1", Transforms.Normalizers.NormalizingEstimator.NormalizerMode.MeanVariance)); pipe.Fit(ML.Data.ReadFromEnumerable(trainData)); diff --git a/test/Microsoft.ML.Tests/ImagesTests.cs b/test/Microsoft.ML.Tests/ImagesTests.cs index 9810559d3d..9e2552237f 100644 --- a/test/Microsoft.ML.Tests/ImagesTests.cs +++ b/test/Microsoft.ML.Tests/ImagesTests.cs @@ -44,10 +44,10 @@ public void TestEstimatorChain() } }, new MultiFileSource(dataFile)); - var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImagePath", "ImageReal")) - .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", 100, 100)) - .Append(new ImagePixelExtractingEstimator(env, "ImageReal", "ImagePixels")) - .Append(new ImageGrayscalingEstimator(env, ("ImageReal", "ImageGray"))); + var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImageReal", "ImagePath")) + .Append(new ImageResizingEstimator(env, "ImageReal", 100, 100, "ImageReal")) + .Append(new ImagePixelExtractingEstimator(env, "ImagePixels", "ImageReal")) + .Append(new ImageGrayscalingEstimator(env, ("ImageGray", "ImageReal"))); TestEstimatorCore(pipe, data, null, invalidData); Done(); @@ -68,10 +68,10 @@ public void TestEstimatorSaveLoad() } }, new MultiFileSource(dataFile)); - var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImagePath", "ImageReal")) - .Append(new ImageResizingEstimator(env, "ImageReal", "ImageReal", 100, 100)) - .Append(new ImagePixelExtractingEstimator(env, "ImageReal", "ImagePixels")) - .Append(new ImageGrayscalingEstimator(env, ("ImageReal", "ImageGray"))); + var pipe = new ImageLoadingEstimator(env, imageFolder, ("ImageReal", "ImagePath")) + .Append(new ImageResizingEstimator(env, "ImageReal", 100, 100, "ImageReal")) + .Append(new ImagePixelExtractingEstimator(env, "ImagePixels", "ImageReal")) + .Append(new ImageGrayscalingEstimator(env, ("ImageGray", "ImageReal"))); pipe.GetOutputSchema(Core.Data.SchemaShape.Create(data.Schema)); var model = pipe.Fit(data); @@ -106,8 +106,8 @@ public void TestSaveImages() new TextLoader.Column("Name", DataKind.TX, 1), } }, new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(env, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(env, "ImageReal", "ImageCropped", 100, 100, ImageResizerTransformer.ResizingKind.IsoPad).Transform(images); + var images = new ImageLoaderTransformer(env, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(env, "ImageCropped", 100, 100, "ImageReal", ImageResizerTransformer.ResizingKind.IsoPad).Transform(images); cropped.Schema.TryGetColumnIndex("ImagePath", out int pathColumn); cropped.Schema.TryGetColumnIndex("ImageCropped", out int cropBitmapColumn); @@ -145,11 +145,11 @@ public void TestGreyscaleTransformImages() new TextLoader.Column("Name", DataKind.TX, 1), } }, new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(env, imageFolder, ("ImagePath", "ImageReal")).Transform(data); + var images = new ImageLoaderTransformer(env, imageFolder, ("ImageReal", "ImagePath")).Transform(data); - var cropped = new ImageResizerTransformer(env, "ImageReal", "ImageCropped", imageWidth, imageHeight).Transform(images); + var cropped = new ImageResizerTransformer(env, "ImageCropped", imageWidth, imageHeight, "ImageReal").Transform(images); - IDataView grey = new ImageGrayscaleTransformer(env, ("ImageCropped", "ImageGrey")).Transform(cropped); + IDataView grey = new ImageGrayscaleTransformer(env, ("ImageGrey", "ImageCropped")).Transform(cropped); var fname = nameof(TestGreyscaleTransformImages) + "_model.zip"; var fh = env.CreateOutputFile(fname); @@ -196,17 +196,17 @@ public void TestBackAndForthConversionWithAlphaInterleave() new TextLoader.Column("Name", DataKind.TX, 1), } }, new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(env, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(env, "ImageReal", "ImageCropped", imageWidth, imageHeight).Transform(images); + var images = new ImageLoaderTransformer(env, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(env, "ImageCropped", imageWidth, imageHeight, "ImageReal").Transform(images); - var pixels = new ImagePixelExtractorTransformer(env, "ImageCropped", "ImagePixels", ImagePixelExtractorTransformer.ColorBits.All, true, 2f / 255, 127.5f).Transform(cropped); + var pixels = new ImagePixelExtractorTransformer(env, "ImagePixels", "ImageCropped", ImagePixelExtractorTransformer.ColorBits.All, true, 2f / 255, 127.5f).Transform(cropped); IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() { InterleaveArgb = true, Offset = -1f, Scale = 255f / 2, Column = new VectorToImageTransform.Column[1]{ - new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} + new VectorToImageTransform.Column() { Name = "ImageRestored" , Source= "ImagePixels", ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} } }, pixels); @@ -263,9 +263,9 @@ public void TestBackAndForthConversionWithoutAlphaInterleave() new TextLoader.Column("Name", DataKind.TX, 1), } }, new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(env, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(env, "ImageReal", "ImageCropped", imageWidth, imageHeight).Transform(images); - var pixels = new ImagePixelExtractorTransformer(env, "ImageCropped", "ImagePixels", ImagePixelExtractorTransformer.ColorBits.Rgb, true, 2f / 255, 127.5f).Transform(cropped); + var images = new ImageLoaderTransformer(env, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(env, "ImageCropped", imageWidth, imageHeight, "ImageReal").Transform(images); + var pixels = new ImagePixelExtractorTransformer(env, "ImagePixels", "ImageCropped", ImagePixelExtractorTransformer.ColorBits.Rgb, true, 2f / 255, 127.5f).Transform(cropped); IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() { @@ -273,7 +273,7 @@ public void TestBackAndForthConversionWithoutAlphaInterleave() Offset = -1f, Scale = 255f / 2, Column = new VectorToImageTransform.Column[1]{ - new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} + new VectorToImageTransform.Column() { Name = "ImageRestored" , Source= "ImagePixels", ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} } }, pixels); @@ -330,9 +330,9 @@ public void TestBackAndForthConversionWithAlphaNoInterleave() new TextLoader.Column("Name", DataKind.TX, 1), } }, new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(env, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(env, "ImageReal", "ImageCropped", imageWidth, imageHeight).Transform(images); - var pixels = new ImagePixelExtractorTransformer(env, "ImageCropped", "ImagePixels", ImagePixelExtractorTransformer.ColorBits.All, false, 2f / 255, 127.5f).Transform(cropped); + var images = new ImageLoaderTransformer(env, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(env, "ImageCropped", imageWidth, imageHeight, "ImageReal").Transform(images); + var pixels = new ImagePixelExtractorTransformer(env, "ImagePixels", "ImageCropped", ImagePixelExtractorTransformer.ColorBits.All, false, 2f / 255, 127.5f).Transform(cropped); IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() { @@ -340,7 +340,7 @@ public void TestBackAndForthConversionWithAlphaNoInterleave() Offset = -1f, Scale = 255f / 2, Column = new VectorToImageTransform.Column[1]{ - new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} + new VectorToImageTransform.Column() { Name = "ImageRestored" , Source= "ImagePixels", ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} } }, pixels); @@ -397,9 +397,9 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleave() new TextLoader.Column("Name", DataKind.TX, 1), } }, new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(env, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(env, "ImageReal", "ImageCropped", imageWidth, imageHeight).Transform(images); - var pixels = new ImagePixelExtractorTransformer(env, "ImageCropped", "ImagePixels", ImagePixelExtractorTransformer.ColorBits.Rgb, false, 2f / 255, 127.5f).Transform(cropped); + var images = new ImageLoaderTransformer(env, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(env, "ImageCropped", imageWidth, imageHeight, "ImageReal").Transform(images); + var pixels = new ImagePixelExtractorTransformer(env, "ImagePixels", "ImageCropped", ImagePixelExtractorTransformer.ColorBits.Rgb, false, 2f / 255, 127.5f).Transform(cropped); IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() { @@ -407,7 +407,7 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleave() Offset = -1f, Scale = 255f / 2, Column = new VectorToImageTransform.Column[1]{ - new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} + new VectorToImageTransform.Column() { Name = "ImageRestored" , Source= "ImagePixels", ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} } }, pixels); @@ -464,16 +464,16 @@ public void TestBackAndForthConversionWithAlphaInterleaveNoOffset() new TextLoader.Column("Name", DataKind.TX, 1), } }, new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(env, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(env, "ImageReal", "ImageCropped", imageWidth, imageHeight).Transform(images); + var images = new ImageLoaderTransformer(env, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(env, "ImageCropped", imageWidth, imageHeight, "ImageReal").Transform(images); - var pixels = new ImagePixelExtractorTransformer(env, "ImageCropped", "ImagePixels", ImagePixelExtractorTransformer.ColorBits.All, true).Transform(cropped); + var pixels = new ImagePixelExtractorTransformer(env, "ImagePixels", "ImageCropped", ImagePixelExtractorTransformer.ColorBits.All, true).Transform(cropped); IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() { InterleaveArgb = true, Column = new VectorToImageTransform.Column[1]{ - new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} + new VectorToImageTransform.Column() { Name = "ImageRestored" , Source= "ImagePixels", ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} } }, pixels); @@ -530,16 +530,16 @@ public void TestBackAndForthConversionWithoutAlphaInterleaveNoOffset() new TextLoader.Column("Name", DataKind.TX, 1), } }, new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(env, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(env, "ImageReal", "ImageCropped", imageWidth, imageHeight).Transform(images); + var images = new ImageLoaderTransformer(env, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(env, "ImageCropped", imageWidth, imageHeight, "ImageReal").Transform(images); - var pixels = new ImagePixelExtractorTransformer(env, "ImageCropped", "ImagePixels", ImagePixelExtractorTransformer.ColorBits.Rgb, true).Transform(cropped); + var pixels = new ImagePixelExtractorTransformer(env, "ImagePixels", "ImageCropped", ImagePixelExtractorTransformer.ColorBits.Rgb, true).Transform(cropped); IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() { InterleaveArgb = true, Column = new VectorToImageTransform.Column[1]{ - new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} + new VectorToImageTransform.Column() { Name = "ImageRestored" , Source= "ImagePixels", ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} } }, pixels); @@ -596,16 +596,16 @@ public void TestBackAndForthConversionWithAlphaNoInterleaveNoOffset() new TextLoader.Column("Name", DataKind.TX, 1), } }, new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(env, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(env, "ImageReal", "ImageCropped", imageWidth, imageHeight).Transform(images); + var images = new ImageLoaderTransformer(env, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(env, "ImageCropped", imageWidth, imageHeight, "ImageReal").Transform(images); - var pixels = new ImagePixelExtractorTransformer(env, "ImageCropped", "ImagePixels", ImagePixelExtractorTransformer.ColorBits.All).Transform(cropped); + var pixels = new ImagePixelExtractorTransformer(env, "ImagePixels", "ImageCropped", ImagePixelExtractorTransformer.ColorBits.All).Transform(cropped); IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() { InterleaveArgb = false, Column = new VectorToImageTransform.Column[1]{ - new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} + new VectorToImageTransform.Column() { Name = "ImageRestored" , Source= "ImagePixels", ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=true} } }, pixels); @@ -662,15 +662,15 @@ public void TestBackAndForthConversionWithoutAlphaNoInterleaveNoOffset() new TextLoader.Column("Name", DataKind.TX, 1), } }, new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(env, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(env, "ImageReal", "ImageCropped", imageWidth, imageHeight).Transform(images); - var pixels = new ImagePixelExtractorTransformer(env, "ImageCropped", "ImagePixels").Transform(cropped); + var images = new ImageLoaderTransformer(env, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(env, "ImageCropped", imageWidth, imageHeight, "ImageReal").Transform(images); + var pixels = new ImagePixelExtractorTransformer(env, "ImagePixels", "ImageCropped").Transform(cropped); IDataView backToBitmaps = new VectorToImageTransform(env, new VectorToImageTransform.Arguments() { InterleaveArgb = false, Column = new VectorToImageTransform.Column[1]{ - new VectorToImageTransform.Column() { Source= "ImagePixels", Name = "ImageRestored" , ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} + new VectorToImageTransform.Column() { Name = "ImageRestored" , Source= "ImagePixels", ImageHeight=imageHeight, ImageWidth=imageWidth, ContainsAlpha=false} } }, pixels); diff --git a/test/Microsoft.ML.Tests/OnnxConversionTest.cs b/test/Microsoft.ML.Tests/OnnxConversionTest.cs index 5260c0578c..74abe16841 100644 --- a/test/Microsoft.ML.Tests/OnnxConversionTest.cs +++ b/test/Microsoft.ML.Tests/OnnxConversionTest.cs @@ -72,7 +72,7 @@ public void SimpleEndToEndOnnxConversionTest() // Step 3: Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); - var onnxEstimator = new OnnxScoringEstimator(mlContext, onnxModelPath, inputNames, outputNames); + var onnxEstimator = new OnnxScoringEstimator(mlContext, outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(data); var onnxResult = onnxTransformer.Transform(data); @@ -156,7 +156,7 @@ public void KmeansOnnxConversionTest() // Evaluate the saved ONNX model using the data used to train the ML.NET pipeline. string[] inputNames = onnxModel.Graph.Input.Select(valueInfoProto => valueInfoProto.Name).ToArray(); string[] outputNames = onnxModel.Graph.Output.Select(valueInfoProto => valueInfoProto.Name).ToArray(); - var onnxEstimator = new OnnxScoringEstimator(mlContext, onnxModelPath, inputNames, outputNames); + var onnxEstimator = new OnnxScoringEstimator(mlContext, outputNames, inputNames, onnxModelPath); var onnxTransformer = onnxEstimator.Fit(data); var onnxResult = onnxTransformer.Transform(data); CompareSelectedR4VectorColumns("Score", "Score0", transformedData, onnxResult, 3); @@ -453,7 +453,7 @@ public void WordEmbeddingsTest() var embedNetworkPath = GetDataPath(@"shortsentiment.emd"); var data = mlContext.Data.ReadFromTextFile(dataPath, hasHeader: false, separatorChar: '\t'); - var pipeline = mlContext.Transforms.Text.ExtractWordEmbeddings("Tokens", embedNetworkPath, "Embed"); + var pipeline = mlContext.Transforms.Text.ExtractWordEmbeddings("Embed", embedNetworkPath, "Tokens"); var model = pipeline.Fit(data); var transformedData = model.Transform(data); diff --git a/test/Microsoft.ML.Tests/RangeFilterTests.cs b/test/Microsoft.ML.Tests/RangeFilterTests.cs index 8775a789d2..0dab9b59f5 100644 --- a/test/Microsoft.ML.Tests/RangeFilterTests.cs +++ b/test/Microsoft.ML.Tests/RangeFilterTests.cs @@ -29,7 +29,7 @@ public void RangeFilterTest() var cnt = data1.GetColumn(ML, "Floats").Count(); Assert.Equal(2L, cnt); - data = ML.Transforms.Conversion.Hash("Strings", "Key", hashBits: 20).Fit(data).Transform(data); + data = ML.Transforms.Conversion.Hash("Key", "Strings", hashBits: 20).Fit(data).Transform(data); var data2 = ML.Data.FilterByKeyColumnFraction(data, "Key", upperBound: 0.5); cnt = data2.GetColumn(ML, "Floats").Count(); Assert.Equal(1L, cnt); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs index 4f25bbec30..e2a1036f28 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/CookbookSamples/CookbookSamplesDynamicApi.cs @@ -167,7 +167,7 @@ private ITransformer TrainOnIris(string irisDataPath) // Use the multi-class SDCA model to predict the label using features. .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent()) // Apply the inverse conversion from 'PredictedLabel' column back to string value. - .Append(mlContext.Transforms.Conversion.MapKeyToValue(("PredictedLabel", "Data"))); + .Append(mlContext.Transforms.Conversion.MapKeyToValue(("Data", "PredictedLabel"))); // Train the model. var model = dynamicPipeline.Fit(trainData); @@ -216,9 +216,9 @@ private void NormalizationWorkout(string dataPath) // Apply all kinds of standard ML.NET normalization to the raw features. var pipeline = mlContext.Transforms.Normalize( - new NormalizingEstimator.MinMaxColumn("Features", "MinMaxNormalized", fixZero: true), - new NormalizingEstimator.MeanVarColumn("Features", "MeanVarNormalized", fixZero: true), - new NormalizingEstimator.BinningColumn("Features", "BinNormalized", numBins: 256)); + new NormalizingEstimator.MinMaxColumn("MinMaxNormalized", "Features", fixZero: true), + new NormalizingEstimator.MeanVarColumn("MeanVarNormalized", "Features", fixZero: true), + new NormalizingEstimator.BinningColumn("BinNormalized", "Features", numBins: 256)); // Let's train our pipeline of normalizers, and then apply it to the same data. var normalizedData = pipeline.Fit(trainData).Transform(trainData); @@ -267,26 +267,26 @@ private void TextFeaturizationOn(string dataPath) // Apply various kinds of text operations supported by ML.NET. var dynamicPipeline = // One-stop shop to run the full text featurization. - mlContext.Transforms.Text.FeaturizeText("Message", "TextFeatures") + mlContext.Transforms.Text.FeaturizeText("TextFeatures", "Message") // Normalize the message for later transforms - .Append(mlContext.Transforms.Text.NormalizeText("Message", "NormalizedMessage")) + .Append(mlContext.Transforms.Text.NormalizeText("NormalizedMessage", "Message")) // NLP pipeline 1: bag of words. - .Append(new WordBagEstimator(mlContext, "NormalizedMessage", "BagOfWords")) + .Append(new WordBagEstimator(mlContext, "BagOfWords", "NormalizedMessage")) // NLP pipeline 2: bag of bigrams, using hashes instead of dictionary indices. - .Append(new WordHashBagEstimator(mlContext, "NormalizedMessage", "BagOfBigrams", + .Append(new WordHashBagEstimator(mlContext, "BagOfBigrams","NormalizedMessage", ngramLength: 2, allLengths: false)) // NLP pipeline 3: bag of tri-character sequences with TF-IDF weighting. - .Append(mlContext.Transforms.Text.TokenizeCharacters("Message", "MessageChars")) - .Append(new NgramExtractingEstimator(mlContext, "MessageChars", "BagOfTrichar", + .Append(mlContext.Transforms.Text.TokenizeCharacters("MessageChars", "Message")) + .Append(new NgramExtractingEstimator(mlContext, "BagOfTrichar", "MessageChars", ngramLength: 3, weighting: NgramExtractingEstimator.WeightingCriteria.TfIdf)) // NLP pipeline 4: word embeddings. - .Append(mlContext.Transforms.Text.TokenizeWords("NormalizedMessage", "TokenizedMessage")) - .Append(mlContext.Transforms.Text.ExtractWordEmbeddings("TokenizedMessage", "Embeddings", + .Append(mlContext.Transforms.Text.TokenizeWords("TokenizedMessage", "NormalizedMessage")) + .Append(mlContext.Transforms.Text.ExtractWordEmbeddings("Embeddings", "TokenizedMessage", WordEmbeddingsExtractingTransformer.PretrainedModelKind.GloVeTwitter25D)); // Let's train our pipeline, and then apply it to the same data. @@ -339,12 +339,12 @@ private void CategoricalFeaturizationOn(params string[] dataPath) // Build several alternative featurization pipelines. var dynamicPipeline = // Convert each categorical feature into one-hot encoding independently. - mlContext.Transforms.Categorical.OneHotEncoding("CategoricalFeatures", "CategoricalOneHot") + mlContext.Transforms.Categorical.OneHotEncoding("CategoricalOneHot", "CategoricalFeatures") // Convert all categorical features into indices, and build a 'word bag' of these. - .Append(mlContext.Transforms.Categorical.OneHotEncoding("CategoricalFeatures", "CategoricalBag", OneHotEncodingTransformer.OutputKind.Bag)) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("CategoricalBag", "CategoricalFeatures", OneHotEncodingTransformer.OutputKind.Bag)) // One-hot encode the workclass column, then drop all the categories that have fewer than 10 instances in the train set. - .Append(mlContext.Transforms.Categorical.OneHotEncoding("Workclass", "WorkclassOneHot")) - .Append(mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("WorkclassOneHot", "WorkclassOneHotTrimmed", count: 10)); + .Append(mlContext.Transforms.Categorical.OneHotEncoding("WorkclassOneHot", "Workclass")) + .Append(mlContext.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("WorkclassOneHotTrimmed", "WorkclassOneHot", count: 10)); // Let's train our pipeline, and then apply it to the same data. var transformedData = dynamicPipeline.Fit(data).Transform(data); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs index 11b34438e6..a4e3afc2cc 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/CrossValidation.cs @@ -26,7 +26,7 @@ void CrossValidation() var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); // Pipeline. - var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") + var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( new SdcaBinaryTrainer.Options { ConvergenceTolerance = 1f, NumThreads = 1, })); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs index ee7561ea95..60fad2c0a3 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Evaluation.cs @@ -24,7 +24,7 @@ public void Evaluation() // Pipeline. var pipeline = ml.Data.CreateTextLoader(TestDatasets.Sentiment.GetLoaderColumns(), hasHeader: true) - .Append(ml.Transforms.Text.FeaturizeText("SentimentText", "Features")) + .Append(ml.Transforms.Text.FeaturizeText("Features", "SentimentText")) .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( new SdcaBinaryTrainer.Options { NumThreads = 1 })); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs index 7a75415c61..f1abe56543 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/FileBasedSavingOfData.cs @@ -27,7 +27,7 @@ void FileBasedSavingOfData() var ml = new MLContext(seed: 1, conc: 1); var src = new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)); var trainData = ml.Data.CreateTextLoader(TestDatasets.Sentiment.GetLoaderColumns(), hasHeader: true) - .Append(ml.Transforms.Text.FeaturizeText("SentimentText", "Features")) + .Append(ml.Transforms.Text.FeaturizeText("Features", "SentimentText")) .Fit(src).Read(src); var path = DeleteOutputPath("i.idv"); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs index eae28ef7f4..e6b6736400 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/IntrospectiveTraining.cs @@ -31,7 +31,7 @@ public void IntrospectiveTraining() var ml = new MLContext(seed: 1, conc: 1); var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); - var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") + var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") .AppendCacheCheckpoint(ml) .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( new SdcaBinaryTrainer.Options { NumThreads = 1 })); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs index 0da3d7aaec..b4794bb23f 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/MultithreadedPrediction.cs @@ -28,7 +28,7 @@ void MultithreadedPrediction() var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); // Pipeline. - var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") + var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") .AppendCacheCheckpoint(ml) .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( new SdcaBinaryTrainer.Options { NumThreads = 1 })); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs index e6d4592a84..254dd73e45 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/ReconfigurablePrediction.cs @@ -27,7 +27,7 @@ public void ReconfigurablePrediction() var testData = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.testFilename), hasHeader: true); // Pipeline. - var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") + var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") .Fit(data); var trainer = ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs index 36625d225d..6d70b43d61 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/SimpleTrainAndPredict.cs @@ -26,7 +26,7 @@ public void SimpleTrainAndPredict() var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); // Pipeline. - var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") + var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") .AppendCacheCheckpoint(ml) .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( new SdcaBinaryTrainer.Options { NumThreads = 1 })); @@ -63,7 +63,7 @@ public void SimpleTrainAndPredictSymSGD() var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); // Pipeline. - var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") + var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") .AppendCacheCheckpoint(ml) .Append(ml.BinaryClassification.Trainers.SymbolicStochasticGradientDescent(new SymSgdClassificationTrainer.Options { diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs index 49be506fa1..6170b61a65 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainSaveModelAndPredict.cs @@ -27,7 +27,7 @@ public void TrainSaveModelAndPredict() var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); // Pipeline. - var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features") + var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText") .AppendCacheCheckpoint(ml) .Append(ml.BinaryClassification.Trainers.StochasticDualCoordinateAscent( new SdcaBinaryTrainer.Options { NumThreads = 1 })); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs index cac3c7f9be..84a8e7d173 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithInitialPredictor.cs @@ -24,7 +24,7 @@ public void TrainWithInitialPredictor() var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); // Pipeline. - var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features"); + var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText"); // Train the pipeline, prepare train set. Since it will be scanned multiple times in the subsequent trainer, we cache the // transformed data in memory. diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs index 64e0b42587..9fde458562 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/TrainWithValidationSet.cs @@ -19,7 +19,7 @@ public void TrainWithValidationSet() var ml = new MLContext(seed: 1, conc: 1); // Pipeline. var data = ml.Data.ReadFromTextFile(GetDataPath(TestDatasets.Sentiment.trainFilename), hasHeader: true); - var pipeline = ml.Transforms.Text.FeaturizeText("SentimentText", "Features"); + var pipeline = ml.Transforms.Text.FeaturizeText("Features", "SentimentText"); // Train the pipeline, prepare train and validation set. var preprocess = pipeline.Fit(data); diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs index 4d39f774c8..4b546475fd 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/Estimators/Visibility.cs @@ -24,7 +24,7 @@ void Visibility() { var ml = new MLContext(seed: 1, conc: 1); var pipeline = ml.Data.CreateTextLoader(TestDatasets.Sentiment.GetLoaderColumns(), hasHeader: true) - .Append(ml.Transforms.Text.FeaturizeText("SentimentText", "Features", s => s.OutputTokens = true)); + .Append(ml.Transforms.Text.FeaturizeText("Features", "SentimentText", s => s.OutputTokens = true)); var src = new MultiFileSource(GetDataPath(TestDatasets.Sentiment.trainFilename)); var data = pipeline.Fit(src).Read(src); diff --git a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs index 22d648967d..f84279c32d 100644 --- a/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/IrisPlantClassificationWithStringLabelTests.cs @@ -35,11 +35,11 @@ public void TrainAndPredictIrisModelWithStringLabelTest() // Create Estimator var pipe = mlContext.Transforms.Concatenate("Features", "SepalLength", "SepalWidth", "PetalLength", "PetalWidth") .Append(mlContext.Transforms.Normalize("Features")) - .Append(mlContext.Transforms.Conversion.MapValueToKey("IrisPlantType", "Label"), TransformerScope.TrainTest) + .Append(mlContext.Transforms.Conversion.MapValueToKey("Label", "IrisPlantType"), TransformerScope.TrainTest) .AppendCacheCheckpoint(mlContext) .Append(mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent( new SdcaMultiClassTrainer.Options { NumThreads = 1 })) - .Append(mlContext.Transforms.Conversion.MapKeyToValue(("PredictedLabel", "Plant"))); + .Append(mlContext.Transforms.Conversion.MapKeyToValue(("Plant", "PredictedLabel"))); // Train the pipeline var trainedModel = pipe.Fit(trainData); diff --git a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs index 8140220de4..f5088c7a0f 100644 --- a/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/TensorflowTests.cs @@ -34,10 +34,10 @@ public void TensorFlowTransforCifarEndToEndTest() } }, new MultiFileSource(dataFile)); - var pipeEstimator = new ImageLoadingEstimator(mlContext, imageFolder, ("ImagePath", "ImageReal")) - .Append(new ImageResizingEstimator(mlContext, "ImageReal", "ImageCropped", imageHeight, imageWidth)) - .Append(new ImagePixelExtractingEstimator(mlContext, "ImageCropped", "Input", interleave: true)) - .Append(new TensorFlowEstimator(mlContext, model_location, new[] { "Input" }, new[] { "Output" })) + var pipeEstimator = new ImageLoadingEstimator(mlContext, imageFolder, ("ImageReal", "ImagePath")) + .Append(new ImageResizingEstimator(mlContext, "ImageCropped", imageHeight, imageWidth, "ImageReal")) + .Append(new ImagePixelExtractingEstimator(mlContext, "Input", "ImageCropped", interleave: true)) + .Append(new TensorFlowEstimator(mlContext, new[] { "Output" }, new[] { "Input" }, model_location)) .Append(new ColumnConcatenatingEstimator(mlContext, "Features", "Output")) .Append(new ValueToKeyMappingEstimator(mlContext, "Label")) .AppendCacheCheckpoint(mlContext) diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs index 9ca6292d6c..3b66f3eda1 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/TensorflowTests.cs @@ -43,7 +43,7 @@ public void TensorFlowTransformMatrixMultiplicationTest() 2.0f, 2.0f }, b = new[] { 3.0f, 3.0f, 3.0f, 3.0f } } })); - var trans = new TensorFlowTransformer(mlContext, modelLocation, new[] { "a", "b" }, new[] { "c" }).Transform(loader); + var trans = new TensorFlowTransformer(mlContext, modelLocation, new[] { "c" }, new[] { "a", "b" }).Transform(loader); using (var cursor = trans.GetRowCursorForAllColumns()) { @@ -142,7 +142,7 @@ public void TensorFlowTransformInputOutputTypesTest() var inputs = new string[]{"f64", "f32", "i64", "i32", "i16", "i8", "u64", "u32", "u16", "u8","b"}; var outputs = new string[] { "o_f64", "o_f32", "o_i64", "o_i32", "o_i16", "o_i8", "o_u64", "o_u32", "o_u16", "o_u8", "o_b" }; - var trans = new TensorFlowTransformer(mlContext, model_location, inputs, outputs).Transform(loader); ; + var trans = new TensorFlowTransformer(mlContext, model_location, outputs, inputs).Transform(loader); ; using (var cursor = trans.GetRowCursorForAllColumns()) { @@ -233,12 +233,12 @@ public void TensorFlowTransformObjectDetectionTest() var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); var data = mlContext.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(mlContext, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(mlContext, "ImageReal", "ImageCropped", 32, 32).Transform(images); + var images = new ImageLoaderTransformer(mlContext, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(mlContext, "ImageCropped", 32, 32, "ImageReal").Transform(images); - var pixels = new ImagePixelExtractorTransformer(mlContext, "ImageCropped", "image_tensor", asFloat: false).Transform(cropped); - var tf = new TensorFlowTransformer(mlContext, modelLocation, new[] { "image_tensor" }, - new[] { "detection_boxes", "detection_scores", "num_detections", "detection_classes" }).Transform(pixels); + var pixels = new ImagePixelExtractorTransformer(mlContext, "image_tensor", "ImageCropped", asFloat: false).Transform(cropped); + var tf = new TensorFlowTransformer(mlContext, modelLocation, + new[] { "detection_boxes", "detection_scores", "num_detections", "detection_classes" }, new[] { "image_tensor" }).Transform(pixels); tf.Schema.TryGetColumnIndex("image_tensor", out int input); tf.Schema.TryGetColumnIndex("detection_boxes", out int boxes); @@ -274,10 +274,10 @@ public void TensorFlowTransformInceptionTest() var dataFile = GetDataPath("images/images.tsv"); var imageFolder = Path.GetDirectoryName(dataFile); var data = mlContext.CreateLoader("Text{col=ImagePath:TX:0 col=Name:TX:1}", new MultiFileSource(dataFile)); - var images = new ImageLoaderTransformer(mlContext, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(mlContext, "ImageReal", "ImageCropped", 224, 224).Transform(images); - var pixels = new ImagePixelExtractorTransformer(mlContext, "ImageCropped", "input").Transform(cropped); - var tf = new TensorFlowTransformer(mlContext, modelLocation, "input", "softmax2_pre_activation").Transform(pixels); + var images = new ImageLoaderTransformer(mlContext, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(mlContext, "ImageCropped", 224, 224 , "ImageReal").Transform(images); + var pixels = new ImagePixelExtractorTransformer(mlContext, "input","ImageCropped").Transform(cropped); + var tf = new TensorFlowTransformer(mlContext, modelLocation, "softmax2_pre_activation", "input").Transform(pixels); tf.Schema.TryGetColumnIndex("input", out int input); tf.Schema.TryGetColumnIndex("softmax2_pre_activation", out int b); @@ -389,8 +389,8 @@ public void TensorFlowTransformMNISTConvTest() var trainData = reader.Read(GetDataPath(TestDatasets.mnistTiny28.trainFilename)); var testData = reader.Read(GetDataPath(TestDatasets.mnistOneClass.testFilename)); - var pipe = mlContext.Transforms.CopyColumns(("Placeholder", "reshape_input")) - .Append(new TensorFlowEstimator(mlContext, "mnist_model/frozen_saved_model.pb", new[] { "Placeholder", "reshape_input" }, new[] { "Softmax", "dense/Relu" })) + var pipe = mlContext.Transforms.CopyColumns(("reshape_input","Placeholder")) + .Append(new TensorFlowEstimator(mlContext, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }, "mnist_model/frozen_saved_model.pb")) .Append(mlContext.Transforms.Concatenate("Features", "Softmax", "dense/Relu")) .Append(mlContext.MulticlassClassification.Trainers.LightGbm("Label", "Features")); @@ -429,8 +429,8 @@ public void TensorFlowTransformMNISTLRTrainingTest() var trainData = reader.Read(GetDataPath(TestDatasets.mnistTiny28.trainFilename)); var testData = reader.Read(GetDataPath(TestDatasets.mnistOneClass.testFilename)); - var pipe = mlContext.Transforms.Categorical.OneHotEncoding("Label", "OneHotLabel") - .Append(mlContext.Transforms.Normalize(new NormalizingEstimator.MinMaxColumn("Placeholder", "Features"))) + var pipe = mlContext.Transforms.Categorical.OneHotEncoding("OneHotLabel", "Label") + .Append(mlContext.Transforms.Normalize(new NormalizingEstimator.MinMaxColumn("Features", "Placeholder"))) .Append(new TensorFlowEstimator(mlContext, new TensorFlowTransformer.Arguments() { ModelLocation = model_location, @@ -447,7 +447,7 @@ public void TensorFlowTransformMNISTLRTrainingTest() ReTrain = true })) .Append(mlContext.Transforms.Concatenate("Features", "Prediction")) - .Append(mlContext.Transforms.Conversion.MapValueToKey("Label", "KeyLabel", maxNumTerms: 10)) + .Append(mlContext.Transforms.Conversion.MapValueToKey("KeyLabel","Label", maxNumTerms: 10)) .Append(mlContext.MulticlassClassification.Trainers.LightGbm("KeyLabel", "Features")); var trainedModel = pipe.Fit(trainData); @@ -544,7 +544,7 @@ private void ExecuteTFTransformMNISTConvTrainingTest(bool shuffle, int? shuffleS preprocessedTestData = testData; } - var pipe = mlContext.Transforms.CopyColumns(("Placeholder", "Features")) + var pipe = mlContext.Transforms.CopyColumns(("Features", "Placeholder")) .Append(new TensorFlowEstimator(mlContext, new TensorFlowTransformer.Arguments() { ModelLocation = modelLocation, @@ -609,8 +609,8 @@ public void TensorFlowTransformMNISTConvSavedModelTest() var trainData = reader.Read(GetDataPath(TestDatasets.mnistTiny28.trainFilename)); var testData = reader.Read(GetDataPath(TestDatasets.mnistOneClass.testFilename)); - var pipe = mlContext.Transforms.CopyColumns(("Placeholder", "reshape_input")) - .Append(new TensorFlowEstimator(mlContext, "mnist_model", new[] { "Placeholder", "reshape_input" }, new[] { "Softmax", "dense/Relu" })) + var pipe = mlContext.Transforms.CopyColumns(("reshape_input", "Placeholder")) + .Append(new TensorFlowEstimator(mlContext, new[] { "Softmax", "dense/Relu" }, new[] { "Placeholder", "reshape_input" }, "mnist_model")) .Append(mlContext.Transforms.Concatenate("Features", new[] { "Softmax", "dense/Relu" })) .Append(mlContext.MulticlassClassification.Trainers.LightGbm("Label", "Features")); @@ -736,12 +736,12 @@ public void TensorFlowTransformCifar() } ); - var pipeEstimator = new ImageLoadingEstimator(mlContext, imageFolder, ("ImagePath", "ImageReal")) - .Append(new ImageResizingEstimator(mlContext, "ImageReal", "ImageCropped", imageWidth, imageHeight)) - .Append(new ImagePixelExtractingEstimator(mlContext, "ImageCropped", "Input", interleave: true)); + var pipeEstimator = new ImageLoadingEstimator(mlContext, imageFolder, ("ImageReal", "ImagePath")) + .Append(new ImageResizingEstimator(mlContext, "ImageCropped", imageWidth, imageHeight, "ImageReal")) + .Append(new ImagePixelExtractingEstimator(mlContext, "Input", "ImageCropped", interleave: true)); var pixels = pipeEstimator.Fit(data).Transform(data); - IDataView trans = new TensorFlowTransformer(mlContext, tensorFlowModel, "Input", "Output").Transform(pixels); + IDataView trans = new TensorFlowTransformer(mlContext, tensorFlowModel, "Output","Input").Transform(pixels); trans.Schema.TryGetColumnIndex("Output", out int output); using (var cursor = trans.GetRowCursor(trans.Schema["Output"])) @@ -779,10 +779,10 @@ public void TensorFlowTransformCifarSavedModel() new TextLoader.Column("Name", DataKind.TX, 1), } ); - var images = new ImageLoaderTransformer(mlContext, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(mlContext, "ImageReal", "ImageCropped", imageWidth, imageHeight).Transform(images); - var pixels = new ImagePixelExtractorTransformer(mlContext, "ImageCropped", "Input", interleave: true).Transform(cropped); - IDataView trans = new TensorFlowTransformer(mlContext, tensorFlowModel, "Input", "Output").Transform(pixels); + var images = new ImageLoaderTransformer(mlContext, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(mlContext, "ImageCropped", imageWidth, imageHeight, "ImageReal").Transform(images); + var pixels = new ImagePixelExtractorTransformer(mlContext, "Input", "ImageCropped", interleave: true).Transform(cropped); + IDataView trans = new TensorFlowTransformer(mlContext, tensorFlowModel, "Output", "Input").Transform(pixels); trans.Schema.TryGetColumnIndex("Output", out int output); using (var cursor = trans.GetRowCursorForAllColumns()) @@ -831,14 +831,14 @@ public void TensorFlowTransformCifarInvalidShape() new TextLoader.Column("Name", DataKind.TX, 1), } ); - var images = new ImageLoaderTransformer(mlContext, imageFolder, ("ImagePath", "ImageReal")).Transform(data); - var cropped = new ImageResizerTransformer(mlContext, "ImageReal", "ImageCropped", imageWidth, imageHeight).Transform(images); - var pixels = new ImagePixelExtractorTransformer(mlContext, "ImageCropped", "Input").Transform(cropped); + var images = new ImageLoaderTransformer(mlContext, imageFolder, ("ImageReal", "ImagePath")).Transform(data); + var cropped = new ImageResizerTransformer(mlContext, "ImageCropped", imageWidth, imageHeight, "ImageReal").Transform(images); + var pixels = new ImagePixelExtractorTransformer(mlContext, "Input", "ImageCropped").Transform(cropped); var thrown = false; try { - IDataView trans = new TensorFlowTransformer(mlContext, modelLocation, "Input", "Output").Transform(pixels); + IDataView trans = new TensorFlowTransformer(mlContext, modelLocation, "Output", "Input").Transform(pixels); } catch { diff --git a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs index e38265745a..c7abf45ed5 100644 --- a/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TensorFlowEstimatorTests.cs @@ -78,7 +78,7 @@ void TestSimpleCase() var xyData = new List { new TestDataXY() { A = new float[4], B = new float[4] } }; var stringData = new List { new TestDataDifferntType() { a = new string[4], b = new string[4] } }; var sizeData = new List { new TestDataSize() { a = new float[2], b = new float[2] } }; - var pipe = new TensorFlowEstimator(Env, modelFile, new[] { "a", "b" }, new[] { "c" }); + var pipe = new TensorFlowEstimator(Env, new[] { "c" }, new[] { "a", "b" }, modelFile); var invalidDataWrongNames = ML.Data.ReadFromEnumerable(xyData); var invalidDataWrongTypes = ML.Data.ReadFromEnumerable( stringData); @@ -119,7 +119,7 @@ void TestOldSavingAndLoading() b = new[] { 10.0f, 8.0f, 6.0f, 6.0f } } })); - var est = new TensorFlowEstimator(Env, modelFile, new[] { "a", "b" }, new[] { "c" }); + var est = new TensorFlowEstimator(Env, new[] { "c" }, new[] { "a", "b" }, modelFile); var transformer = est.Fit(dataView); var result = transformer.Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/TermEstimatorTests.cs b/test/Microsoft.ML.Tests/TermEstimatorTests.cs index 2f19676d8e..2e43973321 100644 --- a/test/Microsoft.ML.Tests/TermEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/TermEstimatorTests.cs @@ -71,13 +71,13 @@ void TestDifferentTypes() }, new MultiFileSource(dataPath)); var pipe = new ValueToKeyMappingEstimator(Env, new[]{ - new ValueToKeyMappingTransformer.ColumnInfo("float1", "TermFloat1"), - new ValueToKeyMappingTransformer.ColumnInfo("float4", "TermFloat4"), - new ValueToKeyMappingTransformer.ColumnInfo("double1", "TermDouble1"), - new ValueToKeyMappingTransformer.ColumnInfo("double4", "TermDouble4"), - new ValueToKeyMappingTransformer.ColumnInfo("int1", "TermInt1"), - new ValueToKeyMappingTransformer.ColumnInfo("text1", "TermText1"), - new ValueToKeyMappingTransformer.ColumnInfo("text2", "TermText2") + new ValueToKeyMappingTransformer.ColumnInfo("TermFloat1", "float1"), + new ValueToKeyMappingTransformer.ColumnInfo("TermFloat4", "float4"), + new ValueToKeyMappingTransformer.ColumnInfo("TermDouble1", "double1"), + new ValueToKeyMappingTransformer.ColumnInfo("TermDouble4", "double4"), + new ValueToKeyMappingTransformer.ColumnInfo("TermInt1", "int1"), + new ValueToKeyMappingTransformer.ColumnInfo("TermText1", "text1"), + new ValueToKeyMappingTransformer.ColumnInfo("TermText2", "text2") }); var data = loader.Read(dataPath); data = TakeFilter.Create(Env, data, 10); @@ -102,9 +102,9 @@ void TestSimpleCase() var stringData = new[] { new TestClassDifferentTypes { A = "1", B = "c", C = "b" } }; var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new ValueToKeyMappingEstimator(Env, new[]{ - new ValueToKeyMappingTransformer.ColumnInfo("A", "TermA"), - new ValueToKeyMappingTransformer.ColumnInfo("B", "TermB"), - new ValueToKeyMappingTransformer.ColumnInfo("C", "TermC") + new ValueToKeyMappingTransformer.ColumnInfo("TermA", "A"), + new ValueToKeyMappingTransformer.ColumnInfo("TermB", "B"), + new ValueToKeyMappingTransformer.ColumnInfo("TermC", "C") }); var invalidData = ML.Data.ReadFromEnumerable(xydata); var validFitNotValidTransformData = ML.Data.ReadFromEnumerable(stringData); @@ -117,9 +117,9 @@ void TestOldSavingAndLoading() var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var dataView = ML.Data.ReadFromEnumerable(data); var est = new ValueToKeyMappingEstimator(Env, new[]{ - new ValueToKeyMappingTransformer.ColumnInfo("A", "TermA"), - new ValueToKeyMappingTransformer.ColumnInfo("B", "TermB"), - new ValueToKeyMappingTransformer.ColumnInfo("C", "TermC") + new ValueToKeyMappingTransformer.ColumnInfo("TermA", "A"), + new ValueToKeyMappingTransformer.ColumnInfo("TermB", "B"), + new ValueToKeyMappingTransformer.ColumnInfo("TermC", "C") }); var transformer = est.Fit(dataView); var result = transformer.Transform(dataView); @@ -139,7 +139,7 @@ void TestMetadataCopy() var data = new[] { new TestMetaClass() { Term = "A", NotUsed = 1 }, new TestMetaClass() { Term = "B" }, new TestMetaClass() { Term = "C" } }; var dataView = ML.Data.ReadFromEnumerable(data); var termEst = new ValueToKeyMappingEstimator(Env, new[] { - new ValueToKeyMappingTransformer.ColumnInfo("Term" ,"T") }); + new ValueToKeyMappingTransformer.ColumnInfo("T", "Term") }); var termTransformer = termEst.Fit(dataView); var result = termTransformer.Transform(dataView); diff --git a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs index 28f9305322..91f5d7a7b1 100644 --- a/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs +++ b/test/Microsoft.ML.Tests/TrainerEstimators/TrainerEstimators.cs @@ -128,7 +128,7 @@ public void TestEstimatorMultiClassNaiveBayesTrainer() }).Read(GetDataPath(TestDatasets.Sentiment.trainFilename)); // Pipeline. - var pipeline = new TextFeaturizingEstimator(Env, "SentimentText", "Features"); + var pipeline = new TextFeaturizingEstimator(Env,"Features" ,"SentimentText"); return (pipeline, data); } @@ -150,8 +150,8 @@ public void TestEstimatorMultiClassNaiveBayesTrainer() // Pipeline. var pipeline = new ValueToKeyMappingEstimator(Env, new[]{ - new ValueToKeyMappingTransformer.ColumnInfo("Workclass", "Group"), - new ValueToKeyMappingTransformer.ColumnInfo("Label", "Label0") }); + new ValueToKeyMappingTransformer.ColumnInfo("Group", "Workclass"), + new ValueToKeyMappingTransformer.ColumnInfo("Label0", "Label") }); return (pipeline, data); } diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalHashTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalHashTests.cs index 782e672392..b41c12313d 100644 --- a/test/Microsoft.ML.Tests/Transformers/CategoricalHashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CategoricalHashTests.cs @@ -52,10 +52,10 @@ public void CategoricalHashWorkout() var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new OneHotHashEncodingEstimator(Env, new[]{ - new OneHotHashEncodingEstimator.ColumnInfo("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag), - new OneHotHashEncodingEstimator.ColumnInfo("A", "CatB", OneHotEncodingTransformer.OutputKind.Bin), - new OneHotHashEncodingEstimator.ColumnInfo("A", "CatC", OneHotEncodingTransformer.OutputKind.Ind), - new OneHotHashEncodingEstimator.ColumnInfo("A", "CatD", OneHotEncodingTransformer.OutputKind.Key), + new OneHotHashEncodingEstimator.ColumnInfo("CatA", "A", OneHotEncodingTransformer.OutputKind.Bag), + new OneHotHashEncodingEstimator.ColumnInfo("CatB", "A", OneHotEncodingTransformer.OutputKind.Bin), + new OneHotHashEncodingEstimator.ColumnInfo("CatC", "A", OneHotEncodingTransformer.OutputKind.Ind), + new OneHotHashEncodingEstimator.ColumnInfo("CatD", "A", OneHotEncodingTransformer.OutputKind.Key), }); TestEstimatorCore(pipe, dataView); @@ -95,7 +95,7 @@ public void CategoricalHashStatic() { var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); var savedData = TakeFilter.Create(Env, est.Fit(data).Transform(data).AsDynamic, 4); - var view = ColumnSelectingTransformer.CreateKeep(Env, savedData, new[] { "A", "B", "C", "D", "E", "F" }); + var view = ColumnSelectingTransformer.CreateKeep(Env, savedData, new[] { "A", "B", "C", "D", "E", "F" }); using (var fs = File.Create(outputPath)) DataSaverUtils.SaveDataView(ch, saver, view, fs, keepHidden: true); } @@ -114,16 +114,16 @@ public void TestMetadataPropagation() var dataView = ML.Data.ReadFromEnumerable(data); var bagPipe = new OneHotHashEncodingEstimator(Env, - new OneHotHashEncodingEstimator.ColumnInfo("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag, invertHash: -1), - new OneHotHashEncodingEstimator.ColumnInfo("B", "CatB", OneHotEncodingTransformer.OutputKind.Bag, invertHash: -1), - new OneHotHashEncodingEstimator.ColumnInfo("C", "CatC", OneHotEncodingTransformer.OutputKind.Bag, invertHash: -1), - new OneHotHashEncodingEstimator.ColumnInfo("D", "CatD", OneHotEncodingTransformer.OutputKind.Bag, invertHash: -1), - new OneHotHashEncodingEstimator.ColumnInfo("E", "CatE", OneHotEncodingTransformer.OutputKind.Ind, invertHash: -1), - new OneHotHashEncodingEstimator.ColumnInfo("F", "CatF", OneHotEncodingTransformer.OutputKind.Ind, invertHash: -1), - new OneHotHashEncodingEstimator.ColumnInfo("A", "CatG", OneHotEncodingTransformer.OutputKind.Key, invertHash: -1), - new OneHotHashEncodingEstimator.ColumnInfo("B", "CatH", OneHotEncodingTransformer.OutputKind.Key, invertHash: -1), - new OneHotHashEncodingEstimator.ColumnInfo("A", "CatI", OneHotEncodingTransformer.OutputKind.Bin, invertHash: -1), - new OneHotHashEncodingEstimator.ColumnInfo("B", "CatJ", OneHotEncodingTransformer.OutputKind.Bin, invertHash: -1)); + new OneHotHashEncodingEstimator.ColumnInfo("CatA", "A", OneHotEncodingTransformer.OutputKind.Bag, invertHash: -1), + new OneHotHashEncodingEstimator.ColumnInfo("CatB", "B", OneHotEncodingTransformer.OutputKind.Bag, invertHash: -1), + new OneHotHashEncodingEstimator.ColumnInfo("CatC", "C", OneHotEncodingTransformer.OutputKind.Bag, invertHash: -1), + new OneHotHashEncodingEstimator.ColumnInfo("CatD", "D", OneHotEncodingTransformer.OutputKind.Bag, invertHash: -1), + new OneHotHashEncodingEstimator.ColumnInfo("CatE", "E", OneHotEncodingTransformer.OutputKind.Ind, invertHash: -1), + new OneHotHashEncodingEstimator.ColumnInfo("CatF", "F", OneHotEncodingTransformer.OutputKind.Ind, invertHash: -1), + new OneHotHashEncodingEstimator.ColumnInfo("CatG", "A", OneHotEncodingTransformer.OutputKind.Key, invertHash: -1), + new OneHotHashEncodingEstimator.ColumnInfo("CatH", "B", OneHotEncodingTransformer.OutputKind.Key, invertHash: -1), + new OneHotHashEncodingEstimator.ColumnInfo("CatI", "A", OneHotEncodingTransformer.OutputKind.Bin, invertHash: -1), + new OneHotHashEncodingEstimator.ColumnInfo("CatJ", "B", OneHotEncodingTransformer.OutputKind.Bin, invertHash: -1)); var bagResult = bagPipe.Fit(dataView).Transform(dataView); ValidateMetadata(bagResult); @@ -218,9 +218,9 @@ public void TestOldSavingAndLoading() var data = new[] { new TestClass() { A = "1", B = "2", C = "3", }, new TestClass() { A = "4", B = "5", C = "6" } }; var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new OneHotHashEncodingEstimator(Env, new[]{ - new OneHotHashEncodingEstimator.ColumnInfo("A", "CatHashA"), - new OneHotHashEncodingEstimator.ColumnInfo("B", "CatHashB"), - new OneHotHashEncodingEstimator.ColumnInfo("C", "CatHashC") + new OneHotHashEncodingEstimator.ColumnInfo("CatHashA", "A"), + new OneHotHashEncodingEstimator.ColumnInfo("CatHashB", "B"), + new OneHotHashEncodingEstimator.ColumnInfo("CatHashC", "C") }); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs index dd9b9ac849..d252f3d1d9 100644 --- a/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CategoricalTests.cs @@ -60,10 +60,10 @@ public void CategoricalWorkout() var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new OneHotEncodingEstimator(Env, new[]{ - new OneHotEncodingEstimator.ColumnInfo("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag), - new OneHotEncodingEstimator.ColumnInfo("A", "CatB", OneHotEncodingTransformer.OutputKind.Bin), - new OneHotEncodingEstimator.ColumnInfo("A", "CatC", OneHotEncodingTransformer.OutputKind.Ind), - new OneHotEncodingEstimator.ColumnInfo("A", "CatD", OneHotEncodingTransformer.OutputKind.Key), + new OneHotEncodingEstimator.ColumnInfo("CatA", "A", OneHotEncodingTransformer.OutputKind.Bag), + new OneHotEncodingEstimator.ColumnInfo("CatB", "A", OneHotEncodingTransformer.OutputKind.Bin), + new OneHotEncodingEstimator.ColumnInfo("CatC", "A", OneHotEncodingTransformer.OutputKind.Ind), + new OneHotEncodingEstimator.ColumnInfo("CatD", "A", OneHotEncodingTransformer.OutputKind.Key), }); TestEstimatorCore(pipe, dataView); @@ -78,10 +78,10 @@ public void CategoricalOneHotHashEncoding() var mlContext = new MLContext(); var dataView = mlContext.Data.ReadFromEnumerable(data); - var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatA", 3, 0, OneHotEncodingTransformer.OutputKind.Bag) - .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatB", 2, 0, OneHotEncodingTransformer.OutputKind.Key)) - .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatC", 3, 0, OneHotEncodingTransformer.OutputKind.Ind)) - .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("A", "CatD", 2, 0, OneHotEncodingTransformer.OutputKind.Bin)); + var pipe = mlContext.Transforms.Categorical.OneHotHashEncoding("CatA", "A", 3, 0, OneHotEncodingTransformer.OutputKind.Bag) + .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("CatB", "A", 2, 0, OneHotEncodingTransformer.OutputKind.Key)) + .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("CatC", "A", 3, 0, OneHotEncodingTransformer.OutputKind.Ind)) + .Append(mlContext.Transforms.Categorical.OneHotHashEncoding("CatD", "A", 2, 0, OneHotEncodingTransformer.OutputKind.Bin)); TestEstimatorCore(pipe, dataView); Done(); @@ -95,10 +95,10 @@ public void CategoricalOneHotEncoding() var mlContext = new MLContext(); var dataView = mlContext.Data.ReadFromEnumerable(data); - var pipe = mlContext.Transforms.Categorical.OneHotEncoding("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag) - .Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatB", OneHotEncodingTransformer.OutputKind.Key)) - .Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatC", OneHotEncodingTransformer.OutputKind.Ind)) - .Append(mlContext.Transforms.Categorical.OneHotEncoding("A", "CatD", OneHotEncodingTransformer.OutputKind.Bin)); + var pipe = mlContext.Transforms.Categorical.OneHotEncoding("CatA", "A", OneHotEncodingTransformer.OutputKind.Bag) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("CatB", "A", OneHotEncodingTransformer.OutputKind.Key)) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("CatC", "A", OneHotEncodingTransformer.OutputKind.Ind)) + .Append(mlContext.Transforms.Categorical.OneHotEncoding("CatD", "A", OneHotEncodingTransformer.OutputKind.Bin)); TestEstimatorCore(pipe, dataView); Done(); @@ -121,7 +121,7 @@ public void CategoricalOneHotEncodingFromSideData() sideDataBuilder.AddColumn("Hello", "hello", "my", "friend"); var sideData = sideDataBuilder.GetDataView(); - var ci = new OneHotEncodingEstimator.ColumnInfo("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag); + var ci = new OneHotEncodingEstimator.ColumnInfo("CatA", "A", OneHotEncodingTransformer.OutputKind.Bag); var pipe = new OneHotEncodingEstimator(mlContext, new[] { ci }, sideData); var output = pipe.Fit(dataView).Transform(dataView); @@ -184,18 +184,18 @@ public void TestMetadataPropagation() var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new OneHotEncodingEstimator(Env, new[] { - new OneHotEncodingEstimator.ColumnInfo("A", "CatA", OneHotEncodingTransformer.OutputKind.Bag), - new OneHotEncodingEstimator.ColumnInfo("B", "CatB", OneHotEncodingTransformer.OutputKind.Bag), - new OneHotEncodingEstimator.ColumnInfo("C", "CatC", OneHotEncodingTransformer.OutputKind.Bag), - new OneHotEncodingEstimator.ColumnInfo("D", "CatD", OneHotEncodingTransformer.OutputKind.Bag), - new OneHotEncodingEstimator.ColumnInfo("E", "CatE", OneHotEncodingTransformer.OutputKind.Ind), - new OneHotEncodingEstimator.ColumnInfo("F", "CatF", OneHotEncodingTransformer.OutputKind.Ind), - new OneHotEncodingEstimator.ColumnInfo("G", "CatG", OneHotEncodingTransformer.OutputKind.Key), - new OneHotEncodingEstimator.ColumnInfo("H", "CatH", OneHotEncodingTransformer.OutputKind.Key), - new OneHotEncodingEstimator.ColumnInfo("A", "CatI", OneHotEncodingTransformer.OutputKind.Bin), - new OneHotEncodingEstimator.ColumnInfo("B", "CatJ", OneHotEncodingTransformer.OutputKind.Bin), - new OneHotEncodingEstimator.ColumnInfo("C", "CatK", OneHotEncodingTransformer.OutputKind.Bin), - new OneHotEncodingEstimator.ColumnInfo("D", "CatL", OneHotEncodingTransformer.OutputKind.Bin) }); + new OneHotEncodingEstimator.ColumnInfo("CatA", "A", OneHotEncodingTransformer.OutputKind.Bag), + new OneHotEncodingEstimator.ColumnInfo("CatB", "B", OneHotEncodingTransformer.OutputKind.Bag), + new OneHotEncodingEstimator.ColumnInfo("CatC", "C", OneHotEncodingTransformer.OutputKind.Bag), + new OneHotEncodingEstimator.ColumnInfo("CatD", "D", OneHotEncodingTransformer.OutputKind.Bag), + new OneHotEncodingEstimator.ColumnInfo("CatE", "E",OneHotEncodingTransformer.OutputKind.Ind), + new OneHotEncodingEstimator.ColumnInfo("CatF", "F", OneHotEncodingTransformer.OutputKind.Ind), + new OneHotEncodingEstimator.ColumnInfo("CatG", "G", OneHotEncodingTransformer.OutputKind.Key), + new OneHotEncodingEstimator.ColumnInfo("CatH", "H", OneHotEncodingTransformer.OutputKind.Key), + new OneHotEncodingEstimator.ColumnInfo("CatI", "A", OneHotEncodingTransformer.OutputKind.Bin), + new OneHotEncodingEstimator.ColumnInfo("CatJ", "B", OneHotEncodingTransformer.OutputKind.Bin), + new OneHotEncodingEstimator.ColumnInfo("CatK", "C", OneHotEncodingTransformer.OutputKind.Bin), + new OneHotEncodingEstimator.ColumnInfo("CatL", "D", OneHotEncodingTransformer.OutputKind.Bin) }); var result = pipe.Fit(dataView).Transform(dataView); @@ -307,9 +307,9 @@ public void TestOldSavingAndLoading() var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new OneHotEncodingEstimator(Env, new[]{ - new OneHotEncodingEstimator.ColumnInfo("A", "TermA"), - new OneHotEncodingEstimator.ColumnInfo("B", "TermB"), - new OneHotEncodingEstimator.ColumnInfo("C", "TermC") + new OneHotEncodingEstimator.ColumnInfo("TermA", "A"), + new OneHotEncodingEstimator.ColumnInfo("TermB", "B"), + new OneHotEncodingEstimator.ColumnInfo("TermC", "C") }); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/Transformers/CharTokenizeTests.cs b/test/Microsoft.ML.Tests/Transformers/CharTokenizeTests.cs index 04a9e378d3..b42c84be00 100644 --- a/test/Microsoft.ML.Tests/Transformers/CharTokenizeTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CharTokenizeTests.cs @@ -41,7 +41,7 @@ public void CharTokenizeWorkout() var dataView = ML.Data.ReadFromEnumerable(data); var invalidData = new[] { new TestWrong() { A = 1, B = new float[2] { 2,3} } }; var invalidDataView = ML.Data.ReadFromEnumerable(invalidData); - var pipe = new TokenizingByCharactersEstimator(Env, columns: new[] { ("A", "TokenizeA"), ("B", "TokenizeB") }); + var pipe = new TokenizingByCharactersEstimator(Env, columns: new[] { ("TokenizeA", "A"), ("TokenizeB", "B") }); TestEstimatorCore(pipe, dataView, invalidInput:invalidDataView); Done(); @@ -59,7 +59,7 @@ public void TestOldSavingAndLoading() var data = new[] { new TestClass() { A = "This is a good sentence.", B = new string[2] { "Much words", "Wow So Cool" } } }; var dataView = ML.Data.ReadFromEnumerable(data); - var pipe = new TokenizingByCharactersEstimator(Env, columns: new[] { ("A", "TokenizeA"), ("B", "TokenizeB") }); + var pipe = new TokenizingByCharactersEstimator(Env, columns: new[] { ("TokenizeA", "A"), ("TokenizeB", "B") }); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); using (var ms = new MemoryStream()) diff --git a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs index 4105c6aa12..ef48abfc43 100644 --- a/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ConvertTests.cs @@ -78,8 +78,8 @@ public void TestConvertWorkout() var data = new[] { new TestClass() { A = 1, B = new int[2] { 1,4 } }, new TestClass() { A = 2, B = new int[2] { 3,4 } }}; var dataView = ML.Data.ReadFromEnumerable(data); - var pipe = new TypeConvertingEstimator(Env, columns: new[] {new TypeConvertingTransformer.ColumnInfo("A", "ConvA", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("B", "ConvB", DataKind.R4)}); + var pipe = new TypeConvertingEstimator(Env, columns: new[] {new TypeConvertingTransformer.ColumnInfo("ConvA", DataKind.R4, "A"), + new TypeConvertingTransformer.ColumnInfo("ConvB", DataKind.R4, "B")}); TestEstimatorCore(pipe, dataView); var allTypesData = new[] @@ -118,18 +118,18 @@ public void TestConvertWorkout() var allTypesDataView = ML.Data.ReadFromEnumerable(allTypesData); var allTypesPipe = new TypeConvertingEstimator(Env, columns: new[] { - new TypeConvertingTransformer.ColumnInfo("AA", "ConvA", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("AB", "ConvB", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("AC", "ConvC", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("AD", "ConvD", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("AE", "ConvE", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("AF", "ConvF", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("AG", "ConvG", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("AH", "ConvH", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("AK", "ConvK", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("AL", "ConvL", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("AM", "ConvM", DataKind.R4), - new TypeConvertingTransformer.ColumnInfo("AN", "ConvN", DataKind.R4)} + new TypeConvertingTransformer.ColumnInfo("ConvA", DataKind.R4, "AA"), + new TypeConvertingTransformer.ColumnInfo("ConvB", DataKind.R4, "AB"), + new TypeConvertingTransformer.ColumnInfo("ConvC", DataKind.R4, "AC"), + new TypeConvertingTransformer.ColumnInfo("ConvD", DataKind.R4, "AD"), + new TypeConvertingTransformer.ColumnInfo("ConvE", DataKind.R4, "AE"), + new TypeConvertingTransformer.ColumnInfo("ConvF", DataKind.R4, "AF"), + new TypeConvertingTransformer.ColumnInfo("ConvG", DataKind.R4, "AG"), + new TypeConvertingTransformer.ColumnInfo("ConvH", DataKind.R4, "AH"), + new TypeConvertingTransformer.ColumnInfo("ConvK", DataKind.R4, "AK"), + new TypeConvertingTransformer.ColumnInfo("ConvL", DataKind.R4, "AL"), + new TypeConvertingTransformer.ColumnInfo("ConvM", DataKind.R4, "AM"), + new TypeConvertingTransformer.ColumnInfo("ConvN", DataKind.R4, "AN")} ); TestEstimatorCore(allTypesPipe, allTypesDataView); @@ -163,7 +163,7 @@ public void ValueToKeyFromSideData() var sideData = sideDataBuilder.GetDataView(); // For some reason the column info is on the *transformer*, not the estimator. Already tracked as issue #1760. - var ci = new ValueToKeyMappingTransformer.ColumnInfo("A", "CatA"); + var ci = new ValueToKeyMappingTransformer.ColumnInfo("CatA", "A"); var pipe = mlContext.Transforms.Conversion.MapValueToKey(new[] { ci }, sideData); var output = pipe.Fit(dataView).Transform(dataView); @@ -192,8 +192,8 @@ public void TestOldSavingAndLoading() var data = new[] { new TestClass() { A = 1, B = new int[2] { 1,4 } }, new TestClass() { A = 2, B = new int[2] { 3,4 } }}; var dataView = ML.Data.ReadFromEnumerable(data); - var pipe = new TypeConvertingEstimator(Env, columns: new[] {new TypeConvertingTransformer.ColumnInfo("A", "ConvA", DataKind.R8), - new TypeConvertingTransformer.ColumnInfo("B", "ConvB", DataKind.R8)}); + var pipe = new TypeConvertingEstimator(Env, columns: new[] {new TypeConvertingTransformer.ColumnInfo("ConvA", DataKind.R8, "A"), + new TypeConvertingTransformer.ColumnInfo("ConvB", DataKind.R8, "B")}); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); @@ -211,11 +211,11 @@ public void TestMetadata() var data = new[] { new MetaClass() { A = 1, B = "A" }, new MetaClass() { A = 2, B = "B" }}; var pipe = new OneHotEncodingEstimator(Env, new[] { - new OneHotEncodingEstimator.ColumnInfo("A", "CatA", OneHotEncodingTransformer.OutputKind.Ind), - new OneHotEncodingEstimator.ColumnInfo("B", "CatB", OneHotEncodingTransformer.OutputKind.Key) + new OneHotEncodingEstimator.ColumnInfo("CatA", "A", OneHotEncodingTransformer.OutputKind.Ind), + new OneHotEncodingEstimator.ColumnInfo("CatB", "B", OneHotEncodingTransformer.OutputKind.Key) }).Append(new TypeConvertingEstimator(Env, new[] { - new TypeConvertingTransformer.ColumnInfo("CatA", "ConvA", DataKind.R8), - new TypeConvertingTransformer.ColumnInfo("CatB", "ConvB", DataKind.U2) + new TypeConvertingTransformer.ColumnInfo("ConvA", DataKind.R8, "CatA"), + new TypeConvertingTransformer.ColumnInfo("ConvB", DataKind.U2, "CatB") })); var dataView = ML.Data.ReadFromEnumerable(data); dataView = pipe.Fit(dataView).Transform(dataView); @@ -272,8 +272,8 @@ public void TypeConvertKeyBackCompatTest() } var outDataOld = modelOld.Transform(dataView); - var modelNew = ML.Transforms.Conversion.ConvertType(new[] { new TypeConvertingTransformer.ColumnInfo("key", "convertedKey", - DataKind.U8, new KeyCount(4)) }).Fit(dataView); + var modelNew = ML.Transforms.Conversion.ConvertType(new[] { new TypeConvertingTransformer.ColumnInfo("convertedKey", + DataKind.U8, "key", new KeyCount(4)) }).Fit(dataView); var outDataNew = modelNew.Transform(dataView); // Check that old and new model produce the same result. diff --git a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs index 8b00f43220..e14960e18d 100644 --- a/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/CopyColumnEstimatorTests.cs @@ -41,7 +41,7 @@ void TestWorking() var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var env = new MLContext(); var dataView = env.Data.ReadFromEnumerable(data); - var est = new ColumnCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var est = new ColumnCopyingEstimator(env, new[] { ("D", "A"), ("E", "B"), ("F", "A") }); var transformer = est.Fit(dataView); var result = transformer.Transform(dataView); ValidateCopyColumnTransformer(result); @@ -53,7 +53,7 @@ void TestBadOriginalSchema() var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var env = new MLContext(); var dataView = env.Data.ReadFromEnumerable(data); - var est = new ColumnCopyingEstimator(env, new[] { ("D", "A"), ("B", "E") }); + var est = new ColumnCopyingEstimator(env, new[] { ("A", "D"), ("E", "B") }); try { var transformer = est.Fit(dataView); @@ -72,7 +72,7 @@ void TestBadTransformSchema() var env = new MLContext(); var dataView = env.Data.ReadFromEnumerable(data); var xyDataView = env.Data.ReadFromEnumerable(xydata); - var est = new ColumnCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var est = new ColumnCopyingEstimator(env, new[] { ("D", "A"), ("E", "B"), ("F", "A") }); var transformer = est.Fit(dataView); try { @@ -90,7 +90,7 @@ void TestSavingAndLoading() var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var env = new MLContext(); var dataView = env.Data.ReadFromEnumerable(data); - var est = new ColumnCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var est = new ColumnCopyingEstimator(env, new[] { ("D", "A"), ("E", "B"), ("F", "A") }); var transformer = est.Fit(dataView); using (var ms = new MemoryStream()) { @@ -108,7 +108,7 @@ void TestOldSavingAndLoading() var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var env = new MLContext(); var dataView = env.Data.ReadFromEnumerable(data); - var est = new ColumnCopyingEstimator(env, new[] { ("A", "D"), ("B", "E"), ("A", "F") }); + var est = new ColumnCopyingEstimator(env, new[] { ("D", "A"), ("E", "B"), ("F", "A") }); var transformer = est.Fit(dataView); var result = transformer.Transform(dataView); var resultRoles = new RoleMappedData(result); @@ -131,7 +131,7 @@ void TestMetadataCopy() { Column = new[] { new ValueToKeyMappingTransformer.Column() { Source = "Term", Name = "T" } } }, dataView); - var est = new ColumnCopyingEstimator(env, "T", "T1"); + var est = new ColumnCopyingEstimator(env, "T1", "T"); var transformer = est.Fit(term); var result = transformer.Transform(term); result.Schema.TryGetColumnIndex("T", out int termIndex); diff --git a/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs b/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs index 181d758dd1..9a9161a443 100644 --- a/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/FeatureSelectionTests.cs @@ -39,10 +39,10 @@ public void FeatureSelectionWorkout() text: ctx.LoadFloat(1)), hasHeader: true) .Read(sentimentDataPath); - var est = new WordBagEstimator(ML, "text", "bag_of_words") + var est = new WordBagEstimator(ML, "bag_of_words", "text") .AppendCacheCheckpoint(ML) - .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("bag_of_words", "bag_of_words_count", 10) - .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("bag_of_words", "bag_of_words_mi", labelColumn: "label"))); + .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnCount("bag_of_words_count", "bag_of_words", 10) + .Append(ML.Transforms.FeatureSelection.SelectFeaturesBasedOnMutualInformation("bag_of_words_mi", "bag_of_words", labelColumn: "label"))); var outputPath = GetOutputPath("FeatureSelection", "featureselection.tsv"); using (var ch = Env.Start("save")) @@ -74,12 +74,12 @@ public void DropSlotsTransform() var columns = new[] { - new SlotsDroppingTransformer.ColumnInfo("VectorFloat", "dropped1", (min: 0, max: 1)), - new SlotsDroppingTransformer.ColumnInfo("VectorFloat", "dropped2"), - new SlotsDroppingTransformer.ColumnInfo("ScalarFloat", "dropped3", (min:0, max: 3)), - new SlotsDroppingTransformer.ColumnInfo("VectorFloat", "dropped4", (min: 1, max: 2)), - new SlotsDroppingTransformer.ColumnInfo("VectorDouble", "dropped5", (min: 1, null)), - new SlotsDroppingTransformer.ColumnInfo("VectorFloat", "dropped6", (min: 100, null)) + new SlotsDroppingTransformer.ColumnInfo("dropped1", "VectorFloat", (min: 0, max: 1)), + new SlotsDroppingTransformer.ColumnInfo("dropped2", "VectorFloat"), + new SlotsDroppingTransformer.ColumnInfo("dropped3", "ScalarFloat", (min:0, max: 3)), + new SlotsDroppingTransformer.ColumnInfo("dropped4", "VectorFloat", (min: 1, max: 2)), + new SlotsDroppingTransformer.ColumnInfo("dropped5", "VectorDouble", (min: 1, null)), + new SlotsDroppingTransformer.ColumnInfo("dropped6", "VectorFloat", (min: 100, null)) }; var trans = new SlotsDroppingTransformer(ML, columns); @@ -115,13 +115,13 @@ public void CountFeatureSelectionWorkout() var data = ML.Data.Cache(reader.Read(new MultiFileSource(dataPath)).AsDynamic); var columns = new[] { - new CountFeatureSelectingEstimator.ColumnInfo("VectorDouble", "FeatureSelectDouble", minCount: 1), - new CountFeatureSelectingEstimator.ColumnInfo("ScalarFloat", "ScalFeatureSelectMissing690", minCount: 690), - new CountFeatureSelectingEstimator.ColumnInfo("ScalarFloat", "ScalFeatureSelectMissing100", minCount: 100), - new CountFeatureSelectingEstimator.ColumnInfo("VectorDouble", "VecFeatureSelectMissing690", minCount: 690), - new CountFeatureSelectingEstimator.ColumnInfo("VectorDouble", "VecFeatureSelectMissing100", minCount: 100) + new CountFeatureSelectingEstimator.ColumnInfo("FeatureSelectDouble", "VectorDouble", minCount: 1), + new CountFeatureSelectingEstimator.ColumnInfo("ScalFeatureSelectMissing690", "ScalarFloat", minCount: 690), + new CountFeatureSelectingEstimator.ColumnInfo("ScalFeatureSelectMissing100", "ScalarFloat", minCount: 100), + new CountFeatureSelectingEstimator.ColumnInfo("VecFeatureSelectMissing690", "VectorDouble", minCount: 690), + new CountFeatureSelectingEstimator.ColumnInfo("VecFeatureSelectMissing100", "VectorDouble", minCount: 100) }; - var est = new CountFeatureSelectingEstimator(ML, "VectorFloat", "FeatureSelect", minCount: 1) + var est = new CountFeatureSelectingEstimator(ML, "FeatureSelect", "VectorFloat", minCount: 1) .Append(new CountFeatureSelectingEstimator(ML, columns)); TestEstimatorCore(est, data); @@ -156,7 +156,7 @@ public void TestCountSelectOldSavingAndLoading() var dataView = reader.Read(new MultiFileSource(dataPath)).AsDynamic; - var pipe = new CountFeatureSelectingEstimator(ML, "VectorFloat", "FeatureSelect", minCount: 1); + var pipe = new CountFeatureSelectingEstimator(ML, "FeatureSelect", "VectorFloat", minCount: 1); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); @@ -182,11 +182,11 @@ public void MutualInformationSelectionWorkout() var data = reader.Read(new MultiFileSource(dataPath)).AsDynamic; - var est = new MutualInformationFeatureSelectingEstimator(ML, "VectorFloat", "FeatureSelect", slotsInOutput: 1, labelColumn: "Label") + var est = new MutualInformationFeatureSelectingEstimator(ML, "FeatureSelect", "VectorFloat", slotsInOutput: 1, labelColumn: "Label") .Append(new MutualInformationFeatureSelectingEstimator(ML, labelColumn: "Label", slotsInOutput: 2, numBins: 100, columns: new[] { - (input: "VectorFloat", output: "out1"), - (input: "VectorDouble", output: "out2") + (name: "out1", source: "VectorFloat"), + (name: "out2", source: "VectorDouble") })); TestEstimatorCore(est, data); @@ -220,7 +220,7 @@ public void TestMutualInformationOldSavingAndLoading() var dataView = reader.Read(new MultiFileSource(dataPath)).AsDynamic; - var pipe = new MutualInformationFeatureSelectingEstimator(ML, "VectorFloat", "FeatureSelect", slotsInOutput: 1, labelColumn: "Label"); + var pipe = new MutualInformationFeatureSelectingEstimator(ML, "FeatureSelect", "VectorFloat", slotsInOutput: 1, labelColumn: "Label"); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/Transformers/HashTests.cs b/test/Microsoft.ML.Tests/Transformers/HashTests.cs index 9dfc3e6a5a..edb2e858f9 100644 --- a/test/Microsoft.ML.Tests/Transformers/HashTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/HashTests.cs @@ -47,10 +47,10 @@ public void HashWorkout() var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new HashingEstimator(Env, new[]{ - new HashingTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), - new HashingTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true), - new HashingTransformer.ColumnInfo("C", "HashC", seed:42), - new HashingTransformer.ColumnInfo("A", "HashD"), + new HashingTransformer.ColumnInfo("HashA", "A", hashBits:4, invertHash:-1), + new HashingTransformer.ColumnInfo("HashB", "B", hashBits:3, ordered:true), + new HashingTransformer.ColumnInfo("HashC", "C", seed:42), + new HashingTransformer.ColumnInfo("HashD", "A"), }); TestEstimatorCore(pipe, dataView); @@ -69,9 +69,9 @@ public void TestMetadata() var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new HashingEstimator(Env, new[] { - new HashingTransformer.ColumnInfo("A", "HashA", invertHash:1, hashBits:10), - new HashingTransformer.ColumnInfo("A", "HashAUnlim", invertHash:-1, hashBits:10), - new HashingTransformer.ColumnInfo("A", "HashAUnlimOrdered", invertHash:-1, hashBits:10, ordered:true) + new HashingTransformer.ColumnInfo("HashA", "A", invertHash:1, hashBits:10), + new HashingTransformer.ColumnInfo("HashAUnlim", "A", invertHash:-1, hashBits:10), + new HashingTransformer.ColumnInfo("HashAUnlimOrdered", "A", invertHash:-1, hashBits:10, ordered:true) }); var result = pipe.Fit(dataView).Transform(dataView); ValidateMetadata(result); @@ -109,10 +109,10 @@ public void TestOldSavingAndLoading() var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new HashingEstimator(Env, new[]{ - new HashingTransformer.ColumnInfo("A", "HashA", hashBits:4, invertHash:-1), - new HashingTransformer.ColumnInfo("B", "HashB", hashBits:3, ordered:true), - new HashingTransformer.ColumnInfo("C", "HashC", seed:42), - new HashingTransformer.ColumnInfo("A", "HashD"), + new HashingTransformer.ColumnInfo("HashA", "A", hashBits:4, invertHash:-1), + new HashingTransformer.ColumnInfo("HashB", "B", hashBits:3, ordered:true), + new HashingTransformer.ColumnInfo("HashC", "C", seed:42), + new HashingTransformer.ColumnInfo("HashD" ,"A"), }); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); @@ -133,7 +133,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe var inRow = MetadataUtils.MetadataAsRow(builder.GetMetadata()); // First do an unordered hash. - var info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: bits); + var info = new HashingTransformer.ColumnInfo("Bar", "Foo", hashBits: bits); var xf = new HashingTransformer(Env, new[] { info }); var mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out int outCol); @@ -145,7 +145,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe Assert.Equal(expected, result); // Next do an ordered hash. - info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: bits, ordered: true); + info = new HashingTransformer.ColumnInfo("Bar", "Foo", hashBits: bits, ordered: true); xf = new HashingTransformer(Env, new[] { info }); mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); @@ -163,7 +163,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe builder.Add("Foo", new VectorType(type, vecLen), (ref VBuffer dst) => denseVec.CopyTo(ref dst)); inRow = MetadataUtils.MetadataAsRow(builder.GetMetadata()); - info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: bits, ordered: false); + info = new HashingTransformer.ColumnInfo("Bar", "Foo", hashBits: bits, ordered: false); xf = new HashingTransformer(Env, new[] { info }); mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); @@ -178,7 +178,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe Assert.All(vecResult.DenseValues(), v => Assert.Equal(expected, v)); // Now do ordered with the dense vector. - info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: bits, ordered: true); + info = new HashingTransformer.ColumnInfo("Bar", "Foo", hashBits: bits, ordered: true); xf = new HashingTransformer(Env, new[] { info }); mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); @@ -197,7 +197,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe builder.Add("Foo", new VectorType(type, vecLen), (ref VBuffer dst) => sparseVec.CopyTo(ref dst)); inRow = MetadataUtils.MetadataAsRow(builder.GetMetadata()); - info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: bits, ordered: false); + info = new HashingTransformer.ColumnInfo("Bar", "Foo", hashBits: bits, ordered: false); xf = new HashingTransformer(Env, new[] { info }); mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); @@ -210,7 +210,7 @@ private void HashTestCore(T val, PrimitiveType type, uint expected, uint expe Assert.Equal(expected, vecResult.GetItemOrDefault(3)); Assert.Equal(expected, vecResult.GetItemOrDefault(7)); - info = new HashingTransformer.ColumnInfo("Foo", "Bar", hashBits: bits, ordered: true); + info = new HashingTransformer.ColumnInfo("Bar", "Foo", hashBits: bits, ordered: true); xf = new HashingTransformer(Env, new[] { info }); mapper = xf.GetRowToRowMapper(inRow.Schema); mapper.OutputSchema.TryGetColumnIndex("Bar", out outCol); diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs index 98233b215e..6099a1955f 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToBinaryVectorEstimatorTest.cs @@ -47,13 +47,13 @@ public void KeyToBinaryVectorWorkout() var dataView = ML.Data.ReadFromEnumerable(data); dataView = new ValueToKeyMappingEstimator(Env, new[]{ - new ValueToKeyMappingTransformer.ColumnInfo("A", "TermA"), - new ValueToKeyMappingTransformer.ColumnInfo("B", "TermB"), - new ValueToKeyMappingTransformer.ColumnInfo("C", "TermC", textKeyValues:true) + new ValueToKeyMappingTransformer.ColumnInfo("TermA", "A"), + new ValueToKeyMappingTransformer.ColumnInfo("TermB", "B"), + new ValueToKeyMappingTransformer.ColumnInfo("TermC", "C", textKeyValues:true) }).Fit(dataView).Transform(dataView); - var pipe = new KeyToBinaryVectorMappingEstimator(Env, new KeyToBinaryVectorMappingTransformer.ColumnInfo("TermA", "CatA"), - new KeyToBinaryVectorMappingTransformer.ColumnInfo("TermC", "CatC")); + var pipe = new KeyToBinaryVectorMappingEstimator(Env, new KeyToBinaryVectorMappingTransformer.ColumnInfo("CatA", "TermA"), + new KeyToBinaryVectorMappingTransformer.ColumnInfo("CatC", "TermC")); TestEstimatorCore(pipe, dataView); Done(); } @@ -71,8 +71,8 @@ public void KeyToBinaryVectorStatic() // Non-pigsty Term. var dynamicData = new ValueToKeyMappingEstimator(Env, new[] { - new ValueToKeyMappingTransformer.ColumnInfo("ScalarString", "A"), - new ValueToKeyMappingTransformer.ColumnInfo("VectorString", "B") }) + new ValueToKeyMappingTransformer.ColumnInfo("A", "ScalarString"), + new ValueToKeyMappingTransformer.ColumnInfo("B", "VectorString") }) .Fit(data.AsDynamic).Transform(data.AsDynamic); var data2 = dynamicData.AssertStatic(Env, ctx => ( @@ -100,18 +100,18 @@ public void TestMetadataPropagation() var dataView = ML.Data.ReadFromEnumerable(data); var termEst = new ValueToKeyMappingEstimator(Env, new[] { - new ValueToKeyMappingTransformer.ColumnInfo("A", "TA", textKeyValues: true), - new ValueToKeyMappingTransformer.ColumnInfo("B", "TB", textKeyValues: true), - new ValueToKeyMappingTransformer.ColumnInfo("C", "TC"), - new ValueToKeyMappingTransformer.ColumnInfo("D", "TD") }); + new ValueToKeyMappingTransformer.ColumnInfo("TA", "A", textKeyValues: true), + new ValueToKeyMappingTransformer.ColumnInfo("TB", "B", textKeyValues: true), + new ValueToKeyMappingTransformer.ColumnInfo("TC", "C"), + new ValueToKeyMappingTransformer.ColumnInfo("TD", "D") }); var termTransformer = termEst.Fit(dataView); dataView = termTransformer.Transform(dataView); var pipe = new KeyToBinaryVectorMappingEstimator(Env, - new KeyToBinaryVectorMappingTransformer.ColumnInfo("TA", "CatA"), - new KeyToBinaryVectorMappingTransformer.ColumnInfo("TB", "CatB"), - new KeyToBinaryVectorMappingTransformer.ColumnInfo("TC", "CatC"), - new KeyToBinaryVectorMappingTransformer.ColumnInfo("TD", "CatD")); + new KeyToBinaryVectorMappingTransformer.ColumnInfo("CatA", "TA"), + new KeyToBinaryVectorMappingTransformer.ColumnInfo("CatB", "TB"), + new KeyToBinaryVectorMappingTransformer.ColumnInfo("CatC", "TC"), + new KeyToBinaryVectorMappingTransformer.ColumnInfo("CatD", "TD")); var result = pipe.Fit(dataView).Transform(dataView); ValidateMetadata(result); @@ -155,16 +155,16 @@ public void TestOldSavingAndLoading() var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var dataView = ML.Data.ReadFromEnumerable(data); var est = new ValueToKeyMappingEstimator(Env, new[]{ - new ValueToKeyMappingTransformer.ColumnInfo("A", "TermA"), - new ValueToKeyMappingTransformer.ColumnInfo("B", "TermB", textKeyValues:true), - new ValueToKeyMappingTransformer.ColumnInfo("C", "TermC") + new ValueToKeyMappingTransformer.ColumnInfo("TermA", "A"), + new ValueToKeyMappingTransformer.ColumnInfo("TermB", "B", textKeyValues:true), + new ValueToKeyMappingTransformer.ColumnInfo("TermC", "C") }); var transformer = est.Fit(dataView); dataView = transformer.Transform(dataView); var pipe = new KeyToBinaryVectorMappingEstimator(Env, - new KeyToBinaryVectorMappingTransformer.ColumnInfo("TermA", "CatA"), - new KeyToBinaryVectorMappingTransformer.ColumnInfo("TermB", "CatB"), - new KeyToBinaryVectorMappingTransformer.ColumnInfo("TermC", "CatC") + new KeyToBinaryVectorMappingTransformer.ColumnInfo("CatA", "TermA"), + new KeyToBinaryVectorMappingTransformer.ColumnInfo("CatB", "TermB"), + new KeyToBinaryVectorMappingTransformer.ColumnInfo("CatC", "TermC") ); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs index cdebc1267c..a442fe7284 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToValueTests.cs @@ -45,13 +45,13 @@ public void KeyToValueWorkout() var data = reader.Read(dataPath); data = new ValueToKeyMappingEstimator(Env, new[] { - new ValueToKeyMappingTransformer.ColumnInfo("ScalarString", "A"), - new ValueToKeyMappingTransformer.ColumnInfo("VectorString", "B") }).Fit(data).Transform(data); + new ValueToKeyMappingTransformer.ColumnInfo("A", "ScalarString"), + new ValueToKeyMappingTransformer.ColumnInfo("B", "VectorString") }).Fit(data).Transform(data); - var badData1 = new ColumnCopyingTransformer(Env, ("BareKey", "A")).Transform(data); - var badData2 = new ColumnCopyingTransformer(Env, ("VectorString", "B")).Transform(data); + var badData1 = new ColumnCopyingTransformer(Env, ("A", "BareKey")).Transform(data); + var badData2 = new ColumnCopyingTransformer(Env, ("B", "VectorString")).Transform(data); - var est = new KeyToValueMappingEstimator(Env, ("A", "A_back"), ("B", "B_back")); + var est = new KeyToValueMappingEstimator(Env, ("A_back", "A"), ("B_back", "B")); TestEstimatorCore(est, data, invalidInput: badData1); TestEstimatorCore(est, data, invalidInput: badData2); @@ -82,8 +82,8 @@ public void KeyToValuePigsty() // Non-pigsty Term. var dynamicData = new ValueToKeyMappingEstimator(Env, new[] { - new ValueToKeyMappingTransformer.ColumnInfo("ScalarString", "A"), - new ValueToKeyMappingTransformer.ColumnInfo("VectorString", "B") }) + new ValueToKeyMappingTransformer.ColumnInfo("A", "ScalarString"), + new ValueToKeyMappingTransformer.ColumnInfo("B", "VectorString") }) .Fit(data.AsDynamic).Transform(data.AsDynamic); var data2 = dynamicData.AssertStatic(Env, ctx => ( diff --git a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs index 004d4f65ba..9c33ec87b7 100644 --- a/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/KeyToVectorEstimatorTests.cs @@ -53,15 +53,15 @@ public void KeyToVectorWorkout() var dataView = ML.Data.ReadFromEnumerable(data); dataView = new ValueToKeyMappingEstimator(Env, new[]{ - new ValueToKeyMappingTransformer.ColumnInfo("A", "TermA"), - new ValueToKeyMappingTransformer.ColumnInfo("B", "TermB"), - new ValueToKeyMappingTransformer.ColumnInfo("C", "TermC", textKeyValues:true) + new ValueToKeyMappingTransformer.ColumnInfo("TermA", "A"), + new ValueToKeyMappingTransformer.ColumnInfo("TermB", "B"), + new ValueToKeyMappingTransformer.ColumnInfo("TermC", "C", textKeyValues:true) }).Fit(dataView).Transform(dataView); - var pipe = new KeyToVectorMappingEstimator(Env, new KeyToVectorMappingTransformer.ColumnInfo("TermA", "CatA", false), - new KeyToVectorMappingTransformer.ColumnInfo("TermB", "CatB", true), - new KeyToVectorMappingTransformer.ColumnInfo("TermC", "CatC", true), - new KeyToVectorMappingTransformer.ColumnInfo("TermC", "CatCNonBag", false)); + var pipe = new KeyToVectorMappingEstimator(Env, new KeyToVectorMappingTransformer.ColumnInfo("CatA", "TermA", false), + new KeyToVectorMappingTransformer.ColumnInfo("CatB", "TermB", true), + new KeyToVectorMappingTransformer.ColumnInfo("CatC", "TermC", true), + new KeyToVectorMappingTransformer.ColumnInfo("CatCNonBag", "TermC", false)); TestEstimatorCore(pipe, dataView); Done(); } @@ -79,8 +79,8 @@ public void KeyToVectorStatic() // Non-pigsty Term. var dynamicData = new ValueToKeyMappingEstimator(Env, new[] { - new ValueToKeyMappingTransformer.ColumnInfo("ScalarString", "A"), - new ValueToKeyMappingTransformer.ColumnInfo("VectorString", "B") }) + new ValueToKeyMappingTransformer.ColumnInfo("A", "ScalarString"), + new ValueToKeyMappingTransformer.ColumnInfo("B", "VectorString") }) .Fit(data.AsDynamic).Transform(data.AsDynamic); var data2 = dynamicData.AssertStatic(Env, ctx => ( @@ -110,26 +110,26 @@ public void TestMetadataPropagation() var dataView = ML.Data.ReadFromEnumerable(data); var termEst = new ValueToKeyMappingEstimator(Env, new[] { - new ValueToKeyMappingTransformer.ColumnInfo("A", "TA", textKeyValues: true), - new ValueToKeyMappingTransformer.ColumnInfo("B", "TB"), - new ValueToKeyMappingTransformer.ColumnInfo("C", "TC", textKeyValues: true), - new ValueToKeyMappingTransformer.ColumnInfo("D", "TD", textKeyValues: true), - new ValueToKeyMappingTransformer.ColumnInfo("E", "TE"), - new ValueToKeyMappingTransformer.ColumnInfo("F", "TF"), - new ValueToKeyMappingTransformer.ColumnInfo("G", "TG"), - new ValueToKeyMappingTransformer.ColumnInfo("H", "TH", textKeyValues: true) }); + new ValueToKeyMappingTransformer.ColumnInfo("TA", "A", textKeyValues: true), + new ValueToKeyMappingTransformer.ColumnInfo("TB", "B"), + new ValueToKeyMappingTransformer.ColumnInfo("TC", "C", textKeyValues: true), + new ValueToKeyMappingTransformer.ColumnInfo("TD", "D", textKeyValues: true), + new ValueToKeyMappingTransformer.ColumnInfo("TE", "E"), + new ValueToKeyMappingTransformer.ColumnInfo("TF", "F"), + new ValueToKeyMappingTransformer.ColumnInfo("TG", "G"), + new ValueToKeyMappingTransformer.ColumnInfo("TH", "H", textKeyValues: true) }); var termTransformer = termEst.Fit(dataView); dataView = termTransformer.Transform(dataView); var pipe = new KeyToVectorMappingEstimator(Env, - new KeyToVectorMappingTransformer.ColumnInfo("TA", "CatA", true), - new KeyToVectorMappingTransformer.ColumnInfo("TB", "CatB", false), - new KeyToVectorMappingTransformer.ColumnInfo("TC", "CatC", false), - new KeyToVectorMappingTransformer.ColumnInfo("TD", "CatD", true), - new KeyToVectorMappingTransformer.ColumnInfo("TE", "CatE", false), - new KeyToVectorMappingTransformer.ColumnInfo("TF", "CatF", true), - new KeyToVectorMappingTransformer.ColumnInfo("TG", "CatG", true), - new KeyToVectorMappingTransformer.ColumnInfo("TH", "CatH", false) + new KeyToVectorMappingTransformer.ColumnInfo("CatA", "TA", true), + new KeyToVectorMappingTransformer.ColumnInfo("CatB", "TB", false), + new KeyToVectorMappingTransformer.ColumnInfo("CatC", "TC", false), + new KeyToVectorMappingTransformer.ColumnInfo("CatD", "TD", true), + new KeyToVectorMappingTransformer.ColumnInfo("CatE", "TE", false), + new KeyToVectorMappingTransformer.ColumnInfo("CatF", "TF", true), + new KeyToVectorMappingTransformer.ColumnInfo("CatG", "TG", true), + new KeyToVectorMappingTransformer.ColumnInfo("CatH", "TH", false) ); var result = pipe.Fit(dataView).Transform(dataView); @@ -215,15 +215,15 @@ public void TestOldSavingAndLoading() var data = new[] { new TestClass() { A = 1, B = 2, C = 3, }, new TestClass() { A = 4, B = 5, C = 6 } }; var dataView = ML.Data.ReadFromEnumerable(data); var est = new ValueToKeyMappingEstimator(Env, new[]{ - new ValueToKeyMappingTransformer.ColumnInfo("A", "TermA"), - new ValueToKeyMappingTransformer.ColumnInfo("B", "TermB"), - new ValueToKeyMappingTransformer.ColumnInfo("C", "TermC") + new ValueToKeyMappingTransformer.ColumnInfo("TermA", "A"), + new ValueToKeyMappingTransformer.ColumnInfo("TermB", "B"), + new ValueToKeyMappingTransformer.ColumnInfo("TermC", "C") }); var transformer = est.Fit(dataView); dataView = transformer.Transform(dataView); var pipe = new KeyToVectorMappingEstimator(Env, - new KeyToVectorMappingTransformer.ColumnInfo("TermA", "CatA", false), - new KeyToVectorMappingTransformer.ColumnInfo("TermB", "CatB", true) + new KeyToVectorMappingTransformer.ColumnInfo("CatA", "TermA",false), + new KeyToVectorMappingTransformer.ColumnInfo("CatB", "TermB", true) ); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs index 55a073d5e9..e30b43eff6 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAIndicatorTests.cs @@ -47,7 +47,7 @@ public void NAIndicatorWorkout() var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new MissingValueIndicatorEstimator(Env, - new (string input, string output)[] { ("A", "NAA"), ("B", "NAB"), ("C", "NAC"), ("D", "NAD") }); + new (string outputColumnName, string inputColumnName)[] { ("NAA", "A"), ("NAB", "B"), ("NAC", "C"), ("NAD", "D") }); TestEstimatorCore(pipe, dataView); Done(); } @@ -71,7 +71,7 @@ public void TestOldSavingAndLoading() var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new MissingValueIndicatorEstimator(Env, - new (string input, string output)[] { ("A", "NAA"), ("B", "NAB"), ("C", "NAC"), ("D", "NAD") }); + new (string outputColumnName, string inputColumnName)[] { ("NAA", "A"), ("NAB", "B"), ("NAC", "C"), ("NAD", "D") }); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); using (var ms = new MemoryStream()) @@ -97,7 +97,7 @@ public void NAIndicatorFileOutput() var wrongCollection = new[] { new TestClass() { A = 1, B = 3, C = new float[2] { 1, 2 }, D = new double[2] { 3, 4 } } }; var invalidData = ML.Data.ReadFromEnumerable(wrongCollection); var est = new MissingValueIndicatorEstimator(Env, - new (string input, string output)[] { ("ScalarFloat", "A"), ("ScalarDouble", "B"), ("VectorFloat", "C"), ("VectorDoulbe", "D") }); + new (string outputColumnName, string inputColumnName)[] { ("A", "ScalarFloat"), ("B", "ScalarDouble"), ("C", "VectorFloat"), ("D", "VectorDoulbe") }); TestEstimatorCore(est, data, invalidInput: invalidData); var outputPath = GetOutputPath("NAIndicator", "featurized.tsv"); @@ -125,8 +125,8 @@ public void NAIndicatorMetadataTest() }; var dataView = ML.Data.ReadFromEnumerable(data); - var pipe = new OneHotEncodingEstimator(Env, "A", "CatA"); - var newpipe = pipe.Append(new MissingValueIndicatorEstimator(Env, new (string input, string output)[] { ("CatA", "NAA") })); + var pipe = new OneHotEncodingEstimator(Env, "CatA", "A"); + var newpipe = pipe.Append(new MissingValueIndicatorEstimator(Env, new (string name, string source)[] { ("NAA", "CatA") })); var result = newpipe.Fit(dataView).Transform(dataView); Assert.True(result.Schema.TryGetColumnIndex("NAA", out var col)); // Check that the column is normalized. diff --git a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs index 6094033ed0..c7fabe38f2 100644 --- a/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NAReplaceTests.cs @@ -44,10 +44,10 @@ public void NAReplaceWorkout() var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new MissingValueReplacingEstimator(Env, - new MissingValueReplacingTransformer.ColumnInfo("A", "NAA", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), - new MissingValueReplacingTransformer.ColumnInfo("B", "NAB", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), - new MissingValueReplacingTransformer.ColumnInfo("C", "NAC", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), - new MissingValueReplacingTransformer.ColumnInfo("D", "NAD", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean)); + new MissingValueReplacingTransformer.ColumnInfo("NAA", "A", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), + new MissingValueReplacingTransformer.ColumnInfo("NAB", "B", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), + new MissingValueReplacingTransformer.ColumnInfo("NAC", "C", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), + new MissingValueReplacingTransformer.ColumnInfo("NAD", "D", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean)); TestEstimatorCore(pipe, dataView); Done(); } @@ -109,10 +109,10 @@ public void TestOldSavingAndLoading() var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new MissingValueReplacingEstimator(Env, - new MissingValueReplacingTransformer.ColumnInfo("A", "NAA", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), - new MissingValueReplacingTransformer.ColumnInfo("B", "NAB", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), - new MissingValueReplacingTransformer.ColumnInfo("C", "NAC", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), - new MissingValueReplacingTransformer.ColumnInfo("D", "NAD", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean)); + new MissingValueReplacingTransformer.ColumnInfo("NAA", "A", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), + new MissingValueReplacingTransformer.ColumnInfo("NAB", "B", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), + new MissingValueReplacingTransformer.ColumnInfo("NAC", "C", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean), + new MissingValueReplacingTransformer.ColumnInfo("NAD", "D", MissingValueReplacingTransformer.ColumnInfo.ReplacementMode.Mean)); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs b/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs index 5b67b5d372..290a42b9a9 100644 --- a/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/NormalizerTests.cs @@ -48,27 +48,27 @@ public void NormalizerWorkout() new NormalizingEstimator.MinMaxColumn("float4"), new NormalizingEstimator.MinMaxColumn("double1"), new NormalizingEstimator.MinMaxColumn("double4"), - new NormalizingEstimator.BinningColumn("float1", "float1bin"), - new NormalizingEstimator.BinningColumn("float4", "float4bin"), - new NormalizingEstimator.BinningColumn("double1", "double1bin"), - new NormalizingEstimator.BinningColumn("double4", "double4bin"), - new NormalizingEstimator.SupervisedBinningColumn("float1", "float1supervisedbin", labelColumn:"int1"), - new NormalizingEstimator.SupervisedBinningColumn("float4", "float4supervisedbin", labelColumn: "int1"), - new NormalizingEstimator.SupervisedBinningColumn("double1", "double1supervisedbin", labelColumn: "int1"), - new NormalizingEstimator.SupervisedBinningColumn("double4", "double4supervisedbin", labelColumn: "int1"), - new NormalizingEstimator.MeanVarColumn("float1", "float1mv"), - new NormalizingEstimator.MeanVarColumn("float4", "float4mv"), - new NormalizingEstimator.MeanVarColumn("double1", "double1mv"), - new NormalizingEstimator.MeanVarColumn("double4", "double4mv"), - new NormalizingEstimator.LogMeanVarColumn("float1", "float1lmv"), - new NormalizingEstimator.LogMeanVarColumn("float4", "float4lmv"), - new NormalizingEstimator.LogMeanVarColumn("double1", "double1lmv"), - new NormalizingEstimator.LogMeanVarColumn("double4", "double4lmv")); + new NormalizingEstimator.BinningColumn("float1bin", "float1"), + new NormalizingEstimator.BinningColumn("float4bin", "float4"), + new NormalizingEstimator.BinningColumn("double1bin", "double1"), + new NormalizingEstimator.BinningColumn("double4bin", "double4"), + new NormalizingEstimator.SupervisedBinningColumn("float1supervisedbin", "float1", labelColumn: "int1"), + new NormalizingEstimator.SupervisedBinningColumn("float4supervisedbin", "float4", labelColumn: "int1"), + new NormalizingEstimator.SupervisedBinningColumn("double1supervisedbin", "double1", labelColumn: "int1"), + new NormalizingEstimator.SupervisedBinningColumn("double4supervisedbin", "double4", labelColumn: "int1"), + new NormalizingEstimator.MeanVarColumn("float1mv", "float1"), + new NormalizingEstimator.MeanVarColumn("float4mv", "float4"), + new NormalizingEstimator.MeanVarColumn("double1mv", "double1"), + new NormalizingEstimator.MeanVarColumn("double4mv", "double4"), + new NormalizingEstimator.LogMeanVarColumn("float1lmv", "float1"), + new NormalizingEstimator.LogMeanVarColumn("float4lmv", "float4"), + new NormalizingEstimator.LogMeanVarColumn("double1lmv", "double1"), + new NormalizingEstimator.LogMeanVarColumn("double4lmv", "double4")); var data = loader.Read(dataPath); - var badData1 = new ColumnCopyingTransformer(Env, ("int1", "float1")).Transform(data); - var badData2 = new ColumnCopyingTransformer(Env, ("float0", "float4")).Transform(data); + var badData1 = new ColumnCopyingTransformer(Env, ("float1", "int1")).Transform(data); + var badData2 = new ColumnCopyingTransformer(Env, ("float4", "float0")).Transform(data); TestEstimatorCore(est, data, null, badData1); TestEstimatorCore(est, data, null, badData2); @@ -108,22 +108,22 @@ public void NormalizerParameters() }, new MultiFileSource(dataPath)); var est = new NormalizingEstimator(Env, - new NormalizingEstimator.MinMaxColumn("float1"), + new NormalizingEstimator.MinMaxColumn("float1"), new NormalizingEstimator.MinMaxColumn("float4"), - new NormalizingEstimator.MinMaxColumn("double1"), + new NormalizingEstimator.MinMaxColumn("double1"), new NormalizingEstimator.MinMaxColumn("double4"), - new NormalizingEstimator.BinningColumn("float1", "float1bin"), - new NormalizingEstimator.BinningColumn("float4", "float4bin"), - new NormalizingEstimator.BinningColumn("double1", "double1bin"), - new NormalizingEstimator.BinningColumn("double4", "double4bin"), - new NormalizingEstimator.MeanVarColumn("float1", "float1mv"), - new NormalizingEstimator.MeanVarColumn("float4", "float4mv"), - new NormalizingEstimator.MeanVarColumn("double1", "double1mv"), - new NormalizingEstimator.MeanVarColumn("double4", "double4mv"), - new NormalizingEstimator.LogMeanVarColumn("float1", "float1lmv"), - new NormalizingEstimator.LogMeanVarColumn("float4", "float4lmv"), - new NormalizingEstimator.LogMeanVarColumn("double1", "double1lmv"), - new NormalizingEstimator.LogMeanVarColumn("double4", "double4lmv")); + new NormalizingEstimator.BinningColumn("float1bin", "float1"), + new NormalizingEstimator.BinningColumn("float4bin", "float4"), + new NormalizingEstimator.BinningColumn("double1bin", "double1"), + new NormalizingEstimator.BinningColumn("double4bin", "double4"), + new NormalizingEstimator.MeanVarColumn("float1mv", "float1"), + new NormalizingEstimator.MeanVarColumn("float4mv", "float4"), + new NormalizingEstimator.MeanVarColumn("double1mv", "double1"), + new NormalizingEstimator.MeanVarColumn("double4mv", "double4"), + new NormalizingEstimator.LogMeanVarColumn("float1lmv", "float1"), + new NormalizingEstimator.LogMeanVarColumn("float4lmv", "float4"), + new NormalizingEstimator.LogMeanVarColumn("double1lmv", "double1"), + new NormalizingEstimator.LogMeanVarColumn("double4lmv", "double4")); var data = loader.Read(dataPath); @@ -144,10 +144,10 @@ public void NormalizerParameters() var doubleAffineDataVec = transformer.Columns[3].ModelParameters as NormalizingTransformer.AffineNormalizerModelParameters>; Assert.Equal(4, doubleAffineDataVec.Scale.Length); Assert.Empty(doubleAffineDataVec.Offset); - + var floatBinData = transformer.Columns[4].ModelParameters as NormalizingTransformer.BinNormalizerModelParameters; Assert.True(35 == floatBinData.UpperBounds.Length); - Assert.True(34 == floatBinData.Density); + Assert.True(34 == floatBinData.Density); Assert.True(0 == floatBinData.Offset); var floatBinDataVec = transformer.Columns[5].ModelParameters as NormalizingTransformer.BinNormalizerModelParameters>; @@ -165,7 +165,7 @@ public void NormalizerParameters() Assert.Equal(35, doubleBinDataVec.UpperBounds[0].Length); Assert.Equal(4, doubleBinDataVec.Density.Length); Assert.Empty(doubleBinDataVec.Offset); - + var floatCdfMeanData = transformer.Columns[8].ModelParameters as NormalizingTransformer.AffineNormalizerModelParameters; Assert.Equal(0.169309646f, floatCdfMeanData.Scale); Assert.Equal(0, floatCdfMeanData.Offset); @@ -272,9 +272,9 @@ public void LpGcNormAndWhiteningWorkout() separator: ';', hasHeader: true) .Read(dataSource); - var est = new LpNormalizingEstimator(ML, "features", "lpnorm") - .Append(new GlobalContrastNormalizingEstimator(ML, "features", "gcnorm")) - .Append(new VectorWhiteningEstimator(ML, "features", "whitened")); + var est = new LpNormalizingEstimator(ML, "lpnorm", "features") + .Append(new GlobalContrastNormalizingEstimator(ML, "gcnorm", "features")) + .Append(new VectorWhiteningEstimator(ML, "whitened", "features")); TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); var outputPath = GetOutputPath("NormalizerEstimator", "lpnorm_gcnorm_whitened.tsv"); @@ -306,8 +306,8 @@ public void WhiteningWorkout() separator: ';', hasHeader: true) .Read(dataSource); - var est = new VectorWhiteningEstimator(ML, "features", "whitened1") - .Append(new VectorWhiteningEstimator(ML, "features", "whitened2", kind: WhiteningKind.Pca, pcaNum: 5)); + var est = new VectorWhiteningEstimator(ML, "whitened1", "features") + .Append(new VectorWhiteningEstimator(ML, "whitened2", "features", kind: WhiteningKind.Pca, pcaNum: 5)); TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); var outputPath = GetOutputPath("NormalizerEstimator", "whitened.tsv"); @@ -339,7 +339,7 @@ public void TestWhiteningOldSavingAndLoading() c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), separator: ';', hasHeader: true) .Read(dataSource).AsDynamic; - var pipe = new VectorWhiteningEstimator(ML, "features", "whitened"); + var pipe = new VectorWhiteningEstimator(ML, "whitened", "features"); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); @@ -366,8 +366,8 @@ public void LpNormWorkout() separator: ';', hasHeader: true) .Read(dataSource); - var est = new LpNormalizingEstimator(ML, "features", "lpNorm1") - .Append(new LpNormalizingEstimator(ML, "features", "lpNorm2", normKind: LpNormalizingEstimatorBase.NormalizerKind.L1Norm, substractMean: true)); + var est = new LpNormalizingEstimator(ML, "lpNorm1", "features") + .Append(new LpNormalizingEstimator(ML, "lpNorm2", "features", normKind: LpNormalizingEstimatorBase.NormalizerKind.L1Norm, substractMean: true)); TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); var outputPath = GetOutputPath("NormalizerEstimator", "lpNorm.tsv"); @@ -399,7 +399,7 @@ public void TestLpNormOldSavingAndLoading() c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), separator: ';', hasHeader: true) .Read(dataSource).AsDynamic; - var pipe = new LpNormalizingEstimator(ML, "features", "whitened"); + var pipe = new LpNormalizingEstimator(ML, "whitened", "features"); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); @@ -425,8 +425,8 @@ public void GcnWorkout() separator: ';', hasHeader: true) .Read(dataSource); - var est = new GlobalContrastNormalizingEstimator(ML, "features", "gcnNorm1") - .Append(new GlobalContrastNormalizingEstimator(ML, "features", "gcnNorm2", substractMean: false, useStdDev: true, scale: 3)); + var est = new GlobalContrastNormalizingEstimator(ML, "gcnNorm1", "features") + .Append(new GlobalContrastNormalizingEstimator(ML, "gcnNorm2", "features", substractMean: false, useStdDev: true, scale: 3)); TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); var outputPath = GetOutputPath("NormalizerEstimator", "gcnNorm.tsv"); @@ -458,7 +458,7 @@ public void TestGcnNormOldSavingAndLoading() c => (label: c.LoadFloat(11), features: c.LoadFloat(0, 10)), separator: ';', hasHeader: true) .Read(dataSource).AsDynamic; - var pipe = new GlobalContrastNormalizingEstimator(ML, "features", "whitened"); + var pipe = new GlobalContrastNormalizingEstimator(ML, "whitened", "features"); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/Transformers/PcaTests.cs b/test/Microsoft.ML.Tests/Transformers/PcaTests.cs index b5210b5c86..c536fa9966 100644 --- a/test/Microsoft.ML.Tests/Transformers/PcaTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/PcaTests.cs @@ -42,10 +42,10 @@ public void PcaWorkout() separator: ';', hasHeader: true) .Read(_dataSource); - var est = new PrincipalComponentAnalysisEstimator(_env, "features", "pca", rank: 4, seed: 10); + var est = new PrincipalComponentAnalysisEstimator(_env, "pca", "features", rank: 4, seed: 10); TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); - var estNonDefaultArgs = new PrincipalComponentAnalysisEstimator(_env, "features", "pca", rank: 3, weightColumn: "weight", overSampling: 2, center: false); + var estNonDefaultArgs = new PrincipalComponentAnalysisEstimator(_env, "pca", "features", rank: 3, weightColumn: "weight", overSampling: 2, center: false); TestEstimatorCore(estNonDefaultArgs, data.AsDynamic, invalidInput: invalidData.AsDynamic); Done(); @@ -59,7 +59,7 @@ public void TestPcaEstimator() separator: ';', hasHeader: true) .Read(_dataSource); - var est = new PrincipalComponentAnalysisEstimator(_env, "features", "pca", rank: 5, seed: 1); + var est = new PrincipalComponentAnalysisEstimator(_env, "pca", "features", rank: 5, seed: 1); var outputPath = GetOutputPath("PCA", "pca.tsv"); using (var ch = _env.Start("save")) { diff --git a/test/Microsoft.ML.Tests/Transformers/RffTests.cs b/test/Microsoft.ML.Tests/Transformers/RffTests.cs index 92a64355a1..68af36863b 100644 --- a/test/Microsoft.ML.Tests/Transformers/RffTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/RffTests.cs @@ -56,8 +56,8 @@ public void RffWorkout() var generator = new GaussianFourierSampler.Arguments(); var pipe = new RandomFourierFeaturizingEstimator(Env, new[]{ - new RandomFourierFeaturizingTransformer.ColumnInfo("A", "RffA", 5, false), - new RandomFourierFeaturizingTransformer.ColumnInfo("A", "RffB", 10, true, new LaplacianFourierSampler.Arguments()) + new RandomFourierFeaturizingTransformer.ColumnInfo("RffA", 5, false, "A"), + new RandomFourierFeaturizingTransformer.ColumnInfo("RffB", 10, true, "A", new LaplacianFourierSampler.Arguments()) }); TestEstimatorCore(pipe, dataView, invalidInput: invalidData, validForFitNotValidForTransformInput: validFitInvalidData); @@ -112,8 +112,8 @@ public void TestOldSavingAndLoading() var dataView = ML.Data.ReadFromEnumerable(data); var est = new RandomFourierFeaturizingEstimator(Env, new[]{ - new RandomFourierFeaturizingTransformer.ColumnInfo("A", "RffA", 5, false), - new RandomFourierFeaturizingTransformer.ColumnInfo("A", "RffB", 10, true,new LaplacianFourierSampler.Arguments()) + new RandomFourierFeaturizingTransformer.ColumnInfo("RffA", 5, false, "A"), + new RandomFourierFeaturizingTransformer.ColumnInfo("RffB", 10, true, "A", new LaplacianFourierSampler.Arguments()) }); var result = est.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index f391edb7ab..efe7532cbd 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -76,8 +76,8 @@ public void TextTokenizationWorkout() text: ctx.LoadFloat(1)), hasHeader: true) .Read(sentimentDataPath); - var est = new WordTokenizingEstimator(Env, "text", "words") - .Append(new TokenizingByCharactersEstimator(Env, "text", "chars")) + var est = new WordTokenizingEstimator(Env,"words", "text") + .Append(new TokenizingByCharactersEstimator(Env,"chars", "text")) .Append(new KeyToValueMappingEstimator(Env, "chars")); TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); @@ -105,7 +105,7 @@ public void TokenizeWithSeparators() text: ctx.LoadText(1)), hasHeader: true) .Read(dataPath).AsDynamic; - var est = new WordTokenizingEstimator(Env, "text", "words", separators: new[] { ' ', '?', '!', '.', ',' }); + var est = new WordTokenizingEstimator(Env, "words", "text", separators: new[] { ' ', '?', '!', '.', ',' }); var outdata = TakeFilter.Create(Env, est.Fit(data).Transform(data), 4); var savedData = ColumnSelectingTransformer.CreateKeep(Env, outdata, new[] { "words" }); @@ -147,9 +147,9 @@ public void TextNormalizationAndStopwordRemoverWorkout() text: ctx.LoadFloat(1)), hasHeader: true) .Read(sentimentDataPath); var est = ML.Transforms.Text.NormalizeText("text") - .Append(ML.Transforms.Text.TokenizeWords("text", "words")) - .Append(ML.Transforms.Text.RemoveDefaultStopWords("words", "NoDefaultStopwords")) - .Append(ML.Transforms.Text.RemoveStopWords("words", "NoStopWords", "xbox", "this", "is", "a", "the","THAT","bY")); + .Append(ML.Transforms.Text.TokenizeWords("words", "text")) + .Append(ML.Transforms.Text.RemoveDefaultStopWords("NoDefaultStopwords", "words")) + .Append(ML.Transforms.Text.RemoveStopWords("NoStopWords", "words", "xbox", "this", "is", "a", "the","THAT","bY")); TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); @@ -214,8 +214,8 @@ public void WordBagWorkout() text: ctx.LoadFloat(1)), hasHeader: true) .Read(sentimentDataPath); - var est = new WordBagEstimator(Env, "text", "bag_of_words"). - Append(new WordHashBagEstimator(Env, "text", "bag_of_wordshash", invertHash: -1)); + var est = new WordBagEstimator(Env, "bag_of_words", "text"). + Append(new WordHashBagEstimator(Env, "bag_of_wordshash", "text", invertHash: -1)); // The following call fails because of the following issue // https://github.com/dotnet/machinelearning/issues/969 @@ -251,9 +251,9 @@ public void NgramWorkout() .Read(sentimentDataPath); var est = new WordTokenizingEstimator(Env, "text", "text") - .Append(new ValueToKeyMappingEstimator(Env, "text", "terms")) - .Append(new NgramExtractingEstimator(Env, "terms", "ngrams")) - .Append(new NgramHashingEstimator(Env, "terms", "ngramshash")); + .Append(new ValueToKeyMappingEstimator(Env, "terms", "text")) + .Append(new NgramExtractingEstimator(Env, "ngrams", "terms")) + .Append(new NgramHashingEstimator(Env, "ngramshash", "terms")); TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); @@ -304,8 +304,8 @@ public void LdaWorkout() text: ctx.LoadFloat(1)), hasHeader: true) .Read(sentimentDataPath); - var est = new WordBagEstimator(env, "text", "bag_of_words"). - Append(new LatentDirichletAllocationEstimator(env, "bag_of_words", "topics", 10, numIterations: 10, + var est = new WordBagEstimator(env, "bag_of_words", "text"). + Append(new LatentDirichletAllocationEstimator(env, "topics", "bag_of_words", 10, numIterations: 10, resetRandomGenerator: true)); // The following call fails because of the following issue diff --git a/test/Microsoft.ML.Tests/Transformers/TextNormalizer.cs b/test/Microsoft.ML.Tests/Transformers/TextNormalizer.cs index 08782e5404..425b331ffa 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextNormalizer.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextNormalizer.cs @@ -43,7 +43,7 @@ public void TextNormalizerWorkout() var data = new[] { new TestClass() { A = "A 1, b. c! йЁ 24 ", B = new string[2] { "~``ё 52ds й vc", "6ksj94 vd ё dakl Юds Ё q й" } }, new TestClass() { A = null, B =new string[2] { null, string.Empty } } }; var dataView = ML.Data.ReadFromEnumerable(data); - var pipe = new TextNormalizingEstimator(Env, columns: new[] { ("A", "NormA"), ("B", "NormB") }); + var pipe = new TextNormalizingEstimator(Env, columns: new[] { ("NormA", "A"), ("NormB", "B") }); var invalidData = new[] { new TestClassB() { A = 1, B = new float[2] { 1,4 } }, new TestClassB() { A = 2, B =new float[2] { 3,4 } } }; @@ -57,11 +57,11 @@ public void TextNormalizerWorkout() var dataSource = new MultiFileSource(dataPath); dataView = reader.Read(dataSource).AsDynamic; - var pipeVariations = new TextNormalizingEstimator(Env, columns: new[] { ("text", "NormText") }).Append( - new TextNormalizingEstimator(Env, textCase: TextNormalizingEstimator.CaseNormalizationMode.Upper, columns: new[] { ("text", "UpperText") })).Append( - new TextNormalizingEstimator(Env, keepDiacritics: true, columns: new[] { ("text", "WithDiacriticsText") })).Append( - new TextNormalizingEstimator(Env, keepNumbers: false, columns: new[] { ("text", "NoNumberText") })).Append( - new TextNormalizingEstimator(Env, keepPunctuations: false, columns: new[] { ("text", "NoPuncText") })); + var pipeVariations = new TextNormalizingEstimator(Env, columns: new[] { ("NormText", "text") }).Append( + new TextNormalizingEstimator(Env, textCase: TextNormalizingEstimator.CaseNormalizationMode.Upper, columns: new[] { ("UpperText", "text") })).Append( + new TextNormalizingEstimator(Env, keepDiacritics: true, columns: new[] { ("WithDiacriticsText", "text") })).Append( + new TextNormalizingEstimator(Env, keepNumbers: false, columns: new[] { ("NoNumberText", "text") })).Append( + new TextNormalizingEstimator(Env, keepPunctuations: false, columns: new[] { ("NoPuncText", "text") })); var outputPath = GetOutputPath("Text", "Normalized.tsv"); using (var ch = Env.Start("save")) @@ -87,7 +87,7 @@ public void TestOldSavingAndLoading() { var data = new[] { new TestClass() { A = "A 1, b. c! йЁ 24 ", B = new string[2] { "~``ё 52ds й vc", "6ksj94 vd ё dakl Юds Ё q й" } } }; var dataView = ML.Data.ReadFromEnumerable(data); - var pipe = new TextNormalizingEstimator(Env, columns: new[] { ("A", "NormA"), ("B", "NormB") }); + var pipe = new TextNormalizingEstimator(Env, columns: new[] { ("NormA", "A"), ("NormB", "B") }); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs index 7d84dbdfa3..ae42eda606 100644 --- a/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/ValueMappingTests.cs @@ -56,7 +56,7 @@ public void ValueMapOneValueTest() var keys = new List() { "foo", "bar", "test", "wahoo" }; var values = new List() { 1, 2, 3, 4 }; - var estimator = new ValueMappingEstimator(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var estimator = new ValueMappingEstimator(Env, keys, values, new[] { ("D", "A"), ("E", "B"), ("F", "C") }); var t = estimator.Fit(dataView); var result = t.Transform(dataView); @@ -87,8 +87,8 @@ public void ValueMapInputIsVectorTest() var values = new List() { 1, 2, 3, 4 }; var estimator = new WordTokenizingEstimator(Env, new[]{ - new WordTokenizingTransformer.ColumnInfo("A", "TokenizeA") - }).Append(new ValueMappingEstimator, int>(Env, keys, values, new[] { ("TokenizeA", "VecD"), ("B", "E"), ("C", "F") })); + new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A") + }).Append(new ValueMappingEstimator, int>(Env, keys, values, new[] { ("VecD", "TokenizeA"), ("E", "B"), ("F", "C") })); var t = estimator.Fit(dataView); var result = t.Transform(dataView); @@ -120,8 +120,8 @@ public void ValueMapInputIsVectorAndValueAsStringKeyTypeTest() var values = new List>() { "a".AsMemory(), "b".AsMemory(), "c".AsMemory(), "d".AsMemory() }; var estimator = new WordTokenizingEstimator(Env, new[]{ - new WordTokenizingTransformer.ColumnInfo("A", "TokenizeA") - }).Append(new ValueMappingEstimator, ReadOnlyMemory>(Env, keys, values, true, new[] { ("TokenizeA", "VecD"), ("B", "E"), ("C", "F") })); + new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A") + }).Append(new ValueMappingEstimator, ReadOnlyMemory>(Env, keys, values, true, new[] { ("VecD", "TokenizeA"), ("E", "B"), ("F", "C") })); var t = estimator.Fit(dataView); var result = t.Transform(dataView); @@ -155,7 +155,7 @@ public void ValueMapVectorValueTest() new int[] {100, 200 }, new int[] {400, 500, 600, 700 }}; - var estimator = new ValueMappingEstimator(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var estimator = new ValueMappingEstimator(Env, keys, values, new[] { ("D", "A"), ("E", "B"), ("F", "C") }); var t = estimator.Fit(dataView); var result = t.Transform(dataView); @@ -196,7 +196,7 @@ public void ValueMapDataViewAsMapTest() }; var mapView = ML.Data.ReadFromEnumerable(map); - var estimator = new ValueMappingEstimator(Env, mapView, "Key", "Value", new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var estimator = new ValueMappingEstimator(Env, mapView, "Key", "Value", new[] { ("D", "A"), ("E", "B"), ("F", "C") }); var t = estimator.Fit(dataView); var result = t.Transform(dataView); @@ -229,7 +229,7 @@ public void ValueMapVectorStringValueTest() new string[] {"forest", "city", "town" }, new string[] {"winter", "summer", "autumn", "spring" }}; - var estimator = new ValueMappingEstimator(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var estimator = new ValueMappingEstimator(Env, keys, values, new[] { ("D", "A"), ("E", "B"), ("F", "C") }); var t = estimator.Fit(dataView); var result = t.Transform(dataView); @@ -262,7 +262,7 @@ public void ValueMappingMissingKey() var keys = new List() { "foo", "bar", "test", "wahoo" }; var values = new List() { 1, 2, 3, 4 }; - var estimator = new ValueMappingEstimator(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var estimator = new ValueMappingEstimator(Env, keys, values, new[] { ("D", "A"), ("E", "B"), ("F", "C") }); var t = estimator.Fit(dataView); var result = t.Transform(dataView); @@ -292,7 +292,7 @@ void TestDuplicateKeys() var keys = new List() { "foo", "foo" }; var values = new List() { 1, 2 }; - Assert.Throws(() => new ValueMappingEstimator(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") })); + Assert.Throws(() => new ValueMappingEstimator(Env, keys, values, new[] { ("D", "A"), ("E", "B"), ("F", "C") })); } [Fact] @@ -304,7 +304,7 @@ public void ValueMappingOutputSchema() var keys = new List() { "foo", "bar", "test", "wahoo" }; var values = new List() { 1, 2, 3, 4 }; - var estimator = new ValueMappingEstimator(Env, keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var estimator = new ValueMappingEstimator(Env, keys, values, new[] { ("D", "A"), ("E", "B"), ("F", "C") }); var outputSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema)); Assert.Equal(6, outputSchema.Count()); @@ -331,7 +331,7 @@ public void ValueMappingWithValuesAsKeyTypesOutputSchema() var keys = new List() { "foo", "bar", "test", "wahoo" }; var values = new List() { "t", "s", "u", "v" }; - var estimator = new ValueMappingEstimator(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var estimator = new ValueMappingEstimator(Env, keys, values, true, new[] { ("D", "A"), ("E", "B"), ("F", "C") }); var outputSchema = estimator.GetOutputSchema(SchemaShape.Create(dataView.Schema)); Assert.Equal(6, outputSchema.Count()); Assert.True(outputSchema.TryFindColumn("D", out SchemaShape.Column dColumn)); @@ -361,7 +361,7 @@ public void ValueMappingValuesAsUintKeyTypes() // These are the expected key type values var values = new List() { 51, 25, 42, 61 }; - var estimator = new ValueMappingEstimator(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var estimator = new ValueMappingEstimator(Env, keys, values, true, new[] { ("D", "A"), ("E", "B"), ("F", "C") }); var t = estimator.Fit(dataView); @@ -400,7 +400,7 @@ public void ValueMappingValuesAsUlongKeyTypes() // These are the expected key type values var values = new List() { 51, Int32.MaxValue, 42, 61 }; - var estimator = new ValueMappingEstimator(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var estimator = new ValueMappingEstimator(Env, keys, values, true, new[] { ("D", "A"), ("E", "B"), ("F", "C") }); var t = estimator.Fit(dataView); @@ -438,7 +438,7 @@ public void ValueMappingValuesAsStringKeyTypes() // Generating the list of strings for the key type values, note that foo1 is duplicated as intended to test that the same index value is returned var values = new List() { "foo1", "foo2", "foo1", "foo3" }; - var estimator = new ValueMappingEstimator(Env, keys, values, true, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var estimator = new ValueMappingEstimator(Env, keys, values, true, new[] { ("D", "A"), ("E", "B"), ("F", "C") }); var t = estimator.Fit(dataView); var result = t.Transform(dataView); @@ -475,8 +475,8 @@ public void ValueMappingValuesAsKeyTypesReverseLookup() // Generating the list of strings for the key type values, note that foo1 is duplicated as intended to test that the same index value is returned var values = new List>() { "foo1".AsMemory(), "foo2".AsMemory(), "foo1".AsMemory(), "foo3".AsMemory() }; - var estimator = new ValueMappingEstimator, ReadOnlyMemory>(Env, keys, values, true, new[] { ("A", "D") }) - .Append(new KeyToValueMappingEstimator(Env, ("D", "DOutput"))); + var estimator = new ValueMappingEstimator, ReadOnlyMemory>(Env, keys, values, true, new[] { ("D", "A") }) + .Append(new KeyToValueMappingEstimator(Env, ("DOutput","D"))); var t = estimator.Fit(dataView); var result = t.Transform(dataView); @@ -502,7 +502,7 @@ public void ValueMappingWorkout() var values = new List() { 1, 2, 3, 4 }; // Workout on value mapping - var est = ML.Transforms.Conversion.ValueMap(keys, values, new[] { ("A", "D"), ("B", "E"), ("C", "F") }); + var est = ML.Transforms.Conversion.ValueMap(keys, values, new[] { ("D", "A"), ("E", "B"), ("F", "C") }); TestEstimatorCore(est, validFitInput: dataView, invalidInput: badDataView); } @@ -550,7 +550,7 @@ void TestSavingAndLoading() var est = new ValueMappingEstimator(Env, new List() { "foo", "bar", "test" }, new List() { 2, 43, 56 }, - new [] {("A","D"), ("B", "E")}); + new [] { ("D", "A"), ("E", "B") }); var transformer = est.Fit(dataView); using (var ms = new MemoryStream()) diff --git a/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs b/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs index c164269cf3..6c572d4f04 100644 --- a/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs @@ -29,7 +29,7 @@ public void TestWordEmbeddings() label: ctx.LoadBool(0), SentimentText: ctx.LoadText(1)), hasHeader: true) .Read(dataPath); - var dynamicData = new TextFeaturizingEstimator(Env, "SentimentText", "SentimentText_Features", args => + var dynamicData = new TextFeaturizingEstimator(Env, "SentimentText_Features", "SentimentText", args => { args.OutputTokens = true; args.KeepPunctuations = false; diff --git a/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs b/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs index 6137e78090..5f6723526d 100644 --- a/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/WordTokenizeTests.cs @@ -57,8 +57,8 @@ public void WordTokenizeWorkout() var invalidData = new[] { new TestWrong() { A =1, B = new float[2] { 2,3 } } }; var invalidDataView = ML.Data.ReadFromEnumerable(invalidData); var pipe = new WordTokenizingEstimator(Env, new[]{ - new WordTokenizingTransformer.ColumnInfo("A", "TokenizeA"), - new WordTokenizingTransformer.ColumnInfo("B", "TokenizeB"), + new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A"), + new WordTokenizingTransformer.ColumnInfo("TokenizeB", "B"), }); TestEstimatorCore(pipe, dataView, invalidInput: invalidDataView); @@ -99,8 +99,8 @@ public void TestOldSavingAndLoading() var dataView = ML.Data.ReadFromEnumerable(data); var pipe = new WordTokenizingEstimator(Env, new[]{ - new WordTokenizingTransformer.ColumnInfo("A", "TokenizeA"), - new WordTokenizingTransformer.ColumnInfo("B", "TokenizeB"), + new WordTokenizingTransformer.ColumnInfo("TokenizeA", "A"), + new WordTokenizingTransformer.ColumnInfo("TokenizeB", "B"), }); var result = pipe.Fit(dataView).Transform(dataView); var resultRoles = new RoleMappedData(result); diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs index 1d748dc510..93dc3f8a48 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesDirectApi.cs @@ -156,7 +156,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngineNoColumn() // Pipeline. - var pipeline = ml.Transforms.Text.FeaturizeText("Text", "Text_Featurized") + var pipeline = ml.Transforms.Text.FeaturizeText("Text_Featurized", "Text") .Append(new SsaChangePointEstimator(ml, new SsaChangePointDetector.Arguments() { Confidence = 95, @@ -232,7 +232,7 @@ public void ChangePointDetectionWithSeasonalityPredictionEngine() // Pipeline. - var pipeline = ml.Transforms.Text.FeaturizeText("Text", "Text_Featurized") + var pipeline = ml.Transforms.Text.FeaturizeText("Text_Featurized", "Text") .Append(new SsaChangePointEstimator(ml, new SsaChangePointDetector.Arguments() { Confidence = 95, diff --git a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs index 68a3cdfdb6..e65c9da724 100644 --- a/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs +++ b/test/Microsoft.ML.TimeSeries.Tests/TimeSeriesEstimatorTests.cs @@ -59,8 +59,8 @@ void TestSsaChangePointEstimator() for (int i = 0; i < ChangeHistorySize; i++) data.Add(new Data(i * 100)); - var pipe = new SsaChangePointEstimator(Env, "Value", "Change", - Confidence, ChangeHistorySize, MaxTrainingSize, SeasonalitySize); + var pipe = new SsaChangePointEstimator(Env, "Change", + Confidence, ChangeHistorySize, MaxTrainingSize, SeasonalitySize, "Value"); var xyData = new List { new TestDataXY() { A = new float[inputSize] } }; var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } }; @@ -93,8 +93,8 @@ void TestSsaSpikeEstimator() for (int i = 0; i < PValueHistorySize; i++) data.Add(new Data(i * 100)); - var pipe = new SsaSpikeEstimator(Env, "Value", "Change", - Confidence, PValueHistorySize, MaxTrainingSize, SeasonalitySize); + var pipe = new SsaSpikeEstimator(Env, "Change", + Confidence, PValueHistorySize, MaxTrainingSize, SeasonalitySize, "Value"); var xyData = new List { new TestDataXY() { A = new float[inputSize] } }; var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } }; @@ -121,7 +121,7 @@ void TestIidChangePointEstimator() data.Add(new Data(i * 100)); var pipe = new IidChangePointEstimator(Env, - "Value", "Change", Confidence, ChangeHistorySize); + "Change", Confidence, ChangeHistorySize, "Value"); var xyData = new List { new TestDataXY() { A = new float[inputSize] } }; var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } }; @@ -147,8 +147,8 @@ void TestIidSpikeEstimator() for (int i = 0; i < PValueHistorySize; i++) data.Add(new Data(i * 100)); - var pipe = new IidSpikeEstimator(Env, - "Value", "Change", Confidence, PValueHistorySize); + var pipe = new IidSpikeEstimator(Env, + "Change", Confidence, PValueHistorySize, "Value"); var xyData = new List { new TestDataXY() { A = new float[inputSize] } }; var stringData = new List { new TestDataDifferntType() { data_0 = new string[inputSize] } };