Skip to content

Commit

Permalink
Feature/egraph fakepass (#1111)
Browse files Browse the repository at this point in the history
* add cli preprocess options/fix pattern match utility

* fix egraph

* add driver in run pass context

* fix muli branch equal enode extract

* revert to history version

* change op name to output tensor name

* fallback to old unit test

* change for algo quant tools

* add pass to test match

* revise typo for BroadcastInputMarker

* add AutoSetupTestMethod for pad test

* add AutoSetupTestMethod for squeeze

* add AutoSetupTestMethod to some unittest

* change order for calibration test

* change vulkansdk-linux version from 182 to 189

---------

Co-authored-by: 郑启航 <597323109@qq.com>
Co-authored-by: guodongliang <guodongliang@canaan-creative.com>
  • Loading branch information
3 people authored Oct 30, 2023
1 parent 9a7268c commit ef3d74f
Show file tree
Hide file tree
Showing 37 changed files with 287 additions and 105 deletions.
77 changes: 77 additions & 0 deletions src/Nncase.Cli/Commands/Compile.cs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,50 @@ public Compile()
alias: "--calib-method",
description: $"model quant options, default is {Quantization.CalibMethod.Kld}",
getDefaultValue: () => Quantization.CalibMethod.Kld));
AddOption(new Option<bool>(
alias: "--pre-process",
description: "whether enable pre process, default is False",
getDefaultValue: () => false));
AddOption(new Option(
alias: "--input-layout",
description: "the model input data layout, default is empty. eg. NCHW/NHWC",
getDefaultValue: () => string.Empty));
AddOption(new Option(
alias: "--output-layout",
description: "the model output data layout, default is empty. eg. NCHW/NHWC",
getDefaultValue: () => string.Empty));
AddOption(new Option(
alias: "--input-type",
description: "the model input data value type, default is Float32",
getDefaultValue: () => InputType.Float32));
AddOption(new Option<IEnumerable<int>>(
alias: "--input-shape",
description: "the model input data shape, default is []. eg. `--input-shape 1 2 3 4`",
getDefaultValue: () => Array.Empty<int>()));
AddOption(new Option<IEnumerable<float>>(
alias: "--input-range",
description: "the model input data value range, default is []. eg `--input-range -100.3 200.4`",
getDefaultValue: () => Array.Empty<float>()));
AddOption(new Option<bool>(
alias: "--swap-rb",
description: "whether swap the model input data channel R and B",
getDefaultValue: () => false));
AddOption(new Option(
alias: "--letter-box-value",
description: "letterbox value, default 0.0",
getDefaultValue: () => 0.0f));
AddOption(new Option<IEnumerable<float>>(
alias: "--mean",
description: "the model input data mean, default []",
getDefaultValue: () => Array.Empty<float>()));
AddOption(new Option<IEnumerable<float>>(
alias: "--std",
description: "the model input data std, default []",
getDefaultValue: () => Array.Empty<float>()));
AddOption(new Option(
alias: "--model-layout",
description: "the model's input layout, default is empty. eg. NCHW/NHWC",
getDefaultValue: () => string.Empty));
AddOption(new Option<bool>(
alias: "--benchmark-only",
description: $"benchmark only",
Expand Down Expand Up @@ -142,6 +186,17 @@ private async Task RunAsync(CliCompileOptions cliOptions, IHost host)
},
ModelQuantMode = cliOptions.ModelQuantMode,
},
PreProcess = cliOptions.PreProcess,
InputLayout = cliOptions.InputLayout,
OutputLayout = cliOptions.OutputLayout,
InputType = cliOptions.InputType,
InputShape = cliOptions.InputShape.ToArray(),
InputRange = cliOptions.InputRange.ToArray(),
SwapRB = cliOptions.SwapRB,
LetterBoxValue = cliOptions.LetterBoxValue,
Mean = cliOptions.Mean.ToArray(),
Std = cliOptions.Std.ToArray(),
ModelLayout = cliOptions.ModelLayout,
IsBenchmarkOnly = cliOptions.BenchmarkOnly,
};

Expand Down Expand Up @@ -209,6 +264,28 @@ internal sealed class CliCompileOptions
public DatasetFormat DatasetFormat { get; set; }

public bool BenchmarkOnly { get; set; }

public bool PreProcess { get; set; }

public string InputLayout { get; set; }

public string OutputLayout { get; set; }

public InputType InputType { get; set; }

public List<int> InputShape { get; set; }

public List<float> InputRange { get; set; }

public bool SwapRB { get; set; }

public float LetterBoxValue { get; set; }

public List<float> Mean { get; set; }

public List<float> Std { get; set; }

public string ModelLayout { get; set; }
}

#pragma warning restore CS8618
5 changes: 5 additions & 0 deletions src/Nncase.Core/Passes/RunPassContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ public record RunPassContext
/// </summary>
public int Index { get; set; }

/// <summary>
/// Gets this pass's driver.
/// </summary>
public IPass? Driver { get; init; }

/// <summary>
/// Gets or sets a value indicating whether control rewrite once or not.
/// when RewriteOnce is true, the rule will only apply once, then restart rewrite from first rule.
Expand Down
7 changes: 7 additions & 0 deletions src/Nncase.Core/PatternMatch/IMatchResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ public interface IMatchResult : IEnumerable<KeyValuePair<IPattern, object>>
/// <returns>Match result.</returns>
object this[IPattern pattern] { get; }

/// <summary>
/// Get match result by name, default is null.
/// </summary>
/// <param name="name">Pattern name.</param>
/// <returns>match result.</returns>
object GetValueOrDefault(string name);

/// <summary>
/// Get match result by pattern.
/// </summary>
Expand Down
3 changes: 3 additions & 0 deletions src/Nncase.Core/PatternMatch/MatchResult.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ where kv.Key.Name is not null
/// <inheritdoc/>
public object this[string name] => _stringMap[name];

/// <inheritdoc/>
public object GetValueOrDefault(string name) => _stringMap.GetValueOrDefault(name, null!);

/// <inheritdoc/>
public IEnumerator<KeyValuePair<IPattern, object>> GetEnumerator()
{
Expand Down
12 changes: 1 addition & 11 deletions src/Nncase.Core/Utilities/ReplaceUtility.cs
Original file line number Diff line number Diff line change
Expand Up @@ -97,11 +97,6 @@ public static Call ReplaceCallParams(Expr target, IReadOnlyList<Expr> oldParams,
return new Call(target, ReplaceItems(oldParams, pairs));
}

public static Call ReplaceCallParams(Call call, params (int, Expr)[] pairs)
{
return new Call(call.Target, ReplaceItems(call.Arguments.ToArray(), pairs));
}

/// <summary>
/// replace the call params with parameter info.
/// </summary>
Expand All @@ -122,12 +117,7 @@ public static Call ReplaceCallParams(Expr target, IReadOnlyList<Expr> oldParams,
/// <param name="expr">expr.</param>
/// <returns>new Call.</returns>
public static Call ReplaceCallFirstParam(Expr target, IReadOnlyList<Expr> oldParams, Expr expr) =>
ReplaceCallParams(target, oldParams, (oldParams[0], expr));

public static Expr ReplaceCallFirstParam(Call call, Expr expr)
{
return ReplaceCallFirstParam(call.Target, call.Arguments.ToArray(), expr);
}
ReplaceCallParams(target, oldParams, (0, expr));

/// <summary>
/// Replace target in body with expr.
Expand Down
10 changes: 5 additions & 5 deletions src/Nncase.EGraph/Passes/RewriteProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@ public Expr ERewrite(Expr expr, IEnumerable<IRewriteRule> rules, RunPassContext
public IEGraph ERewrite(IEGraph eGraph, IEnumerable<IRewriteRule> rules, RunPassContext context)
{
var last_version = eGraph.Version;
int count = 0;

while (true)
{
var matches = rules.
Expand All @@ -59,10 +57,12 @@ public IEGraph ERewrite(IEGraph eGraph, IEnumerable<IRewriteRule> rules, RunPass

if (DumpScope.Current.IsEnabled(DumpFlags.Rewrite))
{
foreach (var (rule, results) in matches.Where(p => p.Item2.Count != 0))
using var fs = DumpScope.Current.OpenFile(Path.Combine("Matches", $"V{eGraph.Version}.txt"));
using var writer = new StreamWriter(fs);
writer.WriteLine("rule, results");
foreach (var (rule, results) in matches)
{
using var fs = DumpScope.Current.OpenFile(Path.Combine("Matches", $"V{eGraph.Version}_{count++}_{rule.GetType().Name}.dot"));
EGraphPrinter.DumpEgraphAsDot(eGraph, results, fs);
writer.WriteLine($"{rule.GetType().Name}, {results.Count}");
}
}

Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Importer/Onnx/Conv2D.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private Expr VisitConv2D(in NodeProto op)
var pads = AutoPad(op, autoPad, input, weights, strides.ToArray<long>(), dilation.ToArray(), isConv1D);
pads.InferenceType();
var conv = F.NN.Conv2D(input, weights, bias, strides.ToArray(), pads, dilation.ToArray(), PadMode.Constant, group);
List<string> outputNames = new() { op.Name };
List<string> outputNames = new() { op.Output[0] };
conv.Metadata.OutputNames = outputNames;
if (isConv1D)
{
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Importer/Onnx/MatMul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ private Expr VisitMatMul(in NodeProto op)
{
var (a, b) = GetInputExprs(op, 0, 1);
var matmul = IR.F.Math.MatMul(a, b);
List<string> outputNames = new() { op.Name };
List<string> outputNames = new() { op.Output[0] };
matmul.Metadata.OutputNames = outputNames;
return matmul;
}
Expand Down
2 changes: 1 addition & 1 deletion src/Nncase.Importer/TFLite/Conv2DTranspose.cs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ private Expr VisitConv2DTranspose(in tflite.Operator op)
dilation,
PadMode.Constant,
1);
List<string> outputNames = new() { GetInputTensor(op, 0).Name };
List<string> outputNames = new() { GetOutputTensor(op, 0).Name };
conv2DTranspose.Metadata.OutputNames = outputNames;
return F.Tensors.NCHWToNHWC(F.Math.Clamp(
conv2DTranspose,
Expand Down
4 changes: 2 additions & 2 deletions src/Nncase.Importer/TFLite/MatMul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ private Expr VisitMatMul(in tflite.Operator op, bool isFullyConnected = true)
: Expand(Cast(0, GetDataType(GetInputTensor(op, 0).Type)), new[] { otherTensor.Shape(0) }).Evaluate().AsTensor();

var matmul = MatMul(lhs, rhs);
List<string> outputNames = new() { GetInputTensor(op, 0).Name + "_matmul" };
List<string> outputNames = new() { GetOutputTensor(op, 0).Name + "_matmul" };
matmul.Metadata.OutputNames = outputNames;
outputNames.Clear();
outputNames.Add(GetInputTensor(op, 0).Name + "_bias");
outputNames.Add(GetOutputTensor(op, 0).Name);
bias.Metadata.OutputNames = outputNames;
var mm = matmul + bias;

Expand Down
1 change: 1 addition & 0 deletions src/Nncase.Passes/PassManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ public RunPassContextWithAnalysis(IAnalyzerManager analyzerManager, Either<BaseF
_analyzers = populater.Analyzers;
AnalysisResults = populater.AnalysisResults;
RewriteOnce = _analyzers.Count != 0;
Driver = pass;
}

private struct AnalyzerPopulater
Expand Down
67 changes: 37 additions & 30 deletions src/Nncase.Passes/Rules/Lower/BroadcastMarker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Collections.Generic;
using System.Linq;
using Nncase.IR;
using Nncase.IR.Math;
using Nncase.IR.Tensors;
Expand All @@ -17,36 +19,46 @@ namespace Nncase.Passes.Rules.Lower;
[RuleGenerator]
public partial class BroadcastInputMarker : RewriteRule<Pattern>
{
public override Pattern Pattern => IsCallWildcard(
"outer",
IsWildcard(),
InputPattern);
public override Pattern Pattern => IsCall("outer", IsWildcard("outerTarget"), IsVArgsRepeat("outerParams", exprs =>
{
var patterns = new Pattern[exprs.Length];
for (int i = 0; i < exprs.Length; i++)
{
patterns[i] = GetInputPattern(i);
}

return patterns;
}));

public Pattern InputPattern => IsCallWildcard(
"call",
IsWildcard(),
IsRangeOfMarker(
"marker",
IsWildcard(),
IsWildcard()));
public Pattern GetInputPattern(int i) =>
IsAlt(
IsCallWildcard(
$"input_{i}",
IsOp<Op>($"input_target_{i}", NotChangeRangeOp),
IsRangeOfMarker($"input_marker_{i}", IsWildcard($"marker_target_{i}"), IsWildcard($"marker_attribute_{i}"))),
IsWildcard($"input_{i}"));

public Expr? GetReplace(Call outer, Call call, Marker marker)
public Expr? GetReplace(Call outer, Expr outerTarget, IReadOnlyList<Expr> outerParams, IMatchResult result)
{
if (!NotChangeRangeOp(call.Target))
if (!Enumerable.Range(0, outerParams.Count).Select(i => result.GetValueOrDefault($"input_marker_{i}")).Any(e => e is not null))
{
return null;
}

if (outer.Target is MatMul && CompilerServices.TryMatchRoot(outer.Arguments[1], InputPattern, new(), out var matchResult))
var newArgs = new Expr[outerParams.Count];
for (int i = 0; i < outerParams.Count; i++)
{
var rhsMarker = (Marker)matchResult["marker"];
var rhsCall = (Call)matchResult["call"];
var lhs = marker.With(target: ReplaceCallFirstParam(call, marker));
var rhs = rhsMarker.With(target: ReplaceCallFirstParam(rhsCall, rhsMarker));
return ReplaceCallParams(outer, (0, lhs), (1, rhs));
if (result.GetValueOrDefault($"input_marker_{i}") is Marker marker && result[$"marker_target_{i}"] is Expr target && result[$"marker_attribute_{i}"] is Expr range)
{
newArgs[i] = IR.F.Math.RangeOfMarker(outerParams[i], range).With(mixQuantInfo: marker.MixQuantInfo, adaQuantInfo: marker.AdaQuantInfo);
}
else
{
newArgs[i] = outerParams[i];
}
}

return ReplaceCallFirstParam(outer, marker.With(target: ReplaceCallFirstParam(call, marker)));
return new Call(outerTarget, newArgs);
}
}

Expand All @@ -56,23 +68,18 @@ public partial class BroadcastOutputMarker : RewriteRule<Pattern>
{
public override Pattern Pattern => IsRangeOfMarker(
"marker",
IsCallWildcard("input", IsWildcard(), IsCallWildcard(null, IsWildcard())),
IsWildcard());
IsCallWildcard("output", IsOp<Op>("outputTarget", NotChangeRangeOp), IsCallWildcard("input", IsWildcard("inputTarget"))),
IsWildcard("range"));

public Expr? GetReplace(Call input, Marker marker)
public Expr? GetReplace(Marker marker, Expr range, Call output, Op outputTarget, IReadOnlyList<Expr> outputParams)
{
if (!NotChangeRangeOp(input.Target))
{
return null;
}

return ReplaceCallFirstParam(input, marker.With(target: input.Arguments[0]));
return ReplaceCallFirstParam(outputTarget, outputParams, IR.F.Math.RangeOfMarker(outputParams[0], range).With(adaQuantInfo: marker.AdaQuantInfo, mixQuantInfo: marker.MixQuantInfo));
}
}

internal static class BroadcastMarkerHelper
{
public static bool NotChangeRangeOp(Expr op)
public static bool NotChangeRangeOp(Op op)
{
return op is Squeeze || op is Unsqueeze || op is Reshape || op is Broadcast;
}
Expand Down
Loading

0 comments on commit ef3d74f

Please sign in to comment.