Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ReshapeMatMul #1015

Merged
merged 9 commits into from
Jul 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldHardSwish5>();
p.Add<Passes.Rules.Neutral.FoldTwoSlices>();
p.Add<Passes.Rules.Neutral.FocusFull>();
p.Add<Passes.Rules.Neutral.ReshapeMatMul>();
});
passManager.AddWithName<EGraphRulesPass>("NeutralOptimizeTranspose").Configure(p =>
{
Expand Down
119 changes: 1 addition & 118 deletions src/Nncase.Importer/Onnx/MatMul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,124 +13,7 @@ public partial class OnnxImporter
private Expr VisitMatMul(in NodeProto op)
{
var (a, b) = GetInputExprs(op, 0, 1);
var shapeA = IR.F.Tensors.ShapeOf(a);
var shapeB = IR.F.Tensors.ShapeOf(b);
if (a.CheckedShape.IsUnranked || b.CheckedShape.IsUnranked)
{
return IR.F.Tensors.MatMul(a, b);
}

if (a.CheckedShape.Rank > 2 && b.CheckedShape.Rank > 2)
{
return IR.F.Tensors.MatMul(a, b);
}

var lhs = a;
if (a.CheckedShape.Rank > 2)
{
var newShapeA = new Expr[] { -1L, shapeA[-2], shapeA[-1] };
lhs = IR.F.Tensors.Reshape(a, IR.F.Tensors.Stack(new IR.Tuple(newShapeA), 0));
}

if (a.CheckedShape.Rank == 1)
{
var newShapeA = new Expr[] { 1L, shapeA[0] };
lhs = IR.F.Tensors.Reshape(a, IR.F.Tensors.Stack(new IR.Tuple(newShapeA), 0));
}

var rhs = b;
if (b.CheckedShape.Rank > 2)
{
var newShapeB = new Expr[] { -1L, shapeB[-2], shapeB[-1] };
rhs = IR.F.Tensors.Reshape(b, IR.F.Tensors.Stack(new IR.Tuple(newShapeB), 0));
}

if (b.CheckedShape.Rank == 1)
{
var newShapeB = new Expr[] { shapeB[0], 1L };
rhs = IR.F.Tensors.Reshape(b, IR.F.Tensors.Stack(new IR.Tuple(newShapeB), 0));
}

var maxRank = Math.Max(a.CheckedShape.Rank, b.CheckedShape.Rank);
var outputShape = new Expr[maxRank];

if (maxRank == 1)
{
outputShape[0] = 1L;
}
else if (maxRank == 2)
{
if (a.CheckedShape.Rank == 1 && b.CheckedShape.Rank == 2)
{
Array.Resize(ref outputShape, 1);
outputShape[0] = shapeB[1];
}
else if (a.CheckedShape.Rank == 2 && b.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, 1);
outputShape[0] = shapeA[0];
}
else
{
outputShape[0] = shapeA[0];
outputShape[1] = shapeB[1];
}
}
else
{
if (maxRank == a.CheckedShape.Rank)
{
if (b.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, maxRank - 1);
for (var i = 0; i < maxRank - 2; i++)
{
outputShape[i] = shapeA[i];
}

outputShape[^1] = shapeA[-2];
}
else
{
for (var i = 0; i < maxRank - 2; i++)
{
var diff = a.CheckedShape.Rank - b.CheckedShape.Rank;
var dimB = i < diff ? 1L : shapeB[i - diff];
outputShape[i] = IR.F.Math.Max(shapeA[i], dimB);
}

outputShape[^2] = shapeA[-2];
outputShape[^1] = shapeB[-1];
}
}
else if (maxRank == b.CheckedShape.Rank)
{
if (a.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, maxRank - 1);
for (var i = 0; i < maxRank - 2; i++)
{
outputShape[i] = shapeB[i];
}

outputShape[^1] = shapeB[-1];
}
else
{
for (var i = 0; i < maxRank - 2; i++)
{
var diff = b.CheckedShape.Rank - a.CheckedShape.Rank;
var dimA = i < diff ? 1L : shapeA[i - diff];
outputShape[i] = IR.F.Math.Max(shapeB[i], dimA);
}

outputShape[^2] = shapeA[-2];
outputShape[^1] = shapeB[-1];
}
}
}

return IR.F.Tensors.Reshape(F.Tensors.MatMul(lhs, rhs), IR.F.Tensors.Stack(new IR.Tuple(outputShape), 0));
return IR.F.Math.MatMul(a, b);
}
}
}
147 changes: 147 additions & 0 deletions src/Nncase.Passes/Rules/Neutral/ReshapeMatMul.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Linq;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;
using static Nncase.Utilities.MetadataUtility;

namespace Nncase.Passes.Rules.Neutral;

[RuleGenerator]
public partial class ReshapeMatMul : RewriteRule<Pattern>
{
public override Pattern Pattern => IsMatMul(
null,
"matmul",
IsWildcard("a") with { TypePattern = HasFixedShape() },
IsWildcard("b") with { TypePattern = HasFixedShape() });

private Expr? GetReplace(Expr matmul, Expr a, Expr b, RunPassContext context)
{
if (a.CheckedShape.Rank > 2 && b.CheckedShape.Rank > 2)
{
return null;
}

var lhs = a;
var shapeA = a.CheckedShape.ToValueArray();
if (a.CheckedShape.Rank == 4)
{
var c = shapeA.Take(a.CheckedShape.Rank - 2).Aggregate(1, (sum, x) => x * sum);
var newShapeA = new long[] { c, shapeA[^2], shapeA[^1] };
lhs = IR.F.Tensors.Reshape(a, newShapeA);
}
else if (a.CheckedShape.Rank == 1)
{
var newShapeA = new long[] { 1L, shapeA[0] };
lhs = IR.F.Tensors.Reshape(a, newShapeA);
}

var rhs = b;
var shapeB = b.CheckedShape.ToValueArray();
if (b.CheckedShape.Rank == 4)
{
var c = shapeB.Take(b.CheckedShape.Rank - 2).Aggregate(1, (sum, x) => x * sum);
var newShapeB = new long[] { c, shapeB[^2], shapeB[^1] };
rhs = IR.F.Tensors.Reshape(b, newShapeB);
}
else if (b.CheckedShape.Rank == 1)
{
var newShapeB = new long[] { shapeB[0], 1L };
rhs = IR.F.Tensors.Reshape(b, newShapeB);
}

if (lhs == a && rhs == b)
{
return null;
}

var maxRank = Math.Max(a.CheckedShape.Rank, b.CheckedShape.Rank);
var outputShape = new long[maxRank];

if (maxRank == 1)
{
outputShape[0] = 1L;
}
else if (maxRank == 2)
{
if (a.CheckedShape.Rank == 1 && b.CheckedShape.Rank == 2)
{
Array.Resize(ref outputShape, 1);
outputShape[0] = shapeB[1];
}
else if (a.CheckedShape.Rank == 2 && b.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, 1);
outputShape[0] = shapeA[0];
}
else
{
outputShape[0] = shapeA[0];
outputShape[1] = shapeB[1];
}
}
else
{
if (maxRank == a.CheckedShape.Rank)
{
if (b.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, maxRank - 1);
for (var i = 0; i < maxRank - 2; i++)
{
outputShape[i] = shapeA[i];
}

outputShape[^1] = shapeA[^2];
}
else
{
for (var i = 0; i < maxRank - 2; i++)
{
var diff = a.CheckedShape.Rank - b.CheckedShape.Rank;
var dimB = i < diff ? 1L : shapeB[i - diff];
outputShape[i] = Math.Max(shapeA[i], dimB);
}

outputShape[^2] = shapeA[^2];
outputShape[^1] = shapeB[^1];
}
}
else if (maxRank == b.CheckedShape.Rank)
{
if (a.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, maxRank - 1);
for (var i = 0; i < maxRank - 2; i++)
{
outputShape[i] = shapeB[i];
}

outputShape[^1] = shapeB[^1];
}
else
{
for (var i = 0; i < maxRank - 2; i++)
{
var diff = b.CheckedShape.Rank - a.CheckedShape.Rank;
var dimA = i < diff ? 1L : shapeA[i - diff];
outputShape[i] = Math.Max(shapeB[i], dimA);
}

outputShape[^2] = shapeA[^2];
outputShape[^1] = shapeB[^1];
}
}
}

var end = IR.F.Tensors.Reshape(IR.F.Tensors.MatMul(lhs, rhs), outputShape);
return end;
}
}
51 changes: 51 additions & 0 deletions src/Nncase.Tests/Rules/Neutral/UnitTestReshapeMatMul.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System.Collections;
using System.Collections.Generic;
using Nncase.Passes.Rules.Neutral;
using Nncase.Tests.TestFixture;
using Xunit;
using static Nncase.IR.F.Math;

namespace Nncase.Tests.Rules.NeutralTest;

[AutoSetupTestMethod(InitSession = true)]
public class UnitTestReshapeMatMul : TransformTestBase
{
public static IEnumerable<object[]> MatMulShapeData => new[]
{
new object[] { new[] { 2, 3, 7, 9 }, new[] { 9, 7 } },
new object[] { new[] { 7, 9 }, new[] { 2, 3, 9, 7 } },
new object[] { new[] { 3, 7 }, new[] { 7 } },
new object[] { new[] { 7 }, new[] { 7, 3 } },
new object[] { new[] { 2, 3, 7 }, new[] { 7 } },
new object[] { new[] { 3 }, new[] { 2, 3, 7 } },
};

public static IEnumerable<object[]> NopMatMulShapeData => new[]
{
new object[] { new[] { 1, 3, 24, 24 }, new[] { 3, 24, 24 } },
new object[] { new[] { 7, 3 }, new[] { 3, 7 } },
};

[Theory]
[MemberData(nameof(MatMulShapeData))]
public void TestTo3D(int[] shapeA, int[] shapeB)
{
var lhs = Testing.Rand<float>(shapeA);
var rhs = Testing.Rand<float>(shapeB);
var mm = MatMul(lhs, rhs);
TestMatched<ReshapeMatMul>(mm);
}

[Theory]
[MemberData(nameof(NopMatMulShapeData))]
public void TestNop(int[] shapeA, int[] shapeB)
{
var lhs = Testing.Rand<float>(shapeA);
var rhs = Testing.Rand<float>(shapeB);
var mm = MatMul(lhs, rhs);
TestNotMatch<ReshapeMatMul>(mm);
}
}