diff --git a/src/Nncase.Cli/Commands/Compile.cs b/src/Nncase.Cli/Commands/Compile.cs index 69a8cf0c54..edb64f2b3e 100644 --- a/src/Nncase.Cli/Commands/Compile.cs +++ b/src/Nncase.Cli/Commands/Compile.cs @@ -86,6 +86,50 @@ public Compile() alias: "--calib-method", description: $"model quant options, default is {Quantization.CalibMethod.Kld}", getDefaultValue: () => Quantization.CalibMethod.Kld)); + AddOption(new Option( + alias: "--pre-process", + description: "whether enable pre process, default is False", + getDefaultValue: () => false)); + AddOption(new Option( + alias: "--input-layout", + description: "the model input data layout, default is empty. eg. NCHW/NHWC", + getDefaultValue: () => string.Empty)); + AddOption(new Option( + alias: "--output-layout", + description: "the model output data layout, default is empty. eg. NCHW/NHWC", + getDefaultValue: () => string.Empty)); + AddOption(new Option( + alias: "--input-type", + description: "the model input data value type, default is Float32", + getDefaultValue: () => InputType.Float32)); + AddOption(new Option>( + alias: "--input-shape", + description: "the model input data shape, default is []. eg. `--input-shape 1 2 3 4`", + getDefaultValue: () => Array.Empty())); + AddOption(new Option>( + alias: "--input-range", + description: "the model input data value range, default is []. eg `--input-range -100.3 200.4`", + getDefaultValue: () => Array.Empty())); + AddOption(new Option( + alias: "--swap-rb", + description: "whether swap the model input data channel R and B", + getDefaultValue: () => false)); + AddOption(new Option( + alias: "--letter-box-value", + description: "letterbox value, default 0.0", + getDefaultValue: () => 0.0f)); + AddOption(new Option>( + alias: "--mean", + description: "the model input data mean, default []", + getDefaultValue: () => Array.Empty())); + AddOption(new Option>( + alias: "--std", + description: "the model input data std, default []", + getDefaultValue: () => Array.Empty())); + AddOption(new Option( + alias: "--model-layout", + description: "the model's input layout, default is empty. eg. NCHW/NHWC", + getDefaultValue: () => string.Empty)); AddOption(new Option( alias: "--benchmark-only", description: $"benchmark only", @@ -142,6 +186,17 @@ private async Task RunAsync(CliCompileOptions cliOptions, IHost host) }, ModelQuantMode = cliOptions.ModelQuantMode, }, + PreProcess = cliOptions.PreProcess, + InputLayout = cliOptions.InputLayout, + OutputLayout = cliOptions.OutputLayout, + InputType = cliOptions.InputType, + InputShape = cliOptions.InputShape.ToArray(), + InputRange = cliOptions.InputRange.ToArray(), + SwapRB = cliOptions.SwapRB, + LetterBoxValue = cliOptions.LetterBoxValue, + Mean = cliOptions.Mean.ToArray(), + Std = cliOptions.Std.ToArray(), + ModelLayout = cliOptions.ModelLayout, IsBenchmarkOnly = cliOptions.BenchmarkOnly, }; @@ -209,6 +264,28 @@ internal sealed class CliCompileOptions public DatasetFormat DatasetFormat { get; set; } public bool BenchmarkOnly { get; set; } + + public bool PreProcess { get; set; } + + public string InputLayout { get; set; } + + public string OutputLayout { get; set; } + + public InputType InputType { get; set; } + + public List InputShape { get; set; } + + public List InputRange { get; set; } + + public bool SwapRB { get; set; } + + public float LetterBoxValue { get; set; } + + public List Mean { get; set; } + + public List Std { get; set; } + + public string ModelLayout { get; set; } } #pragma warning restore CS8618 diff --git a/src/Nncase.Core/Passes/RunPassContext.cs b/src/Nncase.Core/Passes/RunPassContext.cs index 52becdcda0..4ecaf68aaa 100644 --- a/src/Nncase.Core/Passes/RunPassContext.cs +++ b/src/Nncase.Core/Passes/RunPassContext.cs @@ -36,6 +36,11 @@ public record RunPassContext /// public int Index { get; set; } + /// + /// Gets this pass's driver. + /// + public IPass? Driver { get; init; } + /// /// Gets or sets a value indicating whether control rewrite once or not. /// when RewriteOnce is true, the rule will only apply once, then restart rewrite from first rule. diff --git a/src/Nncase.Core/PatternMatch/IMatchResult.cs b/src/Nncase.Core/PatternMatch/IMatchResult.cs index f6f9446b9e..8f6118d4df 100644 --- a/src/Nncase.Core/PatternMatch/IMatchResult.cs +++ b/src/Nncase.Core/PatternMatch/IMatchResult.cs @@ -30,6 +30,13 @@ public interface IMatchResult : IEnumerable> /// Match result. object this[IPattern pattern] { get; } + /// + /// Get match result by name, default is null. + /// + /// Pattern name. + /// match result. + object GetValueOrDefault(string name); + /// /// Get match result by pattern. /// diff --git a/src/Nncase.Core/PatternMatch/MatchResult.cs b/src/Nncase.Core/PatternMatch/MatchResult.cs index c03ab84d1b..eb1bb1f651 100644 --- a/src/Nncase.Core/PatternMatch/MatchResult.cs +++ b/src/Nncase.Core/PatternMatch/MatchResult.cs @@ -43,6 +43,9 @@ where kv.Key.Name is not null /// public object this[string name] => _stringMap[name]; + /// + public object GetValueOrDefault(string name) => _stringMap.GetValueOrDefault(name, null!); + /// public IEnumerator> GetEnumerator() { diff --git a/src/Nncase.Core/Utilities/ReplaceUtility.cs b/src/Nncase.Core/Utilities/ReplaceUtility.cs index c07cd9850e..fa1331f80b 100644 --- a/src/Nncase.Core/Utilities/ReplaceUtility.cs +++ b/src/Nncase.Core/Utilities/ReplaceUtility.cs @@ -97,11 +97,6 @@ public static Call ReplaceCallParams(Expr target, IReadOnlyList oldParams, return new Call(target, ReplaceItems(oldParams, pairs)); } - public static Call ReplaceCallParams(Call call, params (int, Expr)[] pairs) - { - return new Call(call.Target, ReplaceItems(call.Arguments.ToArray(), pairs)); - } - /// /// replace the call params with parameter info. /// @@ -122,12 +117,7 @@ public static Call ReplaceCallParams(Expr target, IReadOnlyList oldParams, /// expr. /// new Call. public static Call ReplaceCallFirstParam(Expr target, IReadOnlyList oldParams, Expr expr) => - ReplaceCallParams(target, oldParams, (oldParams[0], expr)); - - public static Expr ReplaceCallFirstParam(Call call, Expr expr) - { - return ReplaceCallFirstParam(call.Target, call.Arguments.ToArray(), expr); - } + ReplaceCallParams(target, oldParams, (0, expr)); /// /// Replace target in body with expr. diff --git a/src/Nncase.EGraph/Passes/RewriteProvider.cs b/src/Nncase.EGraph/Passes/RewriteProvider.cs index fa4558226a..7642a5395e 100644 --- a/src/Nncase.EGraph/Passes/RewriteProvider.cs +++ b/src/Nncase.EGraph/Passes/RewriteProvider.cs @@ -43,8 +43,6 @@ public Expr ERewrite(Expr expr, IEnumerable rules, RunPassContext public IEGraph ERewrite(IEGraph eGraph, IEnumerable rules, RunPassContext context) { var last_version = eGraph.Version; - int count = 0; - while (true) { var matches = rules. @@ -59,10 +57,12 @@ public IEGraph ERewrite(IEGraph eGraph, IEnumerable rules, RunPass if (DumpScope.Current.IsEnabled(DumpFlags.Rewrite)) { - foreach (var (rule, results) in matches.Where(p => p.Item2.Count != 0)) + using var fs = DumpScope.Current.OpenFile(Path.Combine("Matches", $"V{eGraph.Version}.txt")); + using var writer = new StreamWriter(fs); + writer.WriteLine("rule, results"); + foreach (var (rule, results) in matches) { - using var fs = DumpScope.Current.OpenFile(Path.Combine("Matches", $"V{eGraph.Version}_{count++}_{rule.GetType().Name}.dot")); - EGraphPrinter.DumpEgraphAsDot(eGraph, results, fs); + writer.WriteLine($"{rule.GetType().Name}, {results.Count}"); } } diff --git a/src/Nncase.Importer/Onnx/Conv2D.cs b/src/Nncase.Importer/Onnx/Conv2D.cs index e82002d0ee..5d66995704 100644 --- a/src/Nncase.Importer/Onnx/Conv2D.cs +++ b/src/Nncase.Importer/Onnx/Conv2D.cs @@ -39,7 +39,7 @@ private Expr VisitConv2D(in NodeProto op) var pads = AutoPad(op, autoPad, input, weights, strides.ToArray(), dilation.ToArray(), isConv1D); pads.InferenceType(); var conv = F.NN.Conv2D(input, weights, bias, strides.ToArray(), pads, dilation.ToArray(), PadMode.Constant, group); - List outputNames = new() { op.Name }; + List outputNames = new() { op.Output[0] }; conv.Metadata.OutputNames = outputNames; if (isConv1D) { diff --git a/src/Nncase.Importer/Onnx/MatMul.cs b/src/Nncase.Importer/Onnx/MatMul.cs index 8344d11b2e..5f7a354593 100644 --- a/src/Nncase.Importer/Onnx/MatMul.cs +++ b/src/Nncase.Importer/Onnx/MatMul.cs @@ -14,7 +14,7 @@ private Expr VisitMatMul(in NodeProto op) { var (a, b) = GetInputExprs(op, 0, 1); var matmul = IR.F.Math.MatMul(a, b); - List outputNames = new() { op.Name }; + List outputNames = new() { op.Output[0] }; matmul.Metadata.OutputNames = outputNames; return matmul; } diff --git a/src/Nncase.Importer/TFLite/Conv2DTranspose.cs b/src/Nncase.Importer/TFLite/Conv2DTranspose.cs index e096b9b8c8..688cfe1f0b 100644 --- a/src/Nncase.Importer/TFLite/Conv2DTranspose.cs +++ b/src/Nncase.Importer/TFLite/Conv2DTranspose.cs @@ -54,7 +54,7 @@ private Expr VisitConv2DTranspose(in tflite.Operator op) dilation, PadMode.Constant, 1); - List outputNames = new() { GetInputTensor(op, 0).Name }; + List outputNames = new() { GetOutputTensor(op, 0).Name }; conv2DTranspose.Metadata.OutputNames = outputNames; return F.Tensors.NCHWToNHWC(F.Math.Clamp( conv2DTranspose, diff --git a/src/Nncase.Importer/TFLite/MatMul.cs b/src/Nncase.Importer/TFLite/MatMul.cs index 0472cb6ac8..6b1724e5f5 100644 --- a/src/Nncase.Importer/TFLite/MatMul.cs +++ b/src/Nncase.Importer/TFLite/MatMul.cs @@ -66,10 +66,10 @@ private Expr VisitMatMul(in tflite.Operator op, bool isFullyConnected = true) : Expand(Cast(0, GetDataType(GetInputTensor(op, 0).Type)), new[] { otherTensor.Shape(0) }).Evaluate().AsTensor(); var matmul = MatMul(lhs, rhs); - List outputNames = new() { GetInputTensor(op, 0).Name + "_matmul" }; + List outputNames = new() { GetOutputTensor(op, 0).Name + "_matmul" }; matmul.Metadata.OutputNames = outputNames; outputNames.Clear(); - outputNames.Add(GetInputTensor(op, 0).Name + "_bias"); + outputNames.Add(GetOutputTensor(op, 0).Name); bias.Metadata.OutputNames = outputNames; var mm = matmul + bias; diff --git a/src/Nncase.Passes/PassManager.cs b/src/Nncase.Passes/PassManager.cs index df66312e53..9377511e50 100644 --- a/src/Nncase.Passes/PassManager.cs +++ b/src/Nncase.Passes/PassManager.cs @@ -287,6 +287,7 @@ public RunPassContextWithAnalysis(IAnalyzerManager analyzerManager, Either { - public override Pattern Pattern => IsCallWildcard( - "outer", - IsWildcard(), - InputPattern); + public override Pattern Pattern => IsCall("outer", IsWildcard("outerTarget"), IsVArgsRepeat("outerParams", exprs => + { + var patterns = new Pattern[exprs.Length]; + for (int i = 0; i < exprs.Length; i++) + { + patterns[i] = GetInputPattern(i); + } + + return patterns; + })); - public Pattern InputPattern => IsCallWildcard( - "call", - IsWildcard(), - IsRangeOfMarker( - "marker", - IsWildcard(), - IsWildcard())); + public Pattern GetInputPattern(int i) => + IsAlt( + IsCallWildcard( + $"input_{i}", + IsOp($"input_target_{i}", NotChangeRangeOp), + IsRangeOfMarker($"input_marker_{i}", IsWildcard($"marker_target_{i}"), IsWildcard($"marker_attribute_{i}"))), + IsWildcard($"input_{i}")); - public Expr? GetReplace(Call outer, Call call, Marker marker) + public Expr? GetReplace(Call outer, Expr outerTarget, IReadOnlyList outerParams, IMatchResult result) { - if (!NotChangeRangeOp(call.Target)) + if (!Enumerable.Range(0, outerParams.Count).Select(i => result.GetValueOrDefault($"input_marker_{i}")).Any(e => e is not null)) { return null; } - if (outer.Target is MatMul && CompilerServices.TryMatchRoot(outer.Arguments[1], InputPattern, new(), out var matchResult)) + var newArgs = new Expr[outerParams.Count]; + for (int i = 0; i < outerParams.Count; i++) { - var rhsMarker = (Marker)matchResult["marker"]; - var rhsCall = (Call)matchResult["call"]; - var lhs = marker.With(target: ReplaceCallFirstParam(call, marker)); - var rhs = rhsMarker.With(target: ReplaceCallFirstParam(rhsCall, rhsMarker)); - return ReplaceCallParams(outer, (0, lhs), (1, rhs)); + if (result.GetValueOrDefault($"input_marker_{i}") is Marker marker && result[$"marker_target_{i}"] is Expr target && result[$"marker_attribute_{i}"] is Expr range) + { + newArgs[i] = IR.F.Math.RangeOfMarker(outerParams[i], range).With(mixQuantInfo: marker.MixQuantInfo, adaQuantInfo: marker.AdaQuantInfo); + } + else + { + newArgs[i] = outerParams[i]; + } } - return ReplaceCallFirstParam(outer, marker.With(target: ReplaceCallFirstParam(call, marker))); + return new Call(outerTarget, newArgs); } } @@ -56,23 +68,18 @@ public partial class BroadcastOutputMarker : RewriteRule { public override Pattern Pattern => IsRangeOfMarker( "marker", - IsCallWildcard("input", IsWildcard(), IsCallWildcard(null, IsWildcard())), - IsWildcard()); + IsCallWildcard("output", IsOp("outputTarget", NotChangeRangeOp), IsCallWildcard("input", IsWildcard("inputTarget"))), + IsWildcard("range")); - public Expr? GetReplace(Call input, Marker marker) + public Expr? GetReplace(Marker marker, Expr range, Call output, Op outputTarget, IReadOnlyList outputParams) { - if (!NotChangeRangeOp(input.Target)) - { - return null; - } - - return ReplaceCallFirstParam(input, marker.With(target: input.Arguments[0])); + return ReplaceCallFirstParam(outputTarget, outputParams, IR.F.Math.RangeOfMarker(outputParams[0], range).With(adaQuantInfo: marker.AdaQuantInfo, mixQuantInfo: marker.MixQuantInfo)); } } internal static class BroadcastMarkerHelper { - public static bool NotChangeRangeOp(Expr op) + public static bool NotChangeRangeOp(Op op) { return op is Squeeze || op is Unsqueeze || op is Reshape || op is Broadcast; } diff --git a/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs b/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs index c065fe65c7..e046301045 100644 --- a/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs +++ b/src/Nncase.Passes/Rules/Neutral/CombineQuantize.cs @@ -40,14 +40,17 @@ public sealed partial class CombineQuantizeConcat : RewriteRule private Expr? GetReplace(Quantize quantize, IReadOnlyList tupleInputs, Expr axis, Expr quantParam, RunPassContext options) { - var userAnalysis = options.GetAnalysis(); - - // see UnitTestCombineQuantize.TestCombineQuantizeConcatNegative - foreach (var e in tupleInputs) + if (options.Driver is DataflowPass) { - if (userAnalysis[e].Count() > 1) + var userAnalysis = options.GetAnalysis(); + + // see UnitTestCombineQuantize.TestCombineQuantizeConcatNegative + foreach (var e in tupleInputs) { - return null; + if (userAnalysis[e].Count() > 1) + { + return null; + } } } @@ -61,49 +64,43 @@ public sealed partial class CombineQuantizeConcat : RewriteRule [RuleGenerator] public sealed partial class CombineQuantizeReshape : RewriteRule { - private readonly bool _checkShapeSize; - - public CombineQuantizeReshape() - { - _checkShapeSize = false; - } - /// /// Initializes a new instance of the class. /// /// if true, skip pass. - public CombineQuantizeReshape(bool checkShapeSize = false) + public CombineQuantizeReshape(bool checkShapeSize) { - _checkShapeSize = checkShapeSize; + Pattern = IsQuantize( + "quantize", + _ => true, + IsReshape( + "reshape", + "reshapeCall", + IsWildcard("input") with { TypePattern = HasShape(sp => !(checkShapeSize && sp.ToValueArray().Any(s => s >= 65536)), "CheckedShape") }, + IsWildcard("shape")), + IsWildcard("quantParam")); } - /// - public override Pattern Pattern { get; } = IsQuantize( - "quantize", - _ => true, - IsReshape( - "reshape", - "reshapeCall", - IsWildcard("input"), - IsWildcard("shape")), - IsWildcard("quantParam")); - - private Expr? GetReplace(Quantize quantize, Call reshapeCall, Expr input, Expr shape, Expr quantParam, RunPassContext options) + public CombineQuantizeReshape() + : this(false) { - var userAnalysis = options.GetAnalysis(); + } - if (userAnalysis[reshapeCall].Count() > 1) - { - return null; - } + /// + public override Pattern Pattern { get; } - if (_checkShapeSize && input.CheckedShape.ToValueArray().Any(s => s >= 65536)) + private Expr? GetReplace(Quantize quantize, Call reshapeCall, Expr input, Expr shape, Expr quantParam, RunPassContext context) + { + if (context.Driver is DataflowPass) { - return null; + var userAnalysis = context.GetAnalysis(); + if (userAnalysis[reshapeCall].Count() > 1) + { + return null; + } } var output = Reshape(Quantize(input, quantParam, quantize.TargetType), shape); - output.InferenceType(); return output; } } @@ -160,15 +157,18 @@ public sealed partial class CombineQuantizeTranspose : RewriteRule private Expr? GetReplace(Quantize quantize, Call transposeCall, Expr input, Expr perm, Expr quantParam, RunPassContext options) { - var userAnalysis = options.GetAnalysis(); - - if (userAnalysis[transposeCall].Count() > 1) + try + { + var userAnalysis = options.GetAnalysis(); + if (userAnalysis[transposeCall].Count() > 1) + { + return null; + } + } + catch (System.Exception) { - return null; } - var output = Transpose(Quantize(input, quantParam, quantize.TargetType), perm); - output.InferenceType(); - return output; + return Transpose(Quantize(input, quantParam, quantize.TargetType), perm); } } diff --git a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs index 2eb9c0f39a..a0cf03380a 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs @@ -441,12 +441,14 @@ protected override Expr ReplaceVarsWithArg(Var[] fusionVars, Expr[] args, Expr n { var convTranspose = (Call)CallMarker!.Target; var c = ReplaceCallFirstParam( - convTranspose, + convTranspose.Target, + convTranspose.Arguments.ToArray(), _transposeInputMarker!.With(target: ReplaceCallFirstParam( - _transpose!, + _transpose!.Target, + _transpose!.Arguments.ToArray(), _transposeInputMarker.With(target: - ReplaceCallFirstParam(_originCall!, fusionVars[0]))))); + ReplaceCallFirstParam(_originCall!.Target, _originCall!.Arguments.ToArray(), fusionVars[0]))))); return CallMarker.With(target: base.ReplaceVarsWithArg(fusionVars, args, c)); } diff --git a/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs b/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs index 71f6c49bc3..66b040ec0f 100644 --- a/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs +++ b/src/Nncase.Quantization/Quantization/PytestCalibrationDatasetProvider.cs @@ -99,8 +99,8 @@ private bool TryParseSample(string fileName, [System.Diagnostics.CodeAnalysis.Ma if (match.Success) { string name = match.Groups[1].Value; - int n = int.Parse(match.Groups[2].Value); - int i = int.Parse(match.Groups[3].Value); + int i = int.Parse(match.Groups[2].Value); + int n = int.Parse(match.Groups[3].Value); item = new(name, n, i); return true; } @@ -111,11 +111,11 @@ private bool TryParseSample(string fileName, [System.Diagnostics.CodeAnalysis.Ma private sealed record Sample(string Name, int Number, int InputIndex) { - public string FileName => $"{Name}_{Number}_{InputIndex}.bin"; + public string FileName => $"{Name}_{InputIndex}_{Number}.bin"; public int[] GetShape() { - using var stream = File.OpenRead($"{Name}_{Number}_{InputIndex}.txt"); + using var stream = File.OpenRead($"{Name}_{InputIndex}_{Number}.txt"); using var reader = new StreamReader(stream); var line = reader.ReadLine(); int[] shape = Array.Empty(); diff --git a/src/Nncase.Quantization/Quantization/Quantizer.cs b/src/Nncase.Quantization/Quantization/Quantizer.cs index af1f837e97..961ae8b716 100644 --- a/src/Nncase.Quantization/Quantization/Quantizer.cs +++ b/src/Nncase.Quantization/Quantization/Quantizer.cs @@ -417,10 +417,12 @@ private IDictionary[]> GetRangesFromConfig(QuantScheme foreach (var rangeOf in _rangeOfs) { + bool getRange = false; for (int i = 0; i < quantScheme!.Outputs!.Length; i++) { if (rangeOf.Expr.Metadata.OutputNames?[0] == quantScheme!.Outputs[i].Name) { + getRange = true; if (((RangeOf)((Call)rangeOf.Expr).Target).IsRangeOfWeight == true && quantScheme!.Outputs[i].DataRangeMode == "by_tensor") { var oc = ((Call)rangeOf.Expr).Operands[1].CheckedShape[0].FixedValue; @@ -457,6 +459,21 @@ private IDictionary[]> GetRangesFromConfig(QuantScheme } } } + + if (getRange == false && _quantizeOptions.QuantScheme != string.Empty && _quantizeOptions.QuantSchemeStrictMode == true) + { + if (((RangeOf)((Call)rangeOf.Expr).Target).IsRangeOfWeight == true) + { + var oc = ((Call)rangeOf.Expr).Operands[1].CheckedShape[0].FixedValue; + var valueRanges = new ValueRange[oc]; + ranges.Add(rangeOf, valueRanges); + } + else + { + var valueRanges = new ValueRange[1]; + ranges.Add(rangeOf, valueRanges); + } + } } return ranges; @@ -466,8 +483,10 @@ private void AssignDataTypeFromConfig(QuantScheme quantScheme) { foreach (var marker in _markers) { + bool getRange = false; for (int i = 0; i < quantScheme!.Outputs!.Length; i++) { + getRange = true; if (marker.Expr.Metadata.OutputNames?[0] == quantScheme.Outputs[i].Name) { var markerExpr = (Marker)marker.Expr; @@ -480,6 +499,12 @@ private void AssignDataTypeFromConfig(QuantScheme quantScheme) markerExpr.MixQuantInfo!.MarkerQuantType = dataType; } } + + if (getRange == false && _quantizeOptions.QuantScheme != string.Empty && _quantizeOptions.QuantSchemeStrictMode == true) + { + var markerExpr = (Marker)marker.Expr; + markerExpr.MixQuantInfo!.MarkerQuantType = DataTypes.Float16; + } } } diff --git a/src/Nncase.Tests.TestFixture/TransformTestBase.cs b/src/Nncase.Tests.TestFixture/TransformTestBase.cs index 9f1f8435c3..49083076d4 100644 --- a/src/Nncase.Tests.TestFixture/TransformTestBase.cs +++ b/src/Nncase.Tests.TestFixture/TransformTestBase.cs @@ -67,7 +67,7 @@ public Expr TestMatchedCore(Function pre, IReadOnlyDictionary? feed } var preHashCode = pre.GetHashCode(); - var post = (Function)CompilerServices.Rewrite(pre, rules, new() { AnalysisResults = analysis }); + var post = (Function)CompilerServices.Rewrite(pre, rules, new() { AnalysisResults = analysis, Driver = new DataflowPass() }); if (isNotMatch) { Assert.Equal(preHashCode, post.GetHashCode()); @@ -97,7 +97,7 @@ public Expr TestMatchedCore(Expr pre, IReadOnlyDictionary? feeds = var preHashCode = pre.GetHashCode(); var v1 = pre.Evaluate(feeds); - var post = CompilerServices.Rewrite(pre, rules, new()); + var post = CompilerServices.Rewrite(pre, rules, new() { Driver = new DataflowPass() }); Assert.NotEqual(preHashCode, post.GetHashCode()); var v2 = post.Evaluate(feeds); if (!Comparator.AllEqual(v1, v2)) @@ -112,7 +112,7 @@ public void TestNotMatch(Expr pre, params IRewriteRule[] rules) { pre.InferenceType(); var preHashCode = pre.GetHashCode(); - var post = CompilerServices.Rewrite(pre, rules, new()); + var post = CompilerServices.Rewrite(pre, rules, new() { Driver = new DataflowPass() }); Assert.Equal(preHashCode, post.GetHashCode()); } diff --git a/src/Nncase.Tests/Quant/UnitTestPytestCalibrationDatasetProvider.cs b/src/Nncase.Tests/Quant/UnitTestPytestCalibrationDatasetProvider.cs index 17ec26a05e..6d2f29ce71 100644 --- a/src/Nncase.Tests/Quant/UnitTestPytestCalibrationDatasetProvider.cs +++ b/src/Nncase.Tests/Quant/UnitTestPytestCalibrationDatasetProvider.cs @@ -110,9 +110,9 @@ private static void DumpTensors(Tensor[] tensorValue, string dir, int sample = 1 for (var t = 0; t < tensorValue.Length; t++) { var value = tensorValue[t]; - var sr1 = new StreamWriter(Path.Join(dir, $"input_{s}_{t}.txt")); + var sr1 = new StreamWriter(Path.Join(dir, $"input_{t}_{s}.txt")); DumpTxt(value, sr1); - var sr2 = Path.Join(dir, $"input_{s}_{t}.bin"); + var sr2 = Path.Join(dir, $"input_{t}_{s}.bin"); DumpBin(value, sr2); } } diff --git a/src/Nncase.Tests/Rewrite/RewriteBase.cs b/src/Nncase.Tests/Rewrite/RewriteBase.cs index 3fd4de87fe..33dc381532 100644 --- a/src/Nncase.Tests/Rewrite/RewriteBase.cs +++ b/src/Nncase.Tests/Rewrite/RewriteBase.cs @@ -2874,3 +2874,35 @@ public PReluTransposeCase() public Dictionary FeedDict { get; } } + +/// +/// egraph extract bad case. +/// +public sealed class FoldReshapeWithBranch : IRewriteCase +{ + public FoldReshapeWithBranch() + { + var v1070 = new Var(new TensorType(DataTypes.Float32, new[] { 1, 1, 2, 8400 })); + { + var v1071 = Unary(UnaryOp.Cos, v1070); // f32[1,1,2,8400] + var v1072 = Reshape(v1071, new[] { 1, 2, 8400 }); // f32[1,2,8400] + var v1073 = Reshape(v1072, new[] { 1, 1, 2, 8400 }); // f32[1,1,2,8400] + var v1078 = Unary(UnaryOp.Sin, v1073); // f32[1,1,2,8400] + var v1079 = Reshape(v1078, new[] { 1, 2, 8400 }); // f32[1,2,8400] + var v1080 = Sub(v1072, IR.F.Random.Normal(DataTypes.Float32, new[] { 1, 2, 8400 }).Evaluate().AsTensor()); // f32[1,2,8400] + var v1081 = new IR.Tuple(v1079, v1080); // (f32[1,2,8400], f32[1,2,8400]) + PreExpr = new Function(v1081, new[] { v1070 }); + } + + FeedDict = new() { { v1070, IR.F.Random.Normal(new[] { 1, 1, 2, 8400 }).Evaluate() } }; + } + + public Function PreExpr { get; } + + public IEnumerable Rules => new[] { + typeof(FoldNopReshape), + typeof(FoldTwoReshapes), + }; + + public Dictionary FeedDict { get; } +} diff --git a/src/Nncase.Tests/Rewrite/UnitTestEGraphRewriteFactory.cs b/src/Nncase.Tests/Rewrite/UnitTestEGraphRewriteFactory.cs index 61db6ea870..9a1d8faee7 100644 --- a/src/Nncase.Tests/Rewrite/UnitTestEGraphRewriteFactory.cs +++ b/src/Nncase.Tests/Rewrite/UnitTestEGraphRewriteFactory.cs @@ -111,6 +111,7 @@ public UnitTestEGraphRewriteFactory() new ResizeImageCase(), new ProdCase(), new MultiReshapeCase(), + new PReluTransposeCase(), }; [Theory] diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestBatchNormToBinary.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestBatchNormToBinary.cs index 061fd0c2eb..76e1899d31 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestBatchNormToBinary.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestBatchNormToBinary.cs @@ -15,6 +15,7 @@ using Nncase.Passes; using Nncase.Passes.Rules.Neutral; using Nncase.PatternMatch; +using Nncase.Tests.TestFixture; using Xunit; using static Nncase.IR.F.NN; using ITuple = Nncase.IR.ITuple; @@ -25,6 +26,7 @@ namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestBatchNormToBinary : TransformTestBase { public static readonly TheoryData BatchNormToBinaryPositiveData = new() diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineUnary.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineUnary.cs index 49a2a389cd..5ffd3c7eb9 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestCombineUnary.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestCombineUnary.cs @@ -15,6 +15,7 @@ using Nncase.Passes; using Nncase.Passes.Rules.Neutral; using Nncase.PatternMatch; +using Nncase.Tests.TestFixture; using Xunit; using static Nncase.IR.F.NN; using ITuple = Nncase.IR.ITuple; @@ -25,6 +26,7 @@ namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestCombineUnary : TransformTestBase { // TODO: CombinePadUnary diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestExpandToBinary.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestExpandToBinary.cs index edbd37a310..2e63db64b6 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestExpandToBinary.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestExpandToBinary.cs @@ -11,6 +11,7 @@ using Nncase.IR; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Dimension = Nncase.IR.Dimension; using Math = Nncase.IR.F.Math; @@ -19,6 +20,7 @@ namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestExpandToBroadcast : TransformTestBase { public static IEnumerable TestExpandToBroadcastPositiveData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs index dc20de6667..e0a88f473e 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFlattenToReshape.cs @@ -10,6 +10,7 @@ using System.Threading.Tasks; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Math = Nncase.IR.F.Math; using Random = Nncase.IR.F.Random; @@ -17,6 +18,7 @@ namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestFlattenToReshape : TransformTestBase { public static IEnumerable TestFlattenToReshapePositiveData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldBinary.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldBinary.cs index f46203635b..f05b366d18 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldBinary.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldBinary.cs @@ -14,12 +14,14 @@ using Nncase.IR.NN; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Math = Nncase.IR.F.Math; using Random = Nncase.IR.F.Random; namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestFoldBinary : TransformTestBase { public static IEnumerable TestFoldNopBinaryNegativeData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldCast.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldCast.cs index 54f6a59530..a5c32a88fb 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldCast.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldCast.cs @@ -11,11 +11,13 @@ using Nncase.IR.F; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Random = Nncase.IR.F.Random; namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestFoldCast : TransformTestBase { public static IEnumerable TestFoldTwoCastsPositiveData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldClamp.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldClamp.cs index 72cb0deb35..3117c03360 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldClamp.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldClamp.cs @@ -11,12 +11,14 @@ using Nncase.IR.F; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Math = Nncase.IR.F.Math; using Random = Nncase.IR.F.Random; namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestFoldClamp : TransformTestBase { public static IEnumerable TestFoldNopClampPositiveData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldPad.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldPad.cs index 0b8e607def..a2aa51fa92 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldPad.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldPad.cs @@ -11,12 +11,14 @@ using Nncase.IR.F; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Math = Nncase.IR.F.Math; using Random = Nncase.IR.F.Random; namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestFoldPad : TransformTestBase { public static IEnumerable TestFoldNopPadPositiveData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldQuant.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldQuant.cs index ed8684b323..1488cbd89a 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldQuant.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldQuant.cs @@ -6,11 +6,13 @@ using Nncase.IR; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Random = Nncase.IR.F.Random; namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestFoldQuant : TransformTestBase { public static TheoryData FoldQuantDequantData => new() diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReduces.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReduces.cs index ffe5ecb25a..0abebd3aa2 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReduces.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReduces.cs @@ -12,12 +12,14 @@ using Nncase.IR.F; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Math = Nncase.IR.F.Math; using Random = Nncase.IR.F.Random; namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestFoldReduce : TransformTestBase { public static IEnumerable TestFoldTwoReducesPositiveData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReshape.cs index 2108420fa7..b4af452d8e 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldReshape.cs @@ -11,12 +11,14 @@ using Nncase.IR.F; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Math = Nncase.IR.F.Math; using Random = Nncase.IR.F.Random; namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestFoldReshape : TransformTestBase { public static IEnumerable TestFoldNopReshapePositiveData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs index 066c1f3755..d2fdea0d34 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestReshapeBatchMatmul.cs @@ -20,6 +20,7 @@ namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestReshapeBatchMatmul : TransformTestBase { public static IEnumerable TestReshapeBatchMatmulPositiveData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestSimplifyBinary.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestSimplifyBinary.cs index 1a99065b8d..00abe1266b 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestSimplifyBinary.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestSimplifyBinary.cs @@ -14,12 +14,14 @@ using Nncase.IR.NN; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Math = Nncase.IR.F.Math; using Random = Nncase.IR.F.Random; namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestSimplifyBinary : TransformTestBase { public static IEnumerable TestReassociateMulPositiveData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs index d814e4bcf5..4fa0dd0c0a 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestSpaceToBatchTransform.cs @@ -10,6 +10,7 @@ using System.Threading.Tasks; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Math = Nncase.IR.F.Math; using NN = Nncase.IR.F.NN; @@ -17,6 +18,7 @@ namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestSpaceToBatchToPad : TransformTestBase { public static IEnumerable TestSpaceToBatchToPadPositiveData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs index 9b137e93d8..ece1de55f1 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeToReshape.cs @@ -10,6 +10,7 @@ using System.Threading.Tasks; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Math = Nncase.IR.F.Math; using Random = Nncase.IR.F.Random; @@ -17,6 +18,7 @@ namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestSqueezeToReshape : TransformTestBase { public static IEnumerable TestSqueezeToReshapePositiveData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeTransposeShape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeTransposeShape.cs index 166cc666d4..fecb671000 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeTransposeShape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestSqueezeTransposeShape.cs @@ -62,6 +62,7 @@ public void TestSqueezeTransposeShapeNegative(int[] shape, int[] perm) } } +[AutoSetupTestMethod(InitSession = true)] public class UnitTestSqueezeBinaryShape : TransformTestBase { public static IEnumerable TestSqueezeBinaryShapePosivateData => diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs index bef44a77cb..f6b4f9c9bb 100644 --- a/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestUnSqueezeToReshape.cs @@ -10,6 +10,7 @@ using System.Threading.Tasks; using Nncase.Passes; using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; using Xunit; using Math = Nncase.IR.F.Math; using Random = Nncase.IR.F.Random; @@ -17,6 +18,7 @@ namespace Nncase.Tests.Rules.NeutralTest; +[AutoSetupTestMethod(InitSession = true)] public class UnitTestUnSqueezeToReshape : TransformTestBase { public static IEnumerable TestUnSqueezeToReshapePositiveData =>