From a285f8d5db15cd8bf77f3c03763bce16642f902b Mon Sep 17 00:00:00 2001 From: Ivan Matantsev Date: Thu, 18 Oct 2018 14:14:27 -0700 Subject: [PATCH] Convert TextNormalizer to estimator (#1276) --- .../Transforms/HashTransform.cs | 6 +- .../Transforms/KeyToVectorTransform.cs | 19 +- src/Microsoft.ML.Legacy/CSharpApi.cs | 4 +- src/Microsoft.ML.Transforms/RffTransform.cs | 2 +- .../Text/TextNormalizerTransform.cs | 588 +++++++++++------- .../Text/TextStaticExtensions.cs | 11 +- .../Text/TextTransform.cs | 444 +++++++------ .../Text/WrappedTextTransformers.cs | 85 --- .../BaselineOutput/Common/Text/Normalized.tsv | 17 + ...sticDualCoordinateAscentClassifierBench.cs | 4 +- test/Microsoft.ML.FSharp.Tests/SmokeTests.fs | 6 - .../StaticPipeTests.cs | 41 ++ .../Scenarios/Api/SimpleTrainAndPredict.cs | 5 +- .../Scenarios/ClusteringTests.cs | 2 - .../PipelineApi/SimpleTrainAndPredict.cs | 2 - .../Scenarios/SentimentPredictionTests.cs | 6 - .../SentimentPredictionTests.cs | 8 +- .../Transformers/TextFeaturizerTests.cs | 5 +- .../Transformers/TextNormalizer.cs | 101 +++ .../Transformers/WordEmbeddingsTests.cs | 8 +- 20 files changed, 761 insertions(+), 603 deletions(-) create mode 100644 test/BaselineOutput/Common/Text/Normalized.tsv create mode 100644 test/Microsoft.ML.Tests/Transformers/TextNormalizer.cs diff --git a/src/Microsoft.ML.Data/Transforms/HashTransform.cs b/src/Microsoft.ML.Data/Transforms/HashTransform.cs index a21fbf8002..076d622d1a 100644 --- a/src/Microsoft.ML.Data/Transforms/HashTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/HashTransform.cs @@ -1289,10 +1289,10 @@ public override void Process() /// public sealed class HashEstimator : IEstimator { - public const int NumBitsMin = 1; - public const int NumBitsLim = 32; + internal const int NumBitsMin = 1; + internal const int NumBitsLim = 32; - public static class Defaults + internal static class Defaults { public const int HashBits = NumBitsLim - 1; public const uint Seed = 314489979; diff --git a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs index bdef0a5a15..8fa24d69ce 100644 --- a/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs +++ b/src/Microsoft.ML.Data/Transforms/KeyToVectorTransform.cs @@ -208,17 +208,14 @@ private static IDataTransform Create(IHostEnvironment env, Arguments args, IData env.CheckValue(args.Column, nameof(args.Column)); var cols = new ColumnInfo[args.Column.Length]; - using (var ch = env.Start("ValidateArgs")) + for (int i = 0; i < cols.Length; i++) { - for (int i = 0; i < cols.Length; i++) - { - var item = args.Column[i]; + var item = args.Column[i]; - cols[i] = new ColumnInfo(item.Source ?? item.Name, - item.Name, - item.Bag ?? args.Bag); - }; - } + cols[i] = new ColumnInfo(item.Source ?? item.Name, + item.Name, + item.Bag ?? args.Bag); + }; return new KeyToVectorTransform(env, cols).MakeDataTransform(input); } @@ -727,7 +724,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src // Note that one input feature got expended to an one-hot vector. opType = "ReduceSum"; var reduceNode = ctx.CreateNode(opType, encodedVariableName, dstVariableName, ctx.GetNodeName(opType), ""); - reduceNode.AddAttribute("axes", new long[] { shape.Count - 1}); + reduceNode.AddAttribute("axes", new long[] { shape.Count - 1 }); reduceNode.AddAttribute("keepdims", 0); } return true; @@ -737,7 +734,7 @@ private bool SaveAsOnnxCore(OnnxContext ctx, int iinfo, ColInfo info, string src public sealed class KeyToVectorEstimator : TrivialEstimator { - public static class Defaults + internal static class Defaults { public const bool Bag = false; } diff --git a/src/Microsoft.ML.Legacy/CSharpApi.cs b/src/Microsoft.ML.Legacy/CSharpApi.cs index 7bb4a3f79b..11c255e19b 100644 --- a/src/Microsoft.ML.Legacy/CSharpApi.cs +++ b/src/Microsoft.ML.Legacy/CSharpApi.cs @@ -16789,7 +16789,7 @@ public enum TextTransformLanguage Japanese = 7 } - public enum TextNormalizerTransformCaseNormalizationMode + public enum TextNormalizerEstimatorCaseNormalizationMode { Lower = 0, Upper = 1, @@ -16877,7 +16877,7 @@ public void AddColumn(string name, params string[] source) /// /// Casing text using the rules of the invariant culture. /// - public TextNormalizerTransformCaseNormalizationMode TextCase { get; set; } = TextNormalizerTransformCaseNormalizationMode.Lower; + public TextNormalizerEstimatorCaseNormalizationMode TextCase { get; set; } = TextNormalizerEstimatorCaseNormalizationMode.Lower; /// /// Whether to keep diacritical marks or remove them. diff --git a/src/Microsoft.ML.Transforms/RffTransform.cs b/src/Microsoft.ML.Transforms/RffTransform.cs index 478212ff35..8b86243453 100644 --- a/src/Microsoft.ML.Transforms/RffTransform.cs +++ b/src/Microsoft.ML.Transforms/RffTransform.cs @@ -642,7 +642,7 @@ private void TransformFeatures(ref VBuffer src, ref VBuffer dst, T /// public sealed class RffEstimator : IEstimator { - public static class Defaults + internal static class Defaults { public const int NewDim = 1000; public const bool UseSin = false; diff --git a/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs b/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs index 794b1cb9bd..0ac1e87736 100644 --- a/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/TextNormalizerTransform.cs @@ -4,41 +4,39 @@ #pragma warning disable 420 // volatile with Interlocked.CompareExchange -using System; -using System.Collections.Generic; -using System.Text; -using System.Threading; +using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.Model; -using Microsoft.ML.Runtime.TextAnalytics; +using Microsoft.ML.Transforms.Text; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Threading; -[assembly: LoadableClass(TextNormalizerTransform.Summary, typeof(TextNormalizerTransform), typeof(TextNormalizerTransform.Arguments), typeof(SignatureDataTransform), +[assembly: LoadableClass(TextNormalizerTransform.Summary, typeof(IDataTransform), typeof(TextNormalizerTransform), typeof(TextNormalizerTransform.Arguments), typeof(SignatureDataTransform), "Text Normalizer Transform", "TextNormalizerTransform", "TextNormalizer", "TextNorm")] -[assembly: LoadableClass(TextNormalizerTransform.Summary, typeof(TextNormalizerTransform), null, typeof(SignatureLoadDataTransform), +[assembly: LoadableClass(TextNormalizerTransform.Summary, typeof(IDataTransform), typeof(TextNormalizerTransform), null, typeof(SignatureLoadDataTransform), "Text Normalizer Transform", TextNormalizerTransform.LoaderSignature)] -namespace Microsoft.ML.Runtime.TextAnalytics +[assembly: LoadableClass(TextNormalizerTransform.Summary, typeof(TextNormalizerTransform), null, typeof(SignatureLoadModel), + "Text Normalizer Transform", TextNormalizerTransform.LoaderSignature)] + +[assembly: LoadableClass(typeof(IRowMapper), typeof(TextNormalizerTransform), null, typeof(SignatureLoadRowMapper), + "Text Normalizer Transform", TextNormalizerTransform.LoaderSignature)] + +namespace Microsoft.ML.Transforms.Text { /// /// A text normalization transform that allows normalizing text case, removing diacritical marks, punctuation marks and/or numbers. /// The transform operates on text input as well as vector of tokens/text (vector of ReadOnlyMemory). /// - public sealed class TextNormalizerTransform : OneToOneTransformBase + public sealed class TextNormalizerTransform : OneToOneTransformerBase { - /// - /// Case normalization mode of text. This enumeration is serialized. - /// - public enum CaseNormalizationMode - { - Lower = 0, - Upper = 1, - None = 2 - } - public sealed class Column : OneToOneColumn { public static Column Parse(string str) @@ -62,23 +60,24 @@ public sealed class Arguments public Column[] Column; [Argument(ArgumentType.AtMostOnce, HelpText = "Casing text using the rules of the invariant culture.", ShortName = "case", SortOrder = 1)] - public CaseNormalizationMode TextCase = CaseNormalizationMode.Lower; + public TextNormalizerEstimator.CaseNormalizationMode TextCase = TextNormalizerEstimator.Defaults.TextCase; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep diacritical marks or remove them.", ShortName = "diac", SortOrder = 1)] - public bool KeepDiacritics = false; + public bool KeepDiacritics = TextNormalizerEstimator.Defaults.KeepDiacritics; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep punctuation marks or remove them.", ShortName = "punc", SortOrder = 2)] - public bool KeepPunctuations = true; + public bool KeepPunctuations = TextNormalizerEstimator.Defaults.KeepPunctuations; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep numbers or remove them.", ShortName = "num", SortOrder = 2)] - public bool KeepNumbers = true; + public bool KeepNumbers = TextNormalizerEstimator.Defaults.KeepNumbers; } internal const string Summary = "A text normalization transform that allows normalizing text case, removing diacritical marks, punctuation marks and/or numbers." + " The transform operates on text input as well as vector of tokens/text (vector of ReadOnlyMemory)."; - public const string LoaderSignature = "TextNormalizerTransform"; + internal const string LoaderSignature = nameof(TextNormalizerTransform); + private static VersionInfo GetVersionInfo() { return new VersionInfo( @@ -91,19 +90,146 @@ private static VersionInfo GetVersionInfo() } private const string RegistrationName = "TextNormalizer"; + public IReadOnlyCollection<(string input, string output)> Columns => ColumnPairs.AsReadOnly(); - // Arguments - private readonly CaseNormalizationMode _case; + private readonly TextNormalizerEstimator.CaseNormalizationMode _textCase; private readonly bool _keepDiacritics; private readonly bool _keepPunctuations; private readonly bool _keepNumbers; - // A map where keys are letters combined with diacritics and values are the letters without diacritics. - private static volatile Dictionary _combinedDiacriticsMap; + public TextNormalizerTransform(IHostEnvironment env, + TextNormalizerEstimator.CaseNormalizationMode textCase = TextNormalizerEstimator.Defaults.TextCase, + bool keepDiacritics = TextNormalizerEstimator.Defaults.KeepDiacritics, + bool keepPunctuations = TextNormalizerEstimator.Defaults.KeepPunctuations, + bool keepNumbers = TextNormalizerEstimator.Defaults.KeepNumbers, + params (string input, string output)[] columns) : + base(Contracts.CheckRef(env, nameof(env)).Register(RegistrationName), columns) + { + _textCase = textCase; + _keepDiacritics = keepDiacritics; + _keepPunctuations = keepPunctuations; + _keepNumbers = keepNumbers; - // List of pairs of (letters combined with diacritics, the letters without diacritics) from Office NL team. - private static readonly string[] _combinedDiacriticsPairs = + } + + protected override void CheckInputColumn(ISchema inputSchema, int col, int srcCol) + { + var type = inputSchema.GetColumnType(srcCol); + if (!TextNormalizerEstimator.IsColumnTypeValid(type)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", ColumnPairs[col].input, TextNormalizerEstimator.ExpectedColumnType, type.ToString()); + } + + public override void Save(ModelSaveContext ctx) + { + Host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); + + // *** Binary format *** + // + // byte: case + // bool: whether to keep diacritics + // bool: whether to keep punctuations + // bool: whether to keep numbers + SaveColumns(ctx); + + ctx.Writer.Write((byte)_textCase); + ctx.Writer.WriteBoolByte(_keepDiacritics); + ctx.Writer.WriteBoolByte(_keepPunctuations); + ctx.Writer.WriteBoolByte(_keepNumbers); + } + + // Factory method for SignatureLoadModel. + private static TextNormalizerTransform Create(IHostEnvironment env, ModelLoadContext ctx) { + Contracts.CheckValue(env, nameof(env)); + var host = env.Register(RegistrationName); + host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(GetVersionInfo()); + return new TextNormalizerTransform(host, ctx); + } + + private TextNormalizerTransform(IHost host, ModelLoadContext ctx) + : base(host, ctx) + { + var columnsLength = ColumnPairs.Length; + // *** Binary format *** + // + // byte: case + // bool: whether to keep diacritics + // bool: whether to keep punctuations + // bool: whether to keep numbers + _textCase = (TextNormalizerEstimator.CaseNormalizationMode)ctx.Reader.ReadByte(); + host.CheckDecode(Enum.IsDefined(typeof(TextNormalizerEstimator.CaseNormalizationMode), _textCase)); + + _keepDiacritics = ctx.Reader.ReadBoolByte(); + _keepPunctuations = ctx.Reader.ReadBoolByte(); + _keepNumbers = ctx.Reader.ReadBoolByte(); + } + + // Factory method for SignatureDataTransform. + private static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView input) + { + Contracts.CheckValue(env, nameof(env)); + env.CheckValue(args, nameof(args)); + env.CheckValue(input, nameof(input)); + + env.CheckValue(args.Column, nameof(args.Column)); + var cols = new (string input, string output)[args.Column.Length]; + for (int i = 0; i < cols.Length; i++) + { + var item = args.Column[i]; + cols[i] = (item.Source ?? item.Name, item.Name); + } + return new TextNormalizerTransform(env, args.TextCase, args.KeepDiacritics, args.KeepPunctuations, args.KeepNumbers, cols).MakeDataTransform(input); + } + + // Factory method for SignatureLoadDataTransform. + private static IDataTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) + => Create(env, ctx).MakeDataTransform(input); + + // Factory method for SignatureLoadRowMapper. + private static IRowMapper Create(IHostEnvironment env, ModelLoadContext ctx, ISchema inputSchema) + => Create(env, ctx).MakeRowMapper(inputSchema); + + protected override IRowMapper MakeRowMapper(ISchema schema) => new Mapper(this, Schema.Create(schema)); + + private sealed class Mapper : MapperBase + { + private readonly ColumnType[] _types; + private readonly TextNormalizerTransform _parent; + + public Mapper(TextNormalizerTransform parent, Schema inputSchema) + : base(parent.Host.Register(nameof(Mapper)), parent, inputSchema) + { + _parent = parent; + _types = new ColumnType[_parent.ColumnPairs.Length]; + for (int i = 0; i < _types.Length; i++) + { + inputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int srcCol); + var srcType = inputSchema.GetColumnType(srcCol); + _types[i] = srcType.IsVector ? new VectorType(TextType.Instance) : srcType; + } + } + + public override Schema.Column[] GetOutputColumns() + { + var result = new Schema.Column[_parent.ColumnPairs.Length]; + for (int i = 0; i < _parent.ColumnPairs.Length; i++) + { + InputSchema.TryGetColumnIndex(_parent.ColumnPairs[i].input, out int colIndex); + Host.Assert(colIndex >= 0); + result[i] = new Schema.Column(_parent.ColumnPairs[i].output, _types[i], null); + } + return result; + } + + // A map where keys are letters combined with diacritics and values are the letters without diacritics. + private static volatile Dictionary _combinedDiacriticsMap; + + // List of pairs of (letters combined with diacritics, the letters without diacritics) from Office NL team. + private static readonly string[] _combinedDiacriticsPairs = + { // Latin letters combined with diacritics: "ÀA", "ÁA", "ÂA", "ÃA", "ÄA", "ÅA", "ÇC", "ÈE", "ÉE", "ÊE", "ËE", "ÌI", "ÍI", "ÎI", "ÏI", "ÑN", "ÒO", "ÓO", "ÔO", "ÕO", "ÖO", "ÙU", "ÚU", "ÛU", "ÜU", "ÝY", "àa", "áa", "âa", "ãa", "äa", "åa", @@ -133,253 +259,255 @@ private static VersionInfo GetVersionInfo() "ӴЧ", "ӵч", "ӸЫ", "ӹы" }; - public TextNormalizerTransform(IHostEnvironment env, Arguments args, IDataView input) - : base(env, RegistrationName, Contracts.CheckRef(args, nameof(args)).Column, input, TestIsTextItem) - { - Host.AssertNonEmpty(Infos); - Host.Assert(Infos.Length == Utils.Size(args.Column)); - - using (var ch = Host.Start("Construction")) - { - ch.CheckUserArg(Enum.IsDefined(typeof(CaseNormalizationMode), args.TextCase), - nameof(args.TextCase), "Invalid case normalization mode"); - - _case = args.TextCase; - _keepDiacritics = args.KeepDiacritics; - _keepPunctuations = args.KeepPunctuations; - _keepNumbers = args.KeepNumbers; - } - Metadata.Seal(); - } - - private static Dictionary CombinedDiacriticsMap - { - get + private static Dictionary CombinedDiacriticsMap { - if (_combinedDiacriticsMap == null) + get { - var combinedDiacriticsMap = new Dictionary(); - for (int i = 0; i < _combinedDiacriticsPairs.Length; i++) + if (_combinedDiacriticsMap == null) { - Contracts.Assert(_combinedDiacriticsPairs[i].Length == 2); - combinedDiacriticsMap.Add(_combinedDiacriticsPairs[i][0], _combinedDiacriticsPairs[i][1]); + var combinedDiacriticsMap = new Dictionary(); + for (int i = 0; i < _combinedDiacriticsPairs.Length; i++) + { + Contracts.Assert(_combinedDiacriticsPairs[i].Length == 2); + combinedDiacriticsMap.Add(_combinedDiacriticsPairs[i][0], _combinedDiacriticsPairs[i][1]); + } + + Interlocked.CompareExchange(ref _combinedDiacriticsMap, combinedDiacriticsMap, null); } - Interlocked.CompareExchange(ref _combinedDiacriticsMap, combinedDiacriticsMap, null); + return _combinedDiacriticsMap; } - - return _combinedDiacriticsMap; } - } - private TextNormalizerTransform(IHost host, ModelLoadContext ctx, IDataView input) - : base(host, ctx, input, TestIsTextItem) - { - Host.AssertValue(ctx); - - using (var ch = Host.Start("Deserialization")) + protected override Delegate MakeGetter(IRow input, int iinfo, out Action disposer) { - // *** Binary format *** - // - // byte: case - // bool: whether to keep diacritics - // bool: whether to keep punctuations - // bool: whether to keep numbers - ch.AssertNonEmpty(Infos); - - _case = (CaseNormalizationMode)ctx.Reader.ReadByte(); - ch.CheckDecode(Enum.IsDefined(typeof(CaseNormalizationMode), _case)); - - _keepDiacritics = ctx.Reader.ReadBoolByte(); - _keepPunctuations = ctx.Reader.ReadBoolByte(); - _keepNumbers = ctx.Reader.ReadBoolByte(); - } - Metadata.Seal(); - } + Host.AssertValue(input); + Host.Assert(0 <= iinfo && iinfo < _parent.ColumnPairs.Length); + disposer = null; - public static TextNormalizerTransform Create(IHostEnvironment env, ModelLoadContext ctx, IDataView input) - { - Contracts.CheckValue(env, nameof(env)); - var h = env.Register(RegistrationName); - h.CheckValue(ctx, nameof(ctx)); - h.CheckValue(input, nameof(input)); - ctx.CheckAtModel(GetVersionInfo()); - return h.Apply("Loading Model", ch => new TextNormalizerTransform(h, ctx, input)); - } - - public override void Save(ModelSaveContext ctx) - { - Host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); - - // *** Binary format *** - // - // byte: case - // bool: whether to keep diacritics - // bool: whether to keep punctuations - // bool: whether to keep numbers - SaveBase(ctx); - - ctx.Writer.Write((byte)_case); - ctx.Writer.WriteBoolByte(_keepDiacritics); - ctx.Writer.WriteBoolByte(_keepPunctuations); - ctx.Writer.WriteBoolByte(_keepNumbers); - } - - protected override ColumnType GetColumnTypeCore(int iinfo) - { - Host.Assert(0 <= iinfo & iinfo < Infos.Length); - return Infos[iinfo].TypeSrc.IsVector ? new VectorType(TextType.Instance) : Infos[iinfo].TypeSrc; - } - - protected override Delegate GetGetterCore(IChannel ch, IRow input, int iinfo, out Action disposer) - { - Host.AssertValueOrNull(ch); - Host.AssertValue(input); - Host.Assert(0 <= iinfo && iinfo < Infos.Length); - disposer = null; + var srcType = input.Schema[_parent.ColumnPairs[iinfo].input].Type; + Host.Assert(srcType.ItemType.IsText); - var typeSrc = Infos[iinfo].TypeSrc; - Host.Assert(typeSrc.ItemType.IsText); + if (srcType.IsVector) + { + Host.Assert(srcType.VectorSize >= 0); + return MakeGetterVec(input, iinfo); + } - if (typeSrc.IsVector) - { - Host.Assert(typeSrc.VectorSize >= 0); - return MakeGetterVec(input, iinfo); + Host.Assert(!srcType.IsVector); + return MakeGetterOne(input, iinfo); } - Host.Assert(!typeSrc.IsVector); - return MakeGetterOne(input, iinfo); - } - - private ValueGetter> MakeGetterOne(IRow input, int iinfo) - { - Contracts.Assert(Infos[iinfo].TypeSrc.IsText); - var getSrc = GetSrcGetter>(input, iinfo); - Host.AssertValue(getSrc); - var src = default(ReadOnlyMemory); - var buffer = new StringBuilder(); - return - (ref ReadOnlyMemory dst) => - { - getSrc(ref src); - NormalizeSrc(ref src, ref dst, buffer); - }; - } - - private ValueGetter>> MakeGetterVec(IRow input, int iinfo) - { - var getSrc = GetSrcGetter>>(input, iinfo); - Host.AssertValue(getSrc); - var src = default(VBuffer>); - var buffer = new StringBuilder(); - var list = new List>(); - var temp = default(ReadOnlyMemory); - return - (ref VBuffer> dst) => - { - getSrc(ref src); - list.Clear(); - for (int i = 0; i < src.Count; i++) + private ValueGetter> MakeGetterOne(IRow input, int iinfo) + { + var getSrc = input.GetGetter>(ColMapNewToOld[iinfo]); + Host.AssertValue(getSrc); + var src = default(ReadOnlyMemory); + var buffer = new StringBuilder(); + return + (ref ReadOnlyMemory dst) => { - NormalizeSrc(ref src.Values[i], ref temp, buffer); - if (!temp.IsEmpty) - list.Add(temp); - } - - VBufferUtils.Copy(list, ref dst, list.Count); - }; - } - - private void NormalizeSrc(ref ReadOnlyMemory src, ref ReadOnlyMemory dst, StringBuilder buffer) - { - Host.AssertValue(buffer); + getSrc(ref src); + NormalizeSrc(ref src, ref dst, buffer); + }; + } - if (src.IsEmpty) + private ValueGetter>> MakeGetterVec(IRow input, int iinfo) { - dst = src; - return; + var getSrc = input.GetGetter>>(ColMapNewToOld[iinfo]); + Host.AssertValue(getSrc); + var src = default(VBuffer>); + var buffer = new StringBuilder(); + var list = new List>(); + var temp = default(ReadOnlyMemory); + return + (ref VBuffer> dst) => + { + getSrc(ref src); + list.Clear(); + for (int i = 0; i < src.Count; i++) + { + NormalizeSrc(ref src.Values[i], ref temp, buffer); + if (!temp.IsEmpty) + list.Add(temp); + } + + VBufferUtils.Copy(list, ref dst, list.Count); + }; } - buffer.Clear(); - - int i = 0; - int min = 0; - var span = src.Span; - while (i < src.Length) + private void NormalizeSrc(ref ReadOnlyMemory src, ref ReadOnlyMemory dst, StringBuilder buffer) { - char ch = span[i]; - if (!_keepPunctuations && char.IsPunctuation(ch) || !_keepNumbers && char.IsNumber(ch)) + Host.AssertValue(buffer); + + if (src.IsEmpty) { - // Append everything before ch and ignore ch. - buffer.AppendSpan(span.Slice(min, i - min)); - min = i + 1; - i++; - continue; + dst = src; + return; } - if (!_keepDiacritics) + buffer.Clear(); + + int i = 0; + int min = 0; + var span = src.Span; + while (i < src.Length) { - if (IsCombiningDiacritic(ch)) + char ch = span[i]; + if (!_parent._keepPunctuations && char.IsPunctuation(ch) || !_parent._keepNumbers && char.IsNumber(ch)) { + // Append everything before ch and ignore ch. buffer.AppendSpan(span.Slice(min, i - min)); min = i + 1; i++; continue; } - if (CombinedDiacriticsMap.ContainsKey(ch)) - ch = CombinedDiacriticsMap[ch]; - } + if (!_parent._keepDiacritics) + { + if (IsCombiningDiacritic(ch)) + { + buffer.AppendSpan(span.Slice(min, i - min)); + min = i + 1; + i++; + continue; + } + + if (CombinedDiacriticsMap.ContainsKey(ch)) + ch = CombinedDiacriticsMap[ch]; + } - if (_case == CaseNormalizationMode.Lower) - ch = CharUtils.ToLowerInvariant(ch); - else if (_case == CaseNormalizationMode.Upper) - ch = CharUtils.ToUpperInvariant(ch); + if (_parent._textCase == TextNormalizerEstimator.CaseNormalizationMode.Lower) + ch = CharUtils.ToLowerInvariant(ch); + else if (_parent._textCase == TextNormalizerEstimator.CaseNormalizationMode.Upper) + ch = CharUtils.ToUpperInvariant(ch); - if (ch != src.Span[i]) - { - buffer.AppendSpan(span.Slice(min, i - min)).Append(ch); - min = i + 1; + if (ch != src.Span[i]) + { + buffer.AppendSpan(span.Slice(min, i - min)).Append(ch); + min = i + 1; + } + + i++; } - i++; + Host.Assert(i == src.Length); + int len = i - min; + if (min == 0) + { + Host.Assert(src.Length == len); + dst = src; + } + else + { + buffer.AppendSpan(span.Slice(min, len)); + dst = buffer.ToString().AsMemory(); + } } - Host.Assert(i == src.Length); - int len = i - min; - if (min == 0) + /// + /// Whether a character is a combining diacritic character or not. + /// Combining diacritic characters are the set of diacritics intended to modify other characters. + /// The list is provided by Office NL team. + /// + private bool IsCombiningDiacritic(char ch) { - Host.Assert(src.Length == len); - dst = src; - } - else - { - buffer.AppendSpan(span.Slice(min, len)); - dst = buffer.ToString().AsMemory(); + if (ch < 0x0300 || ch > 0x0670) + return false; + + // Basic combining diacritics + return ch >= 0x0300 && ch <= 0x036F || + + // Hebrew combining diacritics + ch >= 0x0591 && ch <= 0x05BD || ch == 0x05C1 || ch == 0x05C2 || ch == 0x05C4 || + ch == 0x05C5 || ch == 0x05C7 || + + // Arabic combining diacritics + ch >= 0x0610 && ch <= 0x0615 || ch >= 0x064C && ch <= 0x065E || ch == 0x0670; } } + } + public sealed class TextNormalizerEstimator : TrivialEstimator + { /// - /// Whether a character is a combining diacritic character or not. - /// Combining diacritic characters are the set of diacritics intended to modify other characters. - /// The list is provided by Office NL team. + /// Case normalization mode of text. This enumeration is serialized. /// - private bool IsCombiningDiacritic(char ch) + public enum CaseNormalizationMode { - if (ch < 0x0300 || ch > 0x0670) - return false; + Lower = 0, + Upper = 1, + None = 2 + } + + internal static class Defaults + { + public const CaseNormalizationMode TextCase = CaseNormalizationMode.Lower; + public const bool KeepDiacritics = false; + public const bool KeepPunctuations = true; + public const bool KeepNumbers = true; + + } + + public static bool IsColumnTypeValid(ColumnType type) => (type.ItemType.IsText); + + internal const string ExpectedColumnType = "Text or vector of text."; - // Basic combining diacritics - return ch >= 0x0300 && ch <= 0x036F || + /// + /// Normalizes incoming text in by changing case, removing diacritical marks, punctuation marks and/or numbers + /// and outputs new text as . + /// + /// The environment. + /// The column containing text to normalize. + /// The column containing output tokens. Null means is replaced. + /// Casing text using the rules of the invariant culture. + /// Whether to keep diacritical marks or remove them. + /// Whether to keep punctuation marks or remove them. + /// Whether to keep numbers or remove them. + public TextNormalizerEstimator(IHostEnvironment env, + string inputColumn, + string outputColumn = null, + CaseNormalizationMode textCase = Defaults.TextCase, + bool keepDiacritics = Defaults.KeepDiacritics, + bool keepPunctuations = Defaults.KeepPunctuations, + bool keepNumbers = Defaults.KeepNumbers) + : this(env, textCase, keepDiacritics, keepPunctuations, keepNumbers, (inputColumn, outputColumn ?? inputColumn)) + { + } - // Hebrew combining diacritics - ch >= 0x0591 && ch <= 0x05BD || ch == 0x05C1 || ch == 0x05C2 || ch == 0x05C4 || - ch == 0x05C5 || ch == 0x05C7 || + /// + /// Normalizes incoming text in input columns by changing case, removing diacritical marks, punctuation marks and/or numbers + /// and outputs new text as output columns. + /// + /// The environment. + /// Casing text using the rules of the invariant culture. + /// Whether to keep diacritical marks or remove them. + /// Whether to keep punctuation marks or remove them. + /// Whether to keep numbers or remove them. + /// Pairs of columns to run the text normalization on. + public TextNormalizerEstimator(IHostEnvironment env, + CaseNormalizationMode textCase = Defaults.TextCase, + bool keepDiacritics = Defaults.KeepDiacritics, + bool keepPunctuations = Defaults.KeepPunctuations, + bool keepNumbers = Defaults.KeepNumbers, + params (string input, string output)[] columns) + : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextNormalizerEstimator)), new TextNormalizerTransform(env, textCase, keepDiacritics, keepPunctuations, keepNumbers, columns)) + { + } - // Arabic combining diacritics - ch >= 0x0610 && ch <= 0x0615 || ch >= 0x064C && ch <= 0x065E || ch == 0x0670; + public override SchemaShape GetOutputSchema(SchemaShape inputSchema) + { + Host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var colInfo in Transformer.Columns) + { + if (!inputSchema.TryFindColumn(colInfo.input, out var col)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input); + if (!IsColumnTypeValid(col.ItemType)) + throw Host.ExceptSchemaMismatch(nameof(inputSchema), "input", colInfo.input, TextNormalizerEstimator.ExpectedColumnType, col.ItemType.ToString()); + result[colInfo.output] = new SchemaShape.Column(colInfo.output, col.Kind == SchemaShape.Column.VectorKind.Vector ? SchemaShape.Column.VectorKind.VariableVector : SchemaShape.Column.VectorKind.Scalar, col.ItemType, false); + } + return new SchemaShape(result.Values); } } } diff --git a/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs b/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs index 0dc577772e..763a3c768e 100644 --- a/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs +++ b/src/Microsoft.ML.Transforms/Text/TextStaticExtensions.cs @@ -11,7 +11,6 @@ using System; using System.Collections.Generic; using static Microsoft.ML.Runtime.TextAnalytics.StopWordsRemoverTransform; -using static Microsoft.ML.Runtime.TextAnalytics.TextNormalizerTransform; namespace Microsoft.ML.Transforms.Text { @@ -182,7 +181,7 @@ private sealed class OutPipelineColumn : Scalar { public readonly Scalar Input; - public OutPipelineColumn(Scalar input, CaseNormalizationMode textCase, bool keepDiacritics, bool keepPunctuations, bool keepNumbers) + public OutPipelineColumn(Scalar input, TextNormalizerEstimator.CaseNormalizationMode textCase, bool keepDiacritics, bool keepPunctuations, bool keepNumbers) : base(new Reconciler(textCase, keepDiacritics, keepPunctuations, keepNumbers), input) { Input = input; @@ -191,12 +190,12 @@ public OutPipelineColumn(Scalar input, CaseNormalizationMode textCase, b private sealed class Reconciler : EstimatorReconciler, IEquatable { - private readonly CaseNormalizationMode _textCase; + private readonly TextNormalizerEstimator.CaseNormalizationMode _textCase; private readonly bool _keepDiacritics; private readonly bool _keepPunctuations; private readonly bool _keepNumbers; - public Reconciler(CaseNormalizationMode textCase, bool keepDiacritics, bool keepPunctuations, bool keepNumbers) + public Reconciler(TextNormalizerEstimator.CaseNormalizationMode textCase, bool keepDiacritics, bool keepPunctuations, bool keepNumbers) { _textCase = textCase; _keepDiacritics = keepDiacritics; @@ -225,7 +224,7 @@ public override IEstimator Reconcile(IHostEnvironment env, foreach (var outCol in toOutput) pairs.Add((inputNames[((OutPipelineColumn)outCol).Input], outputNames[outCol])); - return new TextNormalizer(env, pairs.ToArray(), _textCase, _keepDiacritics, _keepPunctuations, _keepNumbers); + return new TextNormalizerEstimator(env, _textCase, _keepDiacritics, _keepPunctuations, _keepNumbers, pairs.ToArray()); } } @@ -238,7 +237,7 @@ public override IEstimator Reconcile(IHostEnvironment env, /// Whether to keep punctuation marks or remove them. /// Whether to keep numbers or remove them. public static Scalar NormalizeText(this Scalar input, - CaseNormalizationMode textCase = CaseNormalizationMode.Lower, + TextNormalizerEstimator.CaseNormalizationMode textCase = TextNormalizerEstimator.CaseNormalizationMode.Lower, bool keepDiacritics = false, bool keepPunctuations = true, bool keepNumbers = true) => new OutPipelineColumn(input, textCase, keepDiacritics, keepPunctuations, keepNumbers); diff --git a/src/Microsoft.ML.Transforms/Text/TextTransform.cs b/src/Microsoft.ML.Transforms/Text/TextTransform.cs index f9e3972be8..2ce6c77629 100644 --- a/src/Microsoft.ML.Transforms/Text/TextTransform.cs +++ b/src/Microsoft.ML.Transforms/Text/TextTransform.cs @@ -2,10 +2,6 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using System; -using System.Collections.Generic; -using System.Linq; -using System.Text; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime; using Microsoft.ML.Runtime.CommandLine; @@ -18,6 +14,11 @@ using Microsoft.ML.Runtime.TextAnalytics; using Microsoft.ML.StaticPipe; using Microsoft.ML.StaticPipe.Runtime; +using Microsoft.ML.Transforms.Text; +using System; +using System.Collections.Generic; +using System.Linq; +using System.Text; [assembly: LoadableClass(TextTransform.Summary, typeof(IDataTransform), typeof(TextTransform), typeof(TextTransform.Arguments), typeof(SignatureDataTransform), TextTransform.UserName, "TextTransform", TextTransform.LoaderSignature)] @@ -27,12 +28,8 @@ namespace Microsoft.ML.Runtime.Data { - using StopWordsArgs = StopWordsRemoverTransform.Arguments; - using TextNormalizerArgs = TextNormalizerTransform.Arguments; + using CaseNormalizationMode = TextNormalizerEstimator.CaseNormalizationMode; using StopWordsCol = StopWordsRemoverTransform.Column; - using TextNormalizerCol = TextNormalizerTransform.Column; - using StopWordsLang = StopWordsRemoverTransform.Language; - using CaseNormalizationMode = TextNormalizerTransform.CaseNormalizationMode; // A transform that turns a collection of text documents into numerical feature vectors. The feature vectors are counts // of (word or character) ngrams in a given text. It offers ngram hashing (finding the ngram token string name to feature @@ -97,16 +94,16 @@ public sealed class Arguments : TransformInputBase public IStopWordsRemoverFactory StopWordsRemover; [Argument(ArgumentType.AtMostOnce, HelpText = "Casing text using the rules of the invariant culture.", ShortName = "case", SortOrder = 5)] - public CaseNormalizationMode TextCase = CaseNormalizationMode.Lower; + public CaseNormalizationMode TextCase = TextNormalizerEstimator.Defaults.TextCase; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep diacritical marks or remove them.", ShortName = "diac", SortOrder = 6)] - public bool KeepDiacritics; + public bool KeepDiacritics = TextNormalizerEstimator.Defaults.KeepDiacritics; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep punctuation marks or remove them.", ShortName = "punc", SortOrder = 7)] - public bool KeepPunctuations = true; + public bool KeepPunctuations = TextNormalizerEstimator.Defaults.KeepPunctuations; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to keep numbers or remove them.", ShortName = "num", SortOrder = 8)] - public bool KeepNumbers = true; + public bool KeepNumbers = TextNormalizerEstimator.Defaults.KeepNumbers; [Argument(ArgumentType.AtMostOnce, HelpText = "Whether to output the transformed text tokens as an additional column.", ShortName = "tokens,showtext,showTransformedText", SortOrder = 9)] public bool OutputTokens; @@ -323,58 +320,51 @@ public ITransformer Fit(IDataView input) if (tparams.NeedsNormalizeTransform) { - var xfCols = new TextNormalizerCol[textCols.Length]; + var xfCols = new (string input, string output)[textCols.Length]; string[] dstCols = new string[textCols.Length]; for (int i = 0; i < textCols.Length; i++) { dstCols[i] = GenerateColumnName(view.Schema, textCols[i], "TextNormalizer"); tempCols.Add(dstCols[i]); - xfCols[i] = new TextNormalizerCol() { Source = textCols[i], Name = dstCols[i] }; - } + xfCols[i] = (textCols[i], dstCols[i]); + } - view = new TextNormalizerTransform(h, - new TextNormalizerArgs() - { - Column = xfCols, - KeepDiacritics = tparams.KeepDiacritics, - KeepNumbers = tparams.KeepNumbers, - KeepPunctuations = tparams.KeepPunctuations, - TextCase = tparams.TextCase - }, view); + view = new TextNormalizerEstimator(h, tparams.TextCase, tparams.KeepDiacritics, tparams.KeepPunctuations, tparams.KeepNumbers, xfCols).Fit(view).Transform(view); - textCols = dstCols; - } + textCols = dstCols; + } if (tparams.NeedsWordTokenizationTransform) { var xfCols = new DelimitedTokenizeTransform.Column[textCols.Length]; - wordTokCols = new string[textCols.Length]; - for (int i = 0; i < textCols.Length; i++) + wordTokCols = new string[textCols.Length]; + for (int i = 0; i new Transformer(env, ctx); - private static string GenerateColumnName(ISchema schema, string srcName, string xfTag) - { - return schema.GetTempColumnName(string.Format("{0}_{1}", srcName, xfTag)); - } +private static string GenerateColumnName(ISchema schema, string srcName, string xfTag) +{ + return schema.GetTempColumnName(string.Format("{0}_{1}", srcName, xfTag)); +} - public SchemaShape GetOutputSchema(SchemaShape inputSchema) - { - _host.CheckValue(inputSchema, nameof(inputSchema)); - var result = inputSchema.Columns.ToDictionary(x => x.Name); - foreach (var srcName in _inputColumns) - { - if (!inputSchema.TryFindColumn(srcName, out var col)) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName); - if (!col.ItemType.IsText) - throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName, "scalar or vector of text", col.GetTypeString()); - } +public SchemaShape GetOutputSchema(SchemaShape inputSchema) +{ + _host.CheckValue(inputSchema, nameof(inputSchema)); + var result = inputSchema.Columns.ToDictionary(x => x.Name); + foreach (var srcName in _inputColumns) + { + if (!inputSchema.TryFindColumn(srcName, out var col)) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName); + if (!col.ItemType.IsText) + throw _host.ExceptSchemaMismatch(nameof(inputSchema), "input", srcName, "scalar or vector of text", col.GetTypeString()); + } - var metadata = new List(2); - metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); - if (AdvancedSettings.VectorNormalizer != TextNormKind.None) - metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); + var metadata = new List(2); + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.SlotNames, SchemaShape.Column.VectorKind.Vector, TextType.Instance, false)); + if (AdvancedSettings.VectorNormalizer != TextNormKind.None) + metadata.Add(new SchemaShape.Column(MetadataUtils.Kinds.IsNormalized, SchemaShape.Column.VectorKind.Scalar, BoolType.Instance, false)); - result[OutputColumn] = new SchemaShape.Column(OutputColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, - new SchemaShape(metadata)); - if (AdvancedSettings.OutputTokens) - { - string name = string.Format(TransformedTextColFormat, OutputColumn); - result[name] = new SchemaShape.Column(name, SchemaShape.Column.VectorKind.VariableVector, TextType.Instance, false); - } + result[OutputColumn] = new SchemaShape.Column(OutputColumn, SchemaShape.Column.VectorKind.Vector, NumberType.R4, false, + new SchemaShape(metadata)); + if (AdvancedSettings.OutputTokens) + { + string name = string.Format(TransformedTextColFormat, OutputColumn); + result[name] = new SchemaShape.Column(name, SchemaShape.Column.VectorKind.VariableVector, TextType.Instance, false); + } - return new SchemaShape(result.Values); - } + return new SchemaShape(result.Values); +} - public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView data) - { - Action settings = s => - { - s.TextLanguage = args.Language; - s.TextCase = args.TextCase; - s.KeepDiacritics = args.KeepDiacritics; - s.KeepPunctuations = args.KeepPunctuations; - s.KeepNumbers = args.KeepNumbers; - s.OutputTokens = args.OutputTokens; - s.VectorNormalizer = args.VectorNormalizer; - }; - - var estimator = new TextTransform(env, args.Column.Source ?? new[] { args.Column.Name }, args.Column.Name, settings); - estimator._stopWordsRemover = args.StopWordsRemover; - estimator._dictionary = args.Dictionary; - estimator._wordFeatureExtractor = args.WordFeatureExtractor; - estimator._charFeatureExtractor = args.CharFeatureExtractor; - return estimator.Fit(data).Transform(data) as IDataTransform; - } +public static IDataTransform Create(IHostEnvironment env, Arguments args, IDataView data) +{ + Action settings = s => + { + s.TextLanguage = args.Language; + s.TextCase = args.TextCase; + s.KeepDiacritics = args.KeepDiacritics; + s.KeepPunctuations = args.KeepPunctuations; + s.KeepNumbers = args.KeepNumbers; + s.OutputTokens = args.OutputTokens; + s.VectorNormalizer = args.VectorNormalizer; + }; + + var estimator = new TextTransform(env, args.Column.Source ?? new[] { args.Column.Name }, args.Column.Name, settings); + estimator._stopWordsRemover = args.StopWordsRemover; + estimator._dictionary = args.Dictionary; + estimator._wordFeatureExtractor = args.WordFeatureExtractor; + estimator._charFeatureExtractor = args.CharFeatureExtractor; + return estimator.Fit(data).Transform(data) as IDataTransform; +} - private sealed class Transformer : ITransformer, ICanSaveModel - { - private const string TransformDirTemplate = "Step_{0:000}"; +private sealed class Transformer : ITransformer, ICanSaveModel +{ + private const string TransformDirTemplate = "Step_{0:000}"; - private readonly IHost _host; - private readonly IDataView _xf; + private readonly IHost _host; + private readonly IDataView _xf; - public Transformer(IHostEnvironment env, IDataView input, IDataView view) - { - _host = env.Register(nameof(Transformer)); - _xf = ApplyTransformUtils.ApplyAllTransformsToData(_host, view, new EmptyDataView(_host, input.Schema), input); - } + public Transformer(IHostEnvironment env, IDataView input, IDataView view) + { + _host = env.Register(nameof(Transformer)); + _xf = ApplyTransformUtils.ApplyAllTransformsToData(_host, view, new EmptyDataView(_host, input.Schema), input); + } - public Schema GetOutputSchema(Schema inputSchema) - { - _host.CheckValue(inputSchema, nameof(inputSchema)); - return Transform(new EmptyDataView(_host, inputSchema)).Schema; - } + public Schema GetOutputSchema(Schema inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + return Transform(new EmptyDataView(_host, inputSchema)).Schema; + } - public IDataView Transform(IDataView input) - { - _host.CheckValue(input, nameof(input)); - return ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); - } + public IDataView Transform(IDataView input) + { + _host.CheckValue(input, nameof(input)); + return ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); + } - public bool IsRowToRowMapper => true; + public bool IsRowToRowMapper => true; - public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) - { - _host.CheckValue(inputSchema, nameof(inputSchema)); - var input = new EmptyDataView(_host, inputSchema); - var revMaps = new List(); - IDataView chain; - for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); chain is IDataTransform xf; chain = xf.Source) - { - // Everything in the chain ought to be a row mapper. - _host.Assert(xf is IRowToRowMapper); - revMaps.Add((IRowToRowMapper)xf); - } - // The walkback should have ended at the input. - Contracts.Assert(chain == input); - revMaps.Reverse(); - return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray()); - } + public IRowToRowMapper GetRowToRowMapper(Schema inputSchema) + { + _host.CheckValue(inputSchema, nameof(inputSchema)); + var input = new EmptyDataView(_host, inputSchema); + var revMaps = new List(); + IDataView chain; + for (chain = ApplyTransformUtils.ApplyAllTransformsToData(_host, _xf, input); chain is IDataTransform xf; chain = xf.Source) + { + // Everything in the chain ought to be a row mapper. + _host.Assert(xf is IRowToRowMapper); + revMaps.Add((IRowToRowMapper)xf); + } + // The walkback should have ended at the input. + Contracts.Assert(chain == input); + revMaps.Reverse(); + return new CompositeRowToRowMapper(inputSchema, revMaps.ToArray()); + } - public void Save(ModelSaveContext ctx) - { - _host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(); - ctx.SetVersionInfo(GetVersionInfo()); + public void Save(ModelSaveContext ctx) + { + _host.CheckValue(ctx, nameof(ctx)); + ctx.CheckAtModel(); + ctx.SetVersionInfo(GetVersionInfo()); - var dataPipe = _xf; - var transforms = new List(); - while (dataPipe is IDataTransform xf) - { - transforms.Add(xf); - dataPipe = xf.Source; - Contracts.AssertValue(dataPipe); - } - transforms.Reverse(); + var dataPipe = _xf; + var transforms = new List(); + while (dataPipe is IDataTransform xf) + { + transforms.Add(xf); + dataPipe = xf.Source; + Contracts.AssertValue(dataPipe); + } + transforms.Reverse(); - ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema)); + ctx.SaveSubModel("Loader", c => BinaryLoader.SaveInstance(_host, c, dataPipe.Schema)); - ctx.Writer.Write(transforms.Count); - for (int i = 0; i < transforms.Count; i++) - { - var dirName = string.Format(TransformDirTemplate, i); - ctx.SaveModel(transforms[i], dirName); - } - } + ctx.Writer.Write(transforms.Count); + for (int i = 0; i < transforms.Count; i++) + { + var dirName = string.Format(TransformDirTemplate, i); + ctx.SaveModel(transforms[i], dirName); + } + } - public Transformer(IHostEnvironment env, ModelLoadContext ctx) - { - Contracts.CheckValue(env, nameof(env)); - _host = env.Register(nameof(Transformer)); - _host.CheckValue(ctx, nameof(ctx)); + public Transformer(IHostEnvironment env, ModelLoadContext ctx) + { + Contracts.CheckValue(env, nameof(env)); + _host = env.Register(nameof(Transformer)); + _host.CheckValue(ctx, nameof(ctx)); - ctx.CheckAtModel(GetVersionInfo()); - int n = ctx.Reader.ReadInt32(); + ctx.CheckAtModel(GetVersionInfo()); + int n = ctx.Reader.ReadInt32(); - ctx.LoadModel(env, out var loader, "Loader", new MultiFileSource(null)); + ctx.LoadModel(env, out var loader, "Loader", new MultiFileSource(null)); - IDataView data = loader; - for (int i = 0; i < n; i++) - { - var dirName = string.Format(TransformDirTemplate, i); - ctx.LoadModel(env, out var xf, dirName, data); - data = xf; - } + IDataView data = loader; + for (int i = 0; i < n; i++) + { + var dirName = string.Format(TransformDirTemplate, i); + ctx.LoadModel(env, out var xf, dirName, data); + data = xf; + } - _xf = data; - } + _xf = data; + } - private static VersionInfo GetVersionInfo() - { - return new VersionInfo( - modelSignature: "TEXT XFR", - verWrittenCur: 0x00010001, // Initial - verReadableCur: 0x00010001, - verWeCanReadBack: 0x00010001, - loaderSignature: LoaderSignature, - loaderAssemblyName: typeof(Transformer).Assembly.FullName); - } - } + private static VersionInfo GetVersionInfo() + { + return new VersionInfo( + modelSignature: "TEXT XFR", + verWrittenCur: 0x00010001, // Initial + verReadableCur: 0x00010001, + verWeCanReadBack: 0x00010001, + loaderSignature: LoaderSignature, + loaderAssemblyName: typeof(Transformer).Assembly.FullName); + } +} - internal sealed class OutPipelineColumn : Vector - { - public readonly Scalar[] Inputs; +internal sealed class OutPipelineColumn : Vector +{ + public readonly Scalar[] Inputs; - public OutPipelineColumn(IEnumerable> inputs, Action advancedSettings) - : base(new Reconciler(advancedSettings), inputs.ToArray()) - { - Inputs = inputs.ToArray(); - } - } + public OutPipelineColumn(IEnumerable> inputs, Action advancedSettings) + : base(new Reconciler(advancedSettings), inputs.ToArray()) + { + Inputs = inputs.ToArray(); + } +} - private sealed class Reconciler : EstimatorReconciler - { - private readonly Action _settings; +private sealed class Reconciler : EstimatorReconciler +{ + private readonly Action _settings; - public Reconciler(Action advancedSettings) - { - _settings = advancedSettings; - } + public Reconciler(Action advancedSettings) + { + _settings = advancedSettings; + } - public override IEstimator Reconcile(IHostEnvironment env, - PipelineColumn[] toOutput, - IReadOnlyDictionary inputNames, - IReadOnlyDictionary outputNames, - IReadOnlyCollection usedNames) - { - Contracts.Assert(toOutput.Length == 1); + public override IEstimator Reconcile(IHostEnvironment env, + PipelineColumn[] toOutput, + IReadOnlyDictionary inputNames, + IReadOnlyDictionary outputNames, + IReadOnlyCollection usedNames) + { + Contracts.Assert(toOutput.Length == 1); - var outCol = (OutPipelineColumn)toOutput[0]; - var inputs = outCol.Inputs.Select(x => inputNames[x]); - return new TextTransform(env, inputs, outputNames[outCol], _settings); - } - } + var outCol = (OutPipelineColumn)toOutput[0]; + var inputs = outCol.Inputs.Select(x => inputNames[x]); + return new TextTransform(env, inputs, outputNames[outCol], _settings); + } +} } /// /// Extension methods for the static-pipeline over objects. /// public static class TextFeaturizerStaticPipe +{ + /// + /// Accept text data and converts it to array which represent combinations of ngram/skip-gram token counts. + /// + /// Input data. + /// Additional data. + /// Delegate which allows you to set transformation settings. + /// + public static Vector FeaturizeText(this Scalar input, Scalar[] otherInputs = null, Action advancedSettings = null) { - /// - /// Accept text data and converts it to array which represent combinations of ngram/skip-gram token counts. - /// - /// Input data. - /// Additional data. - /// Delegate which allows you to set transformation settings. - /// - public static Vector FeaturizeText(this Scalar input, Scalar[] otherInputs = null, Action advancedSettings = null) - { - Contracts.CheckValue(input, nameof(input)); - Contracts.CheckValueOrNull(otherInputs); - otherInputs = otherInputs ?? new Scalar[0]; - return new TextTransform.OutPipelineColumn(new[] { input }.Concat(otherInputs), advancedSettings); - } + Contracts.CheckValue(input, nameof(input)); + Contracts.CheckValueOrNull(otherInputs); + otherInputs = otherInputs ?? new Scalar[0]; + return new TextTransform.OutPipelineColumn(new[] { input }.Concat(otherInputs), advancedSettings); } } +} diff --git a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs index 6126fb8358..6cd2f878a9 100644 --- a/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs +++ b/src/Microsoft.ML.Transforms/Text/WrappedTextTransformers.cs @@ -7,12 +7,8 @@ using Microsoft.ML.Runtime.Internal.Utilities; using Microsoft.ML.Runtime.TextAnalytics; using System; -using System.Collections.Generic; using System.Linq; -using System.Text; -using static Microsoft.ML.Runtime.TextAnalytics.LdaTransform; using static Microsoft.ML.Runtime.TextAnalytics.StopWordsRemoverTransform; -using static Microsoft.ML.Runtime.TextAnalytics.TextNormalizerTransform; namespace Microsoft.ML.Transforms { @@ -180,87 +176,6 @@ private static TransformWrapper MakeTransformer(IHostEnvironment env, (string in } } - /// - /// Text normalizer allows normalizing text by changing case (Upper/Lower case), removing diacritical marks, punctuation marks and/or numbers. - /// - public sealed class TextNormalizer : TrivialWrapperEstimator - { - /// - /// Normalizes incoming text in by changing case, removing diacritical marks, punctuation marks and/or numbers - /// and outputs new text as . - /// - /// The environment. - /// The column containing text to normalize. - /// The column containing output tokens. Null means is replaced. - /// Casing text using the rules of the invariant culture. - /// Whether to keep diacritical marks or remove them. - /// Whether to keep punctuation marks or remove them. - /// Whether to keep numbers or remove them. - public TextNormalizer(IHostEnvironment env, - string inputColumn, - string outputColumn = null, - CaseNormalizationMode textCase = CaseNormalizationMode.Lower, - bool keepDiacritics = false, - bool keepPunctuations = true, - bool keepNumbers = true) - : this(env, new[] { (inputColumn, outputColumn ?? inputColumn) }, textCase, keepDiacritics, keepPunctuations, keepNumbers) - { - } - - /// - /// Normalizes incoming text in input columns by changing case, removing diacritical marks, punctuation marks and/or numbers - /// and outputs new text as output columns. - /// - /// The environment. - /// Pairs of columns to run the text normalization on. - /// Casing text using the rules of the invariant culture. - /// Whether to keep diacritical marks or remove them. - /// Whether to keep punctuation marks or remove them. - /// Whether to keep numbers or remove them. - public TextNormalizer(IHostEnvironment env, - (string input, string output)[] columns, - CaseNormalizationMode textCase = CaseNormalizationMode.Lower, - bool keepDiacritics = false, - bool keepPunctuations = true, - bool keepNumbers = true) - : base(Contracts.CheckRef(env, nameof(env)).Register(nameof(TextNormalizer)), - MakeTransformer(env, columns, textCase, keepDiacritics, keepPunctuations, keepNumbers)) - { - } - - private static TransformWrapper MakeTransformer(IHostEnvironment env, - (string input, string output)[] columns, - CaseNormalizationMode textCase, - bool keepDiacritics, - bool keepPunctuations, - bool keepNumbers) - { - Contracts.AssertValue(env); - env.CheckNonEmpty(columns, nameof(columns)); - foreach (var (input, output) in columns) - { - env.CheckValue(input, nameof(input)); - env.CheckValue(output, nameof(input)); - } - - // Create arguments. - var args = new TextNormalizerTransform.Arguments - { - Column = columns.Select(x => new TextNormalizerTransform.Column { Source = x.input, Name = x.output }).ToArray(), - TextCase = textCase, - KeepDiacritics = keepDiacritics, - KeepPunctuations = keepPunctuations, - KeepNumbers = keepNumbers - }; - - // Create a valid instance of data. - var schema = new Schema(columns.Select(x => new Schema.Column(x.input, TextType.Instance, null))); - var emptyData = new EmptyDataView(env, schema); - - return new TransformWrapper(env, new TextNormalizerTransform(env, args, emptyData)); - } - } - /// /// Produces a bag of counts of ngrams (sequences of consecutive words) in a given text. /// It does so by building a dictionary of ngrams and using the id in the dictionary as the index in the bag. diff --git a/test/BaselineOutput/Common/Text/Normalized.tsv b/test/BaselineOutput/Common/Text/Normalized.tsv new file mode 100644 index 0000000000..da9c302237 --- /dev/null +++ b/test/BaselineOutput/Common/Text/Normalized.tsv @@ -0,0 +1,17 @@ +#@ TextLoader{ +#@ header+ +#@ sep=tab +#@ col=label:BL:0 +#@ col=text:TX:1 +#@ col=NormText:TX:2 +#@ col=UpperText:TX:3 +#@ col=WithDiacriticsText:TX:4 +#@ col=NoNumberText:TX:5 +#@ col=NoPuncText:TX:6 +#@ } +label text NormText UpperText WithDiacriticsText NoNumberText NoPuncText +1 ==RUDE== Dude, you are rude upload that carl picture back, or else. ==rude== dude, you are rude upload that carl picture back, or else. ==RUDE== DUDE, YOU ARE RUDE UPLOAD THAT CARL PICTURE BACK, OR ELSE. ==rude== dude, you are rude upload that carl picture back, or else. ==rude== dude, you are rude upload that carl picture back, or else. ==rude== dude you are rude upload that carl picture back or else +1 == OK! == IM GOING TO VANDALIZE WILD ONES WIKI THEN!!! == ok! == im going to vandalize wild ones wiki then!!! == OK! == IM GOING TO VANDALIZE WILD ONES WIKI THEN!!! == ok! == im going to vandalize wild ones wiki then!!! == ok! == im going to vandalize wild ones wiki then!!! == ok == im going to vandalize wild ones wiki then +1 Stop trolling, zapatancas, calling me a liar merely demonstartes that you arer Zapatancas. You may choose to chase every legitimate editor from this site and ignore me but I am an editor with a record that isnt 99% trolling and therefore my wishes are not to be completely ignored by a sockpuppet like yourself. The consensus is overwhelmingly against you and your trollin g lover Zapatancas, stop trolling, zapatancas, calling me a liar merely demonstartes that you arer zapatancas. you may choose to chase every legitimate editor from this site and ignore me but i am an editor with a record that isnt 99% trolling and therefore my wishes are not to be completely ignored by a sockpuppet like yourself. the consensus is overwhelmingly against you and your trollin g lover zapatancas, STOP TROLLING, ZAPATANCAS, CALLING ME A LIAR MERELY DEMONSTARTES THAT YOU ARER ZAPATANCAS. YOU MAY CHOOSE TO CHASE EVERY LEGITIMATE EDITOR FROM THIS SITE AND IGNORE ME BUT I AM AN EDITOR WITH A RECORD THAT ISNT 99% TROLLING AND THEREFORE MY WISHES ARE NOT TO BE COMPLETELY IGNORED BY A SOCKPUPPET LIKE YOURSELF. THE CONSENSUS IS OVERWHELMINGLY AGAINST YOU AND YOUR TROLLIN G LOVER ZAPATANCAS, stop trolling, zapatancas, calling me a liar merely demonstartes that you arer zapatancas. you may choose to chase every legitimate editor from this site and ignore me but i am an editor with a record that isnt 99% trolling and therefore my wishes are not to be completely ignored by a sockpuppet like yourself. the consensus is overwhelmingly against you and your trollin g lover zapatancas, stop trolling, zapatancas, calling me a liar merely demonstartes that you arer zapatancas. you may choose to chase every legitimate editor from this site and ignore me but i am an editor with a record that isnt % trolling and therefore my wishes are not to be completely ignored by a sockpuppet like yourself. the consensus is overwhelmingly against you and your trollin g lover zapatancas, stop trolling zapatancas calling me a liar merely demonstartes that you arer zapatancas you may choose to chase every legitimate editor from this site and ignore me but i am an editor with a record that isnt 99 trolling and therefore my wishes are not to be completely ignored by a sockpuppet like yourself the consensus is overwhelmingly against you and your trollin g lover zapatancas +1 ==You're cool== You seem like a really cool guy... *bursts out laughing at sarcasm*. ==you're cool== you seem like a really cool guy... *bursts out laughing at sarcasm*. ==YOU'RE COOL== YOU SEEM LIKE A REALLY COOL GUY... *BURSTS OUT LAUGHING AT SARCASM*. ==you're cool== you seem like a really cool guy... *bursts out laughing at sarcasm*. ==you're cool== you seem like a really cool guy... *bursts out laughing at sarcasm*. ==youre cool== you seem like a really cool guy bursts out laughing at sarcasm +1 "::::: Why are you threatening me? I'm not being disruptive, its you who is being disruptive. " "::::: why are you threatening me? i'm not being disruptive, its you who is being disruptive. " "::::: WHY ARE YOU THREATENING ME? I'M NOT BEING DISRUPTIVE, ITS YOU WHO IS BEING DISRUPTIVE. " "::::: why are you threatening me? i'm not being disruptive, its you who is being disruptive. " "::::: why are you threatening me? i'm not being disruptive, its you who is being disruptive. " " why are you threatening me im not being disruptive its you who is being disruptive " diff --git a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs index 75a24935a3..79a325b561 100644 --- a/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs +++ b/test/Microsoft.ML.Benchmarks/StochasticDualCoordinateAscentClassifierBench.cs @@ -97,10 +97,8 @@ public void TrainSentiment() Name = "WordEmbeddings", Source = new[] { "SentimentText" } }, - KeepDiacritics = false, - KeepPunctuations = false, - TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower, OutputTokens = true, + KeepPunctuations=false, StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(), VectorNormalizer = TextTransform.TextNormKind.None, CharFeatureExtractor = null, diff --git a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs index d5c7558e45..93259f6e3a 100644 --- a/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs +++ b/test/Microsoft.ML.FSharp.Tests/SmokeTests.fs @@ -102,9 +102,7 @@ module SmokeTest1 = pipeline.Add( TextFeaturizer( "Features", [| "SentimentText" |], - KeepDiacritics = false, KeepPunctuations = false, - TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, OutputTokens = true, VectorNormalizer = TextTransformTextNormKind.L2 )) @@ -171,9 +169,7 @@ module SmokeTest2 = pipeline.Add( TextFeaturizer( "Features", [| "SentimentText" |], - KeepDiacritics = false, KeepPunctuations = false, - TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, OutputTokens = true, VectorNormalizer = TextTransformTextNormKind.L2 )) @@ -237,9 +233,7 @@ module SmokeTest3 = pipeline.Add( TextFeaturizer( "Features", [| "SentimentText" |], - KeepDiacritics = false, KeepPunctuations = false, - TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, OutputTokens = true, VectorNormalizer = TextTransformTextNormKind.L2 )) diff --git a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs index b3fd733912..8561500d14 100644 --- a/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs +++ b/test/Microsoft.ML.StaticPipelineTesting/StaticPipeTests.cs @@ -771,5 +771,46 @@ public void PrincipalComponentAnalysis() var type = schema.GetColumnType(pcaCol); Assert.True(type.IsVector && type.IsKnownSizeVector && type.ItemType.IsNumber); } + + [Fact] + public void TextNormalizeStatic() + { + var env = new ConsoleEnvironment(seed: 0); + var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); + var reader = TextLoader.CreateReader(env, ctx => ( + label: ctx.LoadBool(0), + text: ctx.LoadText(1)), hasHeader: true); + var dataSource = new MultiFileSource(dataPath); + var data = reader.Read(dataSource); + + var est = data.MakeNewEstimator() + .Append(r => ( + r.label, + norm: r.text.NormalizeText(), + norm_Upper: r.text.NormalizeText(textCase: TextNormalizerEstimator.CaseNormalizationMode.Upper), + norm_KeepDiacritics: r.text.NormalizeText(keepDiacritics: true), + norm_NoPuctuations: r.text.NormalizeText(keepPunctuations: false), + norm_NoNumbers: r.text.NormalizeText(keepNumbers: false))); + var tdata = est.Fit(data).Transform(data); + var schema = tdata.AsDynamic.Schema; + + Assert.True(schema.TryGetColumnIndex("norm", out int norm)); + var type = schema.GetColumnType(norm); + Assert.True(!type.IsVector && type.ItemType.IsText); + + Assert.True(schema.TryGetColumnIndex("norm_Upper", out int normUpper)); + type = schema.GetColumnType(normUpper); + Assert.True(!type.IsVector && type.ItemType.IsText); + Assert.True(schema.TryGetColumnIndex("norm_KeepDiacritics", out int diacritics)); + type = schema.GetColumnType(diacritics); + Assert.True(!type.IsVector && type.ItemType.IsText); + Assert.True(schema.TryGetColumnIndex("norm_NoPuctuations", out int punct)); + type = schema.GetColumnType(punct); + Assert.True(!type.IsVector && type.ItemType.IsText); + Assert.True(schema.TryGetColumnIndex("norm_NoNumbers", out int numbers)); + type = schema.GetColumnType(numbers); + Assert.True(!type.IsVector && type.ItemType.IsText); + + } } } \ No newline at end of file diff --git a/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs index 67278ee418..96df9dc8db 100644 --- a/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/Api/SimpleTrainAndPredict.cs @@ -7,7 +7,6 @@ using Microsoft.ML.Runtime.Learners; using Xunit; using System.Linq; -using Microsoft.ML.Runtime.FastTree; using Microsoft.ML.Runtime.RunTests; namespace Microsoft.ML.Tests.Scenarios.Api @@ -70,10 +69,8 @@ private static TextTransform.Arguments MakeSentimentTextTransformArgs(bool norma Name = "Features", Source = new[] { "SentimentText" } }, - KeepDiacritics = false, - KeepPunctuations = false, - TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower, OutputTokens = true, + KeepPunctuations=false, StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(), VectorNormalizer = normalize ? TextTransform.TextNormKind.L2 : TextTransform.TextNormKind.None, CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 3, AllLengths = false }, diff --git a/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs b/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs index aea4a55f18..26f5709f33 100644 --- a/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/ClusteringTests.cs @@ -19,9 +19,7 @@ public void PredictNewsCluster() pipeline.Add(new ColumnConcatenator("AllText", "Subject", "Content")); pipeline.Add(new TextFeaturizer("Features", "AllText") { - KeepDiacritics = false, KeepPunctuations = false, - TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, StopWordsRemover = new PredefinedStopWordsRemover(), VectorNormalizer = TextTransformTextNormKind.L2, CharFeatureExtractor = new NGramNgramExtractor() { NgramLength = 3, AllLengths = false }, diff --git a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs index bf34be3a80..9ac8230a65 100644 --- a/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs +++ b/test/Microsoft.ML.Tests/Scenarios/PipelineApi/SimpleTrainAndPredict.cs @@ -41,9 +41,7 @@ private static TextFeaturizer MakeSentimentTextTransform() { return new TextFeaturizer("Features", "SentimentText") { - KeepDiacritics = false, KeepPunctuations = false, - TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, OutputTokens = true, StopWordsRemover = new PredefinedStopWordsRemover(), VectorNormalizer = TextTransformTextNormKind.L2, diff --git a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs index 00e7997dd5..c30b062549 100644 --- a/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/Scenarios/SentimentPredictionTests.cs @@ -318,9 +318,7 @@ private Legacy.LearningPipeline PreparePipeline() pipeline.Add(new TextFeaturizer("Features", "SentimentText") { - KeepDiacritics = false, KeepPunctuations = false, - TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, OutputTokens = true, StopWordsRemover = new PredefinedStopWordsRemover(), VectorNormalizer = TextTransformTextNormKind.L2, @@ -367,9 +365,7 @@ private LearningPipeline PreparePipelineLightGBM() pipeline.Add(new TextFeaturizer("Features", "SentimentText") { - KeepDiacritics = false, KeepPunctuations = false, - TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, OutputTokens = true, StopWordsRemover = new PredefinedStopWordsRemover(), VectorNormalizer = TextTransformTextNormKind.L2, @@ -416,9 +412,7 @@ private LearningPipeline PreparePipelineSymSGD() pipeline.Add(new TextFeaturizer("Features", "SentimentText") { - KeepDiacritics = false, KeepPunctuations = false, - TextCase = TextNormalizerTransformCaseNormalizationMode.Lower, OutputTokens = true, StopWordsRemover = new PredefinedStopWordsRemover(), VectorNormalizer = TextTransformTextNormKind.L2, diff --git a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs index cbff47e674..3c24a03b05 100644 --- a/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs +++ b/test/Microsoft.ML.Tests/ScenariosWithDirectInstantiation/SentimentPredictionTests.cs @@ -43,10 +43,8 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTest() Name = "Features", Source = new[] { "SentimentText" } }, - KeepDiacritics = false, - KeepPunctuations = false, - TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower, OutputTokens = true, + KeepPunctuations = false, StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(), VectorNormalizer = TextTransform.TextNormKind.L2, CharFeatureExtractor = new NgramExtractorTransform.NgramExtractorArguments() { NgramLength = 3, AllLengths = false }, @@ -108,10 +106,8 @@ public void TrainAndPredictSentimentModelWithDirectionInstantiationTestWithWordE Name = "WordEmbeddings", Source = new[] { "SentimentText" } }, - KeepDiacritics = false, - KeepPunctuations = false, - TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower, OutputTokens = true, + KeepPunctuations= false, StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(), VectorNormalizer = TextTransform.TextNormKind.None, CharFeatureExtractor = null, diff --git a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs index 59a439ec61..023c5ddd73 100644 --- a/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/TextFeaturizerTests.cs @@ -2,12 +2,11 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.StaticPipe; using Microsoft.ML.Runtime.Data; using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.RunTests; -using Microsoft.ML.Runtime.Tools; using Microsoft.ML.Transforms; +using Microsoft.ML.Transforms.Text; using System.IO; using Xunit; using Xunit.Abstractions; @@ -141,7 +140,7 @@ public void TextNormalizationAndStopwordRemoverWorkout() text: ctx.LoadFloat(1)), hasHeader: true) .Read(new MultiFileSource(sentimentDataPath)); - var est = new TextNormalizer(Env,"text") + var est = new TextNormalizerEstimator(Env,"text") .Append(new WordTokenizer(Env, "text", "words")) .Append(new StopwordRemover(Env, "words", "words_without_stopwords")); TestEstimatorCore(est, data.AsDynamic, invalidInput: invalidData.AsDynamic); diff --git a/test/Microsoft.ML.Tests/Transformers/TextNormalizer.cs b/test/Microsoft.ML.Tests/Transformers/TextNormalizer.cs new file mode 100644 index 0000000000..0a0d70fc05 --- /dev/null +++ b/test/Microsoft.ML.Tests/Transformers/TextNormalizer.cs @@ -0,0 +1,101 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. +// See the LICENSE file in the project root for more information. + +using Microsoft.ML.Runtime.Api; +using Microsoft.ML.Runtime.Data; +using Microsoft.ML.Runtime.Data.IO; +using Microsoft.ML.Runtime.Model; +using Microsoft.ML.Runtime.RunTests; +using Microsoft.ML.Runtime.Tools; +using Microsoft.ML.Transforms.Text; +using System.IO; +using Xunit; +using Xunit.Abstractions; + +namespace Microsoft.ML.Tests.Transformers +{ + public class TextNormalizerTests : TestDataPipeBase + { + public TextNormalizerTests(ITestOutputHelper output) : base(output) + { + } + + private class TestClass + { + public string A; + [VectorType(2)] + public string[] B; + } + + + private class TestClassB + { + public float A; + [VectorType(2)] + public float[] B; + } + + [Fact] + public void TextNormalizerWorkout() + { + var data = new[] { new TestClass() { A = "A 1, b. c! йЁ 24 ", B = new string[2] { "~``ё 52ds й vc", "6ksj94 vd ё dakl Юds Ё q й" } }, + new TestClass() { A = null, B =new string[2] { null, string.Empty } } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new TextNormalizerEstimator(Env, columns: new[] { ("A", "NormA"), ("B", "NormB") }); + + var invalidData = new[] { new TestClassB() { A = 1, B = new float[2] { 1,4 } }, + new TestClassB() { A = 2, B =new float[2] { 3,4 } } }; + var invalidDataView = ComponentCreation.CreateDataView(Env, invalidData); + TestEstimatorCore(pipe, dataView, invalidInput: invalidDataView); + + var dataPath = GetDataPath("wikipedia-detox-250-line-data.tsv"); + var reader = TextLoader.CreateReader(Env, ctx => ( + label: ctx.LoadBool(0), + text: ctx.LoadText(1)), hasHeader: true); + var dataSource = new MultiFileSource(dataPath); + dataView = reader.Read(dataSource).AsDynamic; + + var pipeVariations = new TextNormalizerEstimator(Env, columns: new[] { ("text", "NormText") }).Append( + new TextNormalizerEstimator(Env, textCase: TextNormalizerEstimator.CaseNormalizationMode.Upper, columns: new[] { ("text", "UpperText") })).Append( + new TextNormalizerEstimator(Env, keepDiacritics: true, columns: new[] { ("text", "WithDiacriticsText") })).Append( + new TextNormalizerEstimator(Env, keepNumbers: false, columns: new[] { ("text", "NoNumberText") })).Append( + new TextNormalizerEstimator(Env, keepPunctuations: false, columns: new[] { ("text", "NoPuncText") })); + + var outputPath = GetOutputPath("Text", "Normalized.tsv"); + using (var ch = Env.Start("save")) + { + var saver = new TextSaver(Env, new TextSaver.Arguments { Silent = true }); + var savedData = TakeFilter.Create(Env, pipeVariations.Fit(dataView).Transform(dataView), 5); + using (var fs = File.Create(outputPath)) + DataSaverUtils.SaveDataView(ch, saver, savedData, fs, keepHidden: true); + } + + CheckEquality("Text", "Normalized.tsv"); + Done(); + } + + [Fact] + public void TestCommandLine() + { + Assert.Equal(Maml.Main(new[] { @"showschema loader=Text{col=A:TX:0} xf=TextNorm{col=B:A} in=f:\2.txt" }), (int)0); + } + + [Fact] + public void TestOldSavingAndLoading() + { + var data = new[] { new TestClass() { A = "A 1, b. c! йЁ 24 ", B = new string[2] { "~``ё 52ds й vc", "6ksj94 vd ё dakl Юds Ё q й" } } }; + var dataView = ComponentCreation.CreateDataView(Env, data); + var pipe = new TextNormalizerEstimator(Env, columns: new[] { ("A", "NormA"), ("B", "NormB") }); + + var result = pipe.Fit(dataView).Transform(dataView); + var resultRoles = new RoleMappedData(result); + using (var ms = new MemoryStream()) + { + TrainUtils.SaveModel(Env, Env.Start("saving"), ms, null, resultRoles); + ms.Position = 0; + var loadedView = ModelFileUtils.LoadTransforms(Env, dataView, ms); + } + } + } +} diff --git a/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs b/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs index f886924463..87948ebcf3 100644 --- a/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs +++ b/test/Microsoft.ML.Tests/Transformers/WordEmbeddingsTests.cs @@ -2,12 +2,10 @@ // The .NET Foundation licenses this file to you under the MIT license. // See the LICENSE file in the project root for more information. -using Microsoft.ML.StaticPipe; using Microsoft.ML.Runtime.Data; -using Microsoft.ML.Runtime.Data.IO; using Microsoft.ML.Runtime.RunTests; using Microsoft.ML.Scenarios; -using System.IO; +using Microsoft.ML.StaticPipe; using Xunit; using Xunit.Abstractions; @@ -38,10 +36,8 @@ public void TestWordEmbeddings() Name = "SentimentText_Features", Source = new[] { "SentimentText" } }, - KeepDiacritics = false, - KeepPunctuations = false, - TextCase = Runtime.TextAnalytics.TextNormalizerTransform.CaseNormalizationMode.Lower, OutputTokens = true, + KeepPunctuations = false, StopWordsRemover = new Runtime.TextAnalytics.PredefinedStopWordsRemoverFactory(), VectorNormalizer = TextTransform.TextNormKind.None, CharFeatureExtractor = null,