Skip to content

Commit

Permalink
add ReshapeMatMul (#1015)
Browse files Browse the repository at this point in the history
* add ReshapeMatMul

* Apply code-format changes

* Update Compiler.cs

* fix reshape matmul

* fix test

* Apply code-format changes

* fix test data

---------

Co-authored-by: FusionBolt <FusionBolt@users.noreply.github.com>
Co-authored-by: huochenghai <huochenghai@canaan-creative.com>
  • Loading branch information
3 people authored Jul 24, 2023
1 parent 0b642a2 commit ece6363
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 118 deletions.
1 change: 1 addition & 0 deletions src/Nncase.Compiler/Compiler.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add<Passes.Rules.Neutral.FoldHardSwish5>();
p.Add<Passes.Rules.Neutral.FoldTwoSlices>();
p.Add<Passes.Rules.Neutral.FocusFull>();
p.Add<Passes.Rules.Neutral.ReshapeMatMul>();
});
passManager.AddWithName<EGraphRulesPass>("NeutralOptimizeTranspose").Configure(p =>
{
Expand Down
119 changes: 1 addition & 118 deletions src/Nncase.Importer/Onnx/MatMul.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,124 +13,7 @@ public partial class OnnxImporter
private Expr VisitMatMul(in NodeProto op)
{
var (a, b) = GetInputExprs(op, 0, 1);
var shapeA = IR.F.Tensors.ShapeOf(a);
var shapeB = IR.F.Tensors.ShapeOf(b);
if (a.CheckedShape.IsUnranked || b.CheckedShape.IsUnranked)
{
return IR.F.Tensors.MatMul(a, b);
}

if (a.CheckedShape.Rank > 2 && b.CheckedShape.Rank > 2)
{
return IR.F.Tensors.MatMul(a, b);
}

var lhs = a;
if (a.CheckedShape.Rank > 2)
{
var newShapeA = new Expr[] { -1L, shapeA[-2], shapeA[-1] };
lhs = IR.F.Tensors.Reshape(a, IR.F.Tensors.Stack(new IR.Tuple(newShapeA), 0));
}

if (a.CheckedShape.Rank == 1)
{
var newShapeA = new Expr[] { 1L, shapeA[0] };
lhs = IR.F.Tensors.Reshape(a, IR.F.Tensors.Stack(new IR.Tuple(newShapeA), 0));
}

var rhs = b;
if (b.CheckedShape.Rank > 2)
{
var newShapeB = new Expr[] { -1L, shapeB[-2], shapeB[-1] };
rhs = IR.F.Tensors.Reshape(b, IR.F.Tensors.Stack(new IR.Tuple(newShapeB), 0));
}

if (b.CheckedShape.Rank == 1)
{
var newShapeB = new Expr[] { shapeB[0], 1L };
rhs = IR.F.Tensors.Reshape(b, IR.F.Tensors.Stack(new IR.Tuple(newShapeB), 0));
}

var maxRank = Math.Max(a.CheckedShape.Rank, b.CheckedShape.Rank);
var outputShape = new Expr[maxRank];

if (maxRank == 1)
{
outputShape[0] = 1L;
}
else if (maxRank == 2)
{
if (a.CheckedShape.Rank == 1 && b.CheckedShape.Rank == 2)
{
Array.Resize(ref outputShape, 1);
outputShape[0] = shapeB[1];
}
else if (a.CheckedShape.Rank == 2 && b.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, 1);
outputShape[0] = shapeA[0];
}
else
{
outputShape[0] = shapeA[0];
outputShape[1] = shapeB[1];
}
}
else
{
if (maxRank == a.CheckedShape.Rank)
{
if (b.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, maxRank - 1);
for (var i = 0; i < maxRank - 2; i++)
{
outputShape[i] = shapeA[i];
}

outputShape[^1] = shapeA[-2];
}
else
{
for (var i = 0; i < maxRank - 2; i++)
{
var diff = a.CheckedShape.Rank - b.CheckedShape.Rank;
var dimB = i < diff ? 1L : shapeB[i - diff];
outputShape[i] = IR.F.Math.Max(shapeA[i], dimB);
}

outputShape[^2] = shapeA[-2];
outputShape[^1] = shapeB[-1];
}
}
else if (maxRank == b.CheckedShape.Rank)
{
if (a.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, maxRank - 1);
for (var i = 0; i < maxRank - 2; i++)
{
outputShape[i] = shapeB[i];
}

outputShape[^1] = shapeB[-1];
}
else
{
for (var i = 0; i < maxRank - 2; i++)
{
var diff = b.CheckedShape.Rank - a.CheckedShape.Rank;
var dimA = i < diff ? 1L : shapeA[i - diff];
outputShape[i] = IR.F.Math.Max(shapeB[i], dimA);
}

outputShape[^2] = shapeA[-2];
outputShape[^1] = shapeB[-1];
}
}
}

return IR.F.Tensors.Reshape(F.Tensors.MatMul(lhs, rhs), IR.F.Tensors.Stack(new IR.Tuple(outputShape), 0));
return IR.F.Math.MatMul(a, b);
}
}
}
147 changes: 147 additions & 0 deletions src/Nncase.Passes/Rules/Neutral/ReshapeMatMul.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System;
using System.Linq;
using Nncase.Diagnostics;
using Nncase.IR;
using Nncase.PatternMatch;
using static Nncase.IR.TypePatternUtility;
using static Nncase.PatternMatch.F.Math;
using static Nncase.PatternMatch.Utility;
using static Nncase.Utilities.MetadataUtility;

namespace Nncase.Passes.Rules.Neutral;

[RuleGenerator]
public partial class ReshapeMatMul : RewriteRule<Pattern>
{
public override Pattern Pattern => IsMatMul(
null,
"matmul",
IsWildcard("a") with { TypePattern = HasFixedShape() },
IsWildcard("b") with { TypePattern = HasFixedShape() });

private Expr? GetReplace(Expr matmul, Expr a, Expr b, RunPassContext context)
{
if (a.CheckedShape.Rank > 2 && b.CheckedShape.Rank > 2)
{
return null;
}

var lhs = a;
var shapeA = a.CheckedShape.ToValueArray();
if (a.CheckedShape.Rank == 4)
{
var c = shapeA.Take(a.CheckedShape.Rank - 2).Aggregate(1, (sum, x) => x * sum);
var newShapeA = new long[] { c, shapeA[^2], shapeA[^1] };
lhs = IR.F.Tensors.Reshape(a, newShapeA);
}
else if (a.CheckedShape.Rank == 1)
{
var newShapeA = new long[] { 1L, shapeA[0] };
lhs = IR.F.Tensors.Reshape(a, newShapeA);
}

var rhs = b;
var shapeB = b.CheckedShape.ToValueArray();
if (b.CheckedShape.Rank == 4)
{
var c = shapeB.Take(b.CheckedShape.Rank - 2).Aggregate(1, (sum, x) => x * sum);
var newShapeB = new long[] { c, shapeB[^2], shapeB[^1] };
rhs = IR.F.Tensors.Reshape(b, newShapeB);
}
else if (b.CheckedShape.Rank == 1)
{
var newShapeB = new long[] { shapeB[0], 1L };
rhs = IR.F.Tensors.Reshape(b, newShapeB);
}

if (lhs == a && rhs == b)
{
return null;
}

var maxRank = Math.Max(a.CheckedShape.Rank, b.CheckedShape.Rank);
var outputShape = new long[maxRank];

if (maxRank == 1)
{
outputShape[0] = 1L;
}
else if (maxRank == 2)
{
if (a.CheckedShape.Rank == 1 && b.CheckedShape.Rank == 2)
{
Array.Resize(ref outputShape, 1);
outputShape[0] = shapeB[1];
}
else if (a.CheckedShape.Rank == 2 && b.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, 1);
outputShape[0] = shapeA[0];
}
else
{
outputShape[0] = shapeA[0];
outputShape[1] = shapeB[1];
}
}
else
{
if (maxRank == a.CheckedShape.Rank)
{
if (b.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, maxRank - 1);
for (var i = 0; i < maxRank - 2; i++)
{
outputShape[i] = shapeA[i];
}

outputShape[^1] = shapeA[^2];
}
else
{
for (var i = 0; i < maxRank - 2; i++)
{
var diff = a.CheckedShape.Rank - b.CheckedShape.Rank;
var dimB = i < diff ? 1L : shapeB[i - diff];
outputShape[i] = Math.Max(shapeA[i], dimB);
}

outputShape[^2] = shapeA[^2];
outputShape[^1] = shapeB[^1];
}
}
else if (maxRank == b.CheckedShape.Rank)
{
if (a.CheckedShape.Rank == 1)
{
Array.Resize(ref outputShape, maxRank - 1);
for (var i = 0; i < maxRank - 2; i++)
{
outputShape[i] = shapeB[i];
}

outputShape[^1] = shapeB[^1];
}
else
{
for (var i = 0; i < maxRank - 2; i++)
{
var diff = b.CheckedShape.Rank - a.CheckedShape.Rank;
var dimA = i < diff ? 1L : shapeA[i - diff];
outputShape[i] = Math.Max(shapeB[i], dimA);
}

outputShape[^2] = shapeA[^2];
outputShape[^1] = shapeB[^1];
}
}
}

var end = IR.F.Tensors.Reshape(IR.F.Tensors.MatMul(lhs, rhs), outputShape);
return end;
}
}
51 changes: 51 additions & 0 deletions src/Nncase.Tests/Rules/Neutral/UnitTestReshapeMatMul.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.

using System.Collections;
using System.Collections.Generic;
using Nncase.Passes.Rules.Neutral;
using Nncase.Tests.TestFixture;
using Xunit;
using static Nncase.IR.F.Math;

namespace Nncase.Tests.Rules.NeutralTest;

[AutoSetupTestMethod(InitSession = true)]
public class UnitTestReshapeMatMul : TransformTestBase
{
public static IEnumerable<object[]> MatMulShapeData => new[]
{
new object[] { new[] { 2, 3, 7, 9 }, new[] { 9, 7 } },
new object[] { new[] { 7, 9 }, new[] { 2, 3, 9, 7 } },
new object[] { new[] { 3, 7 }, new[] { 7 } },
new object[] { new[] { 7 }, new[] { 7, 3 } },
new object[] { new[] { 2, 3, 7 }, new[] { 7 } },
new object[] { new[] { 3 }, new[] { 2, 3, 7 } },
};

public static IEnumerable<object[]> NopMatMulShapeData => new[]
{
new object[] { new[] { 1, 3, 24, 24 }, new[] { 3, 24, 24 } },
new object[] { new[] { 7, 3 }, new[] { 3, 7 } },
};

[Theory]
[MemberData(nameof(MatMulShapeData))]
public void TestTo3D(int[] shapeA, int[] shapeB)
{
var lhs = Testing.Rand<float>(shapeA);
var rhs = Testing.Rand<float>(shapeB);
var mm = MatMul(lhs, rhs);
TestMatched<ReshapeMatMul>(mm);
}

[Theory]
[MemberData(nameof(NopMatMulShapeData))]
public void TestNop(int[] shapeA, int[] shapeB)
{
var lhs = Testing.Rand<float>(shapeA);
var rhs = Testing.Rand<float>(shapeB);
var mm = MatMul(lhs, rhs);
TestNotMatch<ReshapeMatMul>(mm);
}
}

0 comments on commit ece6363

Please sign in to comment.