From 85011d810954dc0694c882b28d9c43635013c55d Mon Sep 17 00:00:00 2001 From: FusionBolt <59008347+FusionBolt@users.noreply.github.com> Date: Fri, 15 Sep 2023 17:42:46 +0800 Subject: [PATCH] GNNE-1891 ShapeBucket Optimize (#1088) * Optimize ShapeBucket --------- Co-authored-by: FusionBolt --- .../CodeGen/StackVM/CodeGenVisitor.g.cs | 17 +- .../CodeGen/StackVM/StackVMEmitter.g.cs | 153 +++-- .../nncase/kernels/stackvm/tensor_ops.h | 24 +- .../nncase/runtime/stackvm/op_reader.h | 62 +- .../include/nncase/runtime/stackvm/opcode.h | 143 +++-- .../src/kernels/stackvm/reference/pad.cpp | 14 +- src/Native/src/kernels/stackvm/shape_ops.cpp | 108 +++- src/Native/src/runtime/stackvm/op_reader.cpp | 17 +- src/Native/src/runtime/stackvm/ops/tensor.cpp | 81 ++- .../runtime/stackvm/runtime_function_ops.h | 7 +- src/Nncase.Compiler/Compiler.cs | 51 +- src/Nncase.Core/IR/ShapeExpr/Functional.cs | 10 + src/Nncase.Core/IR/ShapeExpr/GetPaddings.cs | 40 ++ src/Nncase.Core/IR/ShapeExpr/ReshapeShape.cs | 24 + src/Nncase.Core/IR/ShapeExpr/SqueezeShape.cs | 24 + .../IR/ShapeExpr/TransposeShape.cs | 24 + .../IR/ShapeExpr/UnsqueezeShape.cs | 21 + src/Nncase.Core/Utilities/ShapeExprUtility.cs | 25 +- src/Nncase.Evaluator/Math/Binary.cs | 2 +- src/Nncase.Evaluator/Math/MatMul.cs | 2 +- src/Nncase.Evaluator/Math/Reduce.cs | 6 +- src/Nncase.Evaluator/NN/BatchToSpace.cs | 6 +- src/Nncase.Evaluator/NN/Conv2D.cs | 8 +- src/Nncase.Evaluator/NN/Pad.cs | 6 +- src/Nncase.Evaluator/NN/SpaceToBatch.cs | 6 +- src/Nncase.Evaluator/ShapeEvaluateContext.cs | 19 +- src/Nncase.Evaluator/ShapeEvaluateVisitor.cs | 2 +- .../ShapeEvaluatorProvider.cs | 3 +- .../ShapeExpr/BroadcastShape.cs | 5 +- src/Nncase.Evaluator/ShapeExpr/GetPaddings.cs | 89 +++ .../ShapeExpr/ReshapeShape.cs | 50 ++ .../ShapeExpr/ShapeExprModule.cs | 5 + .../ShapeExpr/SqueezeShape.cs | 57 ++ .../ShapeExpr/TransposeShape.cs | 51 ++ .../ShapeExpr/UnsqueezeShape.cs | 57 ++ src/Nncase.Evaluator/Tensors/Concat.cs | 4 +- src/Nncase.Evaluator/Tensors/Range.cs | 2 +- src/Nncase.Evaluator/Tensors/Reshape.cs | 71 +-- src/Nncase.Evaluator/Tensors/Slice.cs | 2 +- src/Nncase.Evaluator/Tensors/Squeeze.cs | 23 +- src/Nncase.Evaluator/Tensors/Stack.cs | 4 +- src/Nncase.Evaluator/Tensors/Tile.cs | 2 +- src/Nncase.Evaluator/Tensors/Transpose.cs | 12 +- src/Nncase.Evaluator/Tensors/UnSqueeze.cs | 24 +- src/Nncase.Importer/TFLite/Conv2D.cs | 14 +- src/Nncase.Importer/TFLite/Conv2DTranspose.cs | 8 +- src/Nncase.Importer/Util.cs | 8 +- .../Rules/Neutral/FoldReshape.cs | 5 + .../Rules/Neutral/FoldSqueeze.cs | 62 ++ .../Rules/Neutral/FoldTranspose.cs | 5 + .../Rules/Neutral/SplitSpaceToBatch.cs | 180 ++++++ .../Rules/ShapeBucket/FoldBucketReshape.cs | 44 -- .../Rules/ShapeBucket/FoldNopTuple.cs | 64 +++ .../Rules/ShapeBucket/MergeBucketFusion.cs | 48 +- .../Rules/ShapeBucket/MergeCallToFusion.cs | 44 +- .../Rules/ShapeBucket/RecordFusionShape.cs | 101 ++-- .../Rules/ShapeBucket/ShapeBucket.cs | 536 +++++++++++++----- .../Rules/ShapeBucket/ShapeBucketHelper.cs | 249 +++++++- .../Rules/ShapeExpr/FoldBroadcastShape.cs | 76 +++ .../Rules/ShapeExpr/FoldGetItemShapeOf.cs | 41 ++ .../Rules/ShapeExpr/FoldSplitShapeOf.cs | 57 ++ .../Rules/ShapeExpr/GatherToGetItem.cs | 24 + .../Rules/ShapeExpr/SliceToGetItem.cs | 44 ++ .../TransformTestBase.cs | 1 + src/Nncase.Tests/Importer/UnitTestUtil.cs | 17 - .../Rules/Neutral/UnitTestFoldSqueeze.cs | 33 ++ .../Neutral/UnitTestSplitSpaceToBatch.cs | 39 ++ .../Rules/ShapeBucket/ShapeBucketTest.cs | 168 +++++- .../ShapeExpr/UnitTestFoldBroadcastShape.cs | 31 + .../ShapeExpr/UnitTestFoldGetItemShapeOf.cs | 45 ++ .../ShapeExpr/UnitTestFoldSplitShapeOf.cs | 28 + .../ShapeExpr/UnitTestGatherToGetItem.cs | 35 ++ .../Rules/ShapeExpr/UnitTestSliceToGetItem.cs | 36 ++ test.md | 75 --- testEnvironments.json | 17 - 75 files changed, 2768 insertions(+), 730 deletions(-) create mode 100644 src/Nncase.Core/IR/ShapeExpr/GetPaddings.cs create mode 100644 src/Nncase.Core/IR/ShapeExpr/ReshapeShape.cs create mode 100644 src/Nncase.Core/IR/ShapeExpr/SqueezeShape.cs create mode 100644 src/Nncase.Core/IR/ShapeExpr/TransposeShape.cs create mode 100644 src/Nncase.Core/IR/ShapeExpr/UnsqueezeShape.cs create mode 100644 src/Nncase.Evaluator/ShapeExpr/GetPaddings.cs create mode 100644 src/Nncase.Evaluator/ShapeExpr/ReshapeShape.cs create mode 100644 src/Nncase.Evaluator/ShapeExpr/SqueezeShape.cs create mode 100644 src/Nncase.Evaluator/ShapeExpr/TransposeShape.cs create mode 100644 src/Nncase.Evaluator/ShapeExpr/UnsqueezeShape.cs create mode 100644 src/Nncase.Passes/Rules/Neutral/FoldSqueeze.cs create mode 100644 src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs delete mode 100644 src/Nncase.Passes/Rules/ShapeBucket/FoldBucketReshape.cs create mode 100644 src/Nncase.Passes/Rules/ShapeBucket/FoldNopTuple.cs create mode 100644 src/Nncase.Passes/Rules/ShapeExpr/FoldBroadcastShape.cs create mode 100644 src/Nncase.Passes/Rules/ShapeExpr/FoldGetItemShapeOf.cs create mode 100644 src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs create mode 100644 src/Nncase.Passes/Rules/ShapeExpr/GatherToGetItem.cs create mode 100644 src/Nncase.Passes/Rules/ShapeExpr/SliceToGetItem.cs create mode 100644 src/Nncase.Tests/Rules/Neutral/UnitTestFoldSqueeze.cs create mode 100644 src/Nncase.Tests/Rules/Neutral/UnitTestSplitSpaceToBatch.cs create mode 100644 src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldBroadcastShape.cs create mode 100644 src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs create mode 100644 src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldSplitShapeOf.cs create mode 100644 src/Nncase.Tests/Rules/ShapeExpr/UnitTestGatherToGetItem.cs create mode 100644 src/Nncase.Tests/Rules/ShapeExpr/UnitTestSliceToGetItem.cs delete mode 100644 test.md delete mode 100644 testEnvironments.json diff --git a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs index b900856080..2cfd50e458 100644 --- a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs +++ b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs @@ -1,6 +1,6 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39 +08:00. */ +/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:30 +08:00. */ using System; using System.Collections.Generic; @@ -271,9 +271,24 @@ private void EmitTensorCall(Op op) case IR.ShapeExpr.Conv2DTransposeShape top: Emitter.T.Conv2DTransposeShape(); break; + case IR.ShapeExpr.GetPaddings top: + Emitter.T.GetPaddings(); + break; case IR.ShapeExpr.MatMulShape top: Emitter.T.MatMulShape(); break; + case IR.ShapeExpr.ReshapeShape top: + Emitter.T.ReshapeShape(); + break; + case IR.ShapeExpr.SqueezeShape top: + Emitter.T.SqueezeShape(); + break; + case IR.ShapeExpr.TransposeShape top: + Emitter.T.TransposeShape(); + break; + case IR.ShapeExpr.UnsqueezeShape top: + Emitter.T.UnsqueezeShape(); + break; case IR.Random.Normal top: Emitter.T.Normal(top.Type); break; diff --git a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs index d69739ded8..6e2184c5ea 100644 --- a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs +++ b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs @@ -1,6 +1,6 @@ // Copyright (c) Canaan Inc. All rights reserved. // Licensed under the Apache license. See LICENSE file in the project root for full license information. -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39 +08:00. */ +/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:30 +08:00. */ using System; using System.Collections.Generic; @@ -876,52 +876,59 @@ public void GetItem() } ///. - public void Hardmax() + public void GetPaddings() { _emitter.Write((byte)100); _emitter.Write((ushort)32); } ///. - public void HardSigmoid() + public void Hardmax() { _emitter.Write((byte)100); _emitter.Write((ushort)33); } ///. - public void HardSwish() + public void HardSigmoid() { _emitter.Write((byte)100); _emitter.Write((ushort)34); } ///. - public void IndexOf() + public void HardSwish() { _emitter.Write((byte)100); _emitter.Write((ushort)35); } ///. - public void InstanceNormalization() + public void IndexOf() { _emitter.Write((byte)100); _emitter.Write((ushort)36); } ///. - public void L2Normalization() + public void InstanceNormalization() { _emitter.Write((byte)100); _emitter.Write((ushort)37); } ///. - public void LayerNorm(int axis, float epsilon) + public void L2Normalization() { _emitter.Write((byte)100); _emitter.Write((ushort)38); + } + + ///. + public void LayerNorm(int axis, float epsilon) + { + _emitter.Write((byte)100); + _emitter.Write((ushort)39); _emitter.Write(axis); _emitter.Write(epsilon); } @@ -930,35 +937,35 @@ public void LayerNorm(int axis, float epsilon) public void LeakyRelu() { _emitter.Write((byte)100); - _emitter.Write((ushort)39); + _emitter.Write((ushort)40); } ///. public void LogSoftmax() { _emitter.Write((byte)100); - _emitter.Write((ushort)40); + _emitter.Write((ushort)41); } ///. public void LpNormalization() { _emitter.Write((byte)100); - _emitter.Write((ushort)41); + _emitter.Write((ushort)42); } ///. public void LRN() { _emitter.Write((byte)100); - _emitter.Write((ushort)42); + _emitter.Write((ushort)43); } ///. public void LSTM(LSTMDirection direction, LSTMLayout layout, string[] activations) { _emitter.Write((byte)100); - _emitter.Write((ushort)43); + _emitter.Write((ushort)44); _emitter.Write((int)direction); _emitter.Write((int)layout); _emitter.Write(activations); @@ -968,21 +975,21 @@ public void LSTM(LSTMDirection direction, LSTMLayout layout, string[] activation public void MatMul() { _emitter.Write((byte)100); - _emitter.Write((ushort)44); + _emitter.Write((ushort)45); } ///. public void MatMulShape() { _emitter.Write((byte)100); - _emitter.Write((ushort)45); + _emitter.Write((ushort)46); } ///. public void Normal(DataType type) { _emitter.Write((byte)100); - _emitter.Write((ushort)46); + _emitter.Write((ushort)47); _emitter.Write(type); } @@ -990,7 +997,7 @@ public void Normal(DataType type) public void NormalLike(DataType type) { _emitter.Write((byte)100); - _emitter.Write((ushort)47); + _emitter.Write((ushort)48); _emitter.Write(type); } @@ -998,7 +1005,7 @@ public void NormalLike(DataType type) public void OneHot(OneHotMode oneHotMode) { _emitter.Write((byte)100); - _emitter.Write((ushort)48); + _emitter.Write((ushort)49); _emitter.Write((byte)oneHotMode); } @@ -1006,7 +1013,7 @@ public void OneHot(OneHotMode oneHotMode) public void Pad(PadMode padMode) { _emitter.Write((byte)100); - _emitter.Write((ushort)49); + _emitter.Write((ushort)50); _emitter.Write((byte)padMode); } @@ -1014,21 +1021,21 @@ public void Pad(PadMode padMode) public void PRelu() { _emitter.Write((byte)100); - _emitter.Write((ushort)50); + _emitter.Write((ushort)51); } ///. public void Prod() { _emitter.Write((byte)100); - _emitter.Write((ushort)51); + _emitter.Write((ushort)52); } ///. public void Quantize(DataType targetType) { _emitter.Write((byte)100); - _emitter.Write((ushort)52); + _emitter.Write((ushort)53); _emitter.Write(targetType); } @@ -1036,7 +1043,7 @@ public void Quantize(DataType targetType) public void QuantParamOf(QuantMode quantMode) { _emitter.Write((byte)100); - _emitter.Write((ushort)53); + _emitter.Write((ushort)54); _emitter.Write((int)quantMode); } @@ -1044,14 +1051,14 @@ public void QuantParamOf(QuantMode quantMode) public void Range() { _emitter.Write((byte)100); - _emitter.Write((ushort)54); + _emitter.Write((ushort)55); } ///. public void RangeOf(bool isRangeOfWeight) { _emitter.Write((byte)100); - _emitter.Write((ushort)55); + _emitter.Write((ushort)56); _emitter.Write(isRangeOfWeight); } @@ -1059,14 +1066,14 @@ public void RangeOf(bool isRangeOfWeight) public void Rank() { _emitter.Write((byte)100); - _emitter.Write((ushort)56); + _emitter.Write((ushort)57); } ///. public void Reduce(ReduceOp reduceOp) { _emitter.Write((byte)100); - _emitter.Write((ushort)57); + _emitter.Write((ushort)58); _emitter.Write((byte)reduceOp); } @@ -1074,7 +1081,7 @@ public void Reduce(ReduceOp reduceOp) public void ReduceArg(ReduceArgOp reduceArgOp, DataType destType) { _emitter.Write((byte)100); - _emitter.Write((ushort)58); + _emitter.Write((ushort)59); _emitter.Write((byte)reduceArgOp); _emitter.Write(destType); } @@ -1083,7 +1090,7 @@ public void ReduceArg(ReduceArgOp reduceArgOp, DataType destType) public void ReduceWindow2D(ReduceOp reduceOp) { _emitter.Write((byte)100); - _emitter.Write((ushort)59); + _emitter.Write((ushort)60); _emitter.Write((byte)reduceOp); } @@ -1091,21 +1098,21 @@ public void ReduceWindow2D(ReduceOp reduceOp) public void Relu() { _emitter.Write((byte)100); - _emitter.Write((ushort)60); + _emitter.Write((ushort)61); } ///. public void Relu6() { _emitter.Write((byte)100); - _emitter.Write((ushort)61); + _emitter.Write((ushort)62); } ///. public void Require(string message, bool canFoldConstCall) { _emitter.Write((byte)100); - _emitter.Write((ushort)62); + _emitter.Write((ushort)63); _emitter.Write(message); _emitter.Write(canFoldConstCall); } @@ -1114,14 +1121,21 @@ public void Require(string message, bool canFoldConstCall) public void Reshape() { _emitter.Write((byte)100); - _emitter.Write((ushort)63); + _emitter.Write((ushort)64); + } + + ///. + public void ReshapeShape() + { + _emitter.Write((byte)100); + _emitter.Write((ushort)65); } ///. public void ResizeImage(ImageResizeMode resizeMode, ImageResizeTransformationMode transformationMode, ImageResizeNearestMode nearestMode, bool isTFResize) { _emitter.Write((byte)100); - _emitter.Write((ushort)64); + _emitter.Write((ushort)66); _emitter.Write((byte)resizeMode); _emitter.Write((int)transformationMode); _emitter.Write((int)nearestMode); @@ -1132,147 +1146,161 @@ public void ResizeImage(ImageResizeMode resizeMode, ImageResizeTransformationMod public void ReverseSequence() { _emitter.Write((byte)100); - _emitter.Write((ushort)65); + _emitter.Write((ushort)67); } ///. public void ScatterND() { _emitter.Write((byte)100); - _emitter.Write((ushort)66); + _emitter.Write((ushort)68); } ///. public void Select() { _emitter.Write((byte)100); - _emitter.Write((ushort)67); + _emitter.Write((ushort)69); } ///. public void Selu() { _emitter.Write((byte)100); - _emitter.Write((ushort)68); + _emitter.Write((ushort)70); } ///. public void ShapeOf() { _emitter.Write((byte)100); - _emitter.Write((ushort)69); + _emitter.Write((ushort)71); } ///. public void Sigmoid() { _emitter.Write((byte)100); - _emitter.Write((ushort)70); + _emitter.Write((ushort)72); } ///. public void SizeOf() { _emitter.Write((byte)100); - _emitter.Write((ushort)71); + _emitter.Write((ushort)73); } ///. public void Slice() { _emitter.Write((byte)100); - _emitter.Write((ushort)72); + _emitter.Write((ushort)74); } ///. public void Softmax() { _emitter.Write((byte)100); - _emitter.Write((ushort)73); + _emitter.Write((ushort)75); } ///. public void Softplus() { _emitter.Write((byte)100); - _emitter.Write((ushort)74); + _emitter.Write((ushort)76); } ///. public void Softsign() { _emitter.Write((byte)100); - _emitter.Write((ushort)75); + _emitter.Write((ushort)77); } ///. public void SpaceToBatch() { _emitter.Write((byte)100); - _emitter.Write((ushort)76); + _emitter.Write((ushort)78); } ///. public void Split() { _emitter.Write((byte)100); - _emitter.Write((ushort)77); + _emitter.Write((ushort)79); } ///. public void Squeeze() { _emitter.Write((byte)100); - _emitter.Write((ushort)78); + _emitter.Write((ushort)80); + } + + ///. + public void SqueezeShape() + { + _emitter.Write((byte)100); + _emitter.Write((ushort)81); } ///. public void Stack() { _emitter.Write((byte)100); - _emitter.Write((ushort)79); + _emitter.Write((ushort)82); } ///. public void Swish() { _emitter.Write((byte)100); - _emitter.Write((ushort)80); + _emitter.Write((ushort)83); } ///. public void Tile() { _emitter.Write((byte)100); - _emitter.Write((ushort)81); + _emitter.Write((ushort)84); } ///. public void TopK() { _emitter.Write((byte)100); - _emitter.Write((ushort)82); + _emitter.Write((ushort)85); } ///. public void Transpose() { _emitter.Write((byte)100); - _emitter.Write((ushort)83); + _emitter.Write((ushort)86); + } + + ///. + public void TransposeShape() + { + _emitter.Write((byte)100); + _emitter.Write((ushort)87); } ///. public void Trilu() { _emitter.Write((byte)100); - _emitter.Write((ushort)84); + _emitter.Write((ushort)88); } ///. public void Unary(UnaryOp unaryOp) { _emitter.Write((byte)100); - _emitter.Write((ushort)85); + _emitter.Write((ushort)89); _emitter.Write((byte)unaryOp); } @@ -1280,7 +1308,7 @@ public void Unary(UnaryOp unaryOp) public void Uniform(DataType type) { _emitter.Write((byte)100); - _emitter.Write((ushort)86); + _emitter.Write((ushort)90); _emitter.Write(type); } @@ -1288,7 +1316,7 @@ public void Uniform(DataType type) public void UniformLike(DataType type) { _emitter.Write((byte)100); - _emitter.Write((ushort)87); + _emitter.Write((ushort)91); _emitter.Write(type); } @@ -1296,14 +1324,21 @@ public void UniformLike(DataType type) public void Unsqueeze() { _emitter.Write((byte)100); - _emitter.Write((ushort)88); + _emitter.Write((ushort)92); + } + + ///. + public void UnsqueezeShape() + { + _emitter.Write((byte)100); + _emitter.Write((ushort)93); } ///. public void Where(bool isTfWhere) { _emitter.Write((byte)100); - _emitter.Write((ushort)89); + _emitter.Write((ushort)94); _emitter.Write(isTfWhere); } } diff --git a/src/Native/include/nncase/kernels/stackvm/tensor_ops.h b/src/Native/include/nncase/kernels/stackvm/tensor_ops.h index db59425a29..e918f22f25 100644 --- a/src/Native/include/nncase/kernels/stackvm/tensor_ops.h +++ b/src/Native/include/nncase/kernels/stackvm/tensor_ops.h @@ -1,4 +1,4 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39 +/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 * +08:00. * * Copyright 2019-2021 Canaan Inc. @@ -178,6 +178,12 @@ NNCASE_API result get_item(value_t input, value_t index, value_t output = nullptr, kernel_context &context = default_kernel_context()); +NNCASE_API result +get_paddings(value_t input_shape, value_t weights_shape, value_t strides, + value_t dilations, value_t same, value_t lower, + value_t output = nullptr, + kernel_context &context = default_kernel_context()); + NNCASE_API result hard_sigmoid(value_t input, value_t alpha, value_t beta, value_t output = nullptr, @@ -330,6 +336,10 @@ NNCASE_API result reshape(value_t input, value_t shape, value_t output = nullptr, kernel_context &context = default_kernel_context()); +NNCASE_API result +reshape_shape(value_t input_shape, value_t shape, value_t output = nullptr, + kernel_context &context = default_kernel_context()); + NNCASE_API result resize_image( runtime::stackvm::image_resize_mode_t resize_mode, runtime::stackvm::image_resize_transformation_mode_t transformation_mode, @@ -400,6 +410,10 @@ NNCASE_API result squeeze(value_t input, value_t dim, value_t output = nullptr, kernel_context &context = default_kernel_context()); +NNCASE_API result +squeeze_shape(value_t input_shape, value_t dim, value_t output = nullptr, + kernel_context &context = default_kernel_context()); + NNCASE_API result stack(value_t inputs, value_t axis, value_t output = nullptr, kernel_context &context = default_kernel_context()); @@ -421,6 +435,10 @@ NNCASE_API result transpose(value_t input, value_t perm, value_t output = nullptr, kernel_context &context = default_kernel_context()); +NNCASE_API result +transpose_shape(value_t input_shape, value_t perm, value_t output = nullptr, + kernel_context &context = default_kernel_context()); + NNCASE_API result trilu(value_t input, value_t k, value_t upper, value_t output = nullptr, kernel_context &context = default_kernel_context()); @@ -444,6 +462,10 @@ NNCASE_API result unsqueeze(value_t input, value_t dim, value_t output = nullptr, kernel_context &context = default_kernel_context()); +NNCASE_API result +unsqueeze_shape(value_t input_shape, value_t dim, value_t output = nullptr, + kernel_context &context = default_kernel_context()); + NNCASE_API result where(bool is_tf_where, value_t cond, value_t x, value_t y, value_t output = nullptr, diff --git a/src/Native/include/nncase/runtime/stackvm/op_reader.h b/src/Native/include/nncase/runtime/stackvm/op_reader.h index 5de273cab1..80372463e4 100644 --- a/src/Native/include/nncase/runtime/stackvm/op_reader.h +++ b/src/Native/include/nncase/runtime/stackvm/op_reader.h @@ -1,4 +1,4 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39 +/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 * +08:00. * * Copyright 2019-2021 Canaan Inc. @@ -997,6 +997,14 @@ template <> struct tensor_op_reader { } }; +template <> struct tensor_op_reader { + tensor_get_paddings_op_t + operator()(NNCASE_UNUSED span_reader &reader) const { + tensor_get_paddings_op_t op; + return op; + } +}; + template <> struct tensor_op_reader { tensor_hard_sigmoid_op_t operator()(NNCASE_UNUSED span_reader &reader) const { @@ -1257,6 +1265,14 @@ template <> struct tensor_op_reader { } }; +template <> struct tensor_op_reader { + tensor_reshape_shape_op_t + operator()(NNCASE_UNUSED span_reader &reader) const { + tensor_reshape_shape_op_t op; + return op; + } +}; + template <> struct tensor_op_reader { tensor_resize_image_op_t operator()(NNCASE_UNUSED span_reader &reader) const { @@ -1373,6 +1389,14 @@ template <> struct tensor_op_reader { } }; +template <> struct tensor_op_reader { + tensor_squeeze_shape_op_t + operator()(NNCASE_UNUSED span_reader &reader) const { + tensor_squeeze_shape_op_t op; + return op; + } +}; + template <> struct tensor_op_reader { tensor_stack_op_t operator()(NNCASE_UNUSED span_reader &reader) const { tensor_stack_op_t op; @@ -1408,6 +1432,14 @@ template <> struct tensor_op_reader { } }; +template <> struct tensor_op_reader { + tensor_transpose_shape_op_t + operator()(NNCASE_UNUSED span_reader &reader) const { + tensor_transpose_shape_op_t op; + return op; + } +}; + template <> struct tensor_op_reader { tensor_trilu_op_t operator()(NNCASE_UNUSED span_reader &reader) const { tensor_trilu_op_t op; @@ -1447,6 +1479,14 @@ template <> struct tensor_op_reader { } }; +template <> struct tensor_op_reader { + tensor_unsqueeze_shape_op_t + operator()(NNCASE_UNUSED span_reader &reader) const { + tensor_unsqueeze_shape_op_t op; + return op; + } +}; + template <> struct tensor_op_reader { tensor_where_op_t operator()(NNCASE_UNUSED span_reader &reader) const { tensor_where_op_t op; @@ -1589,6 +1629,10 @@ class NNCASE_API tensor_op_visitor { return default_visit(tensor_function_t::get_item, &op); } virtual result + visit(NNCASE_UNUSED const tensor_get_paddings_op_t &op) noexcept { + return default_visit(tensor_function_t::get_paddings, &op); + } + virtual result visit(NNCASE_UNUSED const tensor_hard_sigmoid_op_t &op) noexcept { return default_visit(tensor_function_t::hard_sigmoid, &op); } @@ -1717,6 +1761,10 @@ class NNCASE_API tensor_op_visitor { return default_visit(tensor_function_t::reshape, &op); } virtual result + visit(NNCASE_UNUSED const tensor_reshape_shape_op_t &op) noexcept { + return default_visit(tensor_function_t::reshape_shape, &op); + } + virtual result visit(NNCASE_UNUSED const tensor_resize_image_op_t &op) noexcept { return default_visit(tensor_function_t::resize_image, &op); } @@ -1777,6 +1825,10 @@ class NNCASE_API tensor_op_visitor { return default_visit(tensor_function_t::squeeze, &op); } virtual result + visit(NNCASE_UNUSED const tensor_squeeze_shape_op_t &op) noexcept { + return default_visit(tensor_function_t::squeeze_shape, &op); + } + virtual result visit(NNCASE_UNUSED const tensor_stack_op_t &op) noexcept { return default_visit(tensor_function_t::stack, &op); } @@ -1797,6 +1849,10 @@ class NNCASE_API tensor_op_visitor { return default_visit(tensor_function_t::transpose, &op); } virtual result + visit(NNCASE_UNUSED const tensor_transpose_shape_op_t &op) noexcept { + return default_visit(tensor_function_t::transpose_shape, &op); + } + virtual result visit(NNCASE_UNUSED const tensor_trilu_op_t &op) noexcept { return default_visit(tensor_function_t::trilu, &op); } @@ -1817,6 +1873,10 @@ class NNCASE_API tensor_op_visitor { return default_visit(tensor_function_t::unsqueeze, &op); } virtual result + visit(NNCASE_UNUSED const tensor_unsqueeze_shape_op_t &op) noexcept { + return default_visit(tensor_function_t::unsqueeze_shape, &op); + } + virtual result visit(NNCASE_UNUSED const tensor_where_op_t &op) noexcept { return default_visit(tensor_function_t::where, &op); } diff --git a/src/Native/include/nncase/runtime/stackvm/opcode.h b/src/Native/include/nncase/runtime/stackvm/opcode.h index 26e98927f1..5c17c82894 100644 --- a/src/Native/include/nncase/runtime/stackvm/opcode.h +++ b/src/Native/include/nncase/runtime/stackvm/opcode.h @@ -1,4 +1,4 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:38 +/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 * +08:00. * * Copyright 2019-2021 Canaan Inc. @@ -136,29 +136,29 @@ enum class tensor_function_t : uint16_t { elu = 20, erf = 21, gelu = 30, - hardmax = 32, - hard_sigmoid = 33, - hard_swish = 34, - instance_normalization = 36, - l2_normalization = 37, - layer_norm = 38, - leaky_relu = 39, - log_softmax = 40, - lp_normalization = 41, - lrn = 42, - one_hot = 48, - pad = 49, - prelu = 50, - reduce_window2d = 59, - relu = 60, - relu6 = 61, - selu = 68, - sigmoid = 70, - softmax = 73, - softplus = 74, - softsign = 75, - space_to_batch = 76, - swish = 80, + hardmax = 33, + hard_sigmoid = 34, + hard_swish = 35, + instance_normalization = 37, + l2_normalization = 38, + layer_norm = 39, + leaky_relu = 40, + log_softmax = 41, + lp_normalization = 42, + lrn = 43, + one_hot = 49, + pad = 50, + prelu = 51, + reduce_window2d = 60, + relu = 61, + relu6 = 62, + selu = 70, + sigmoid = 72, + softmax = 75, + softplus = 76, + softsign = 77, + space_to_batch = 78, + swish = 83, binary = 2, clamp = 9, compare = 10, @@ -167,15 +167,15 @@ enum class tensor_function_t : uint16_t { dequantize = 19, fake_dequantize = 23, fake_quantize = 24, - mat_mul = 44, - quantize = 52, - quant_param_of = 53, - range_of = 55, - reduce = 57, - reduce_arg = 58, - require = 62, - select = 67, - unary = 85, + mat_mul = 45, + quantize = 53, + quant_param_of = 54, + range_of = 56, + reduce = 58, + reduce_arg = 59, + require = 63, + select = 69, + unary = 89, bitcast = 3, broadcast = 4, bucket_pad = 6, @@ -189,35 +189,40 @@ enum class tensor_function_t : uint16_t { gather_elements = 28, gather_nd = 29, get_item = 31, - index_of = 35, - lstm = 43, - prod = 51, - range = 54, - rank = 56, - reshape = 63, - reverse_sequence = 65, - scatter_nd = 66, - shape_of = 69, - size_of = 71, - slice = 72, - split = 77, - squeeze = 78, - stack = 79, - tile = 81, - top_k = 82, - transpose = 83, - trilu = 84, - unsqueeze = 88, - where = 89, + index_of = 36, + lstm = 44, + prod = 52, + range = 55, + rank = 57, + reshape = 64, + reverse_sequence = 67, + scatter_nd = 68, + shape_of = 71, + size_of = 73, + slice = 74, + split = 79, + squeeze = 80, + stack = 82, + tile = 84, + top_k = 85, + transpose = 86, + trilu = 88, + unsqueeze = 92, + where = 94, broadcast_shape = 5, conv2d_shape = 15, conv2d_transpose_shape = 17, - mat_mul_shape = 45, - normal = 46, - normal_like = 47, - uniform = 86, - uniform_like = 87, - resize_image = 64, + get_paddings = 32, + mat_mul_shape = 46, + reshape_shape = 65, + squeeze_shape = 81, + transpose_shape = 87, + unsqueeze_shape = 93, + normal = 47, + normal_like = 48, + uniform = 90, + uniform_like = 91, + resize_image = 66, }; enum class binary_op_t : uint8_t { @@ -663,6 +668,8 @@ struct tensor_gelu_op_t {}; struct tensor_get_item_op_t {}; +struct tensor_get_paddings_op_t {}; + struct tensor_hard_sigmoid_op_t {}; struct tensor_hard_swish_op_t {}; @@ -758,6 +765,8 @@ struct tensor_require_op_t { struct tensor_reshape_op_t {}; +struct tensor_reshape_shape_op_t {}; + struct tensor_resize_image_op_t { image_resize_mode_t resize_mode; image_resize_transformation_mode_t transformation_mode; @@ -793,6 +802,8 @@ struct tensor_split_op_t {}; struct tensor_squeeze_op_t {}; +struct tensor_squeeze_shape_op_t {}; + struct tensor_stack_op_t {}; struct tensor_swish_op_t {}; @@ -803,6 +814,8 @@ struct tensor_top_k_op_t {}; struct tensor_transpose_op_t {}; +struct tensor_transpose_shape_op_t {}; + struct tensor_trilu_op_t {}; struct tensor_unary_op_t { @@ -819,6 +832,8 @@ struct tensor_uniform_like_op_t { struct tensor_unsqueeze_op_t {}; +struct tensor_unsqueeze_shape_op_t {}; + struct tensor_where_op_t { bool is_tf_where; }; @@ -993,8 +1008,18 @@ inline std::string to_string(tensor_function_t tensor_funct) { return "conv2d_shape"; case tensor_function_t::conv2d_transpose_shape: return "conv2d_transpose_shape"; + case tensor_function_t::get_paddings: + return "get_paddings"; case tensor_function_t::mat_mul_shape: return "mat_mul_shape"; + case tensor_function_t::reshape_shape: + return "reshape_shape"; + case tensor_function_t::squeeze_shape: + return "squeeze_shape"; + case tensor_function_t::transpose_shape: + return "transpose_shape"; + case tensor_function_t::unsqueeze_shape: + return "unsqueeze_shape"; case tensor_function_t::normal: return "normal"; case tensor_function_t::normal_like: diff --git a/src/Native/src/kernels/stackvm/reference/pad.cpp b/src/Native/src/kernels/stackvm/reference/pad.cpp index 6186101e8a..2b27400fab 100644 --- a/src/Native/src/kernels/stackvm/reference/pad.cpp +++ b/src/Native/src/kernels/stackvm/reference/pad.cpp @@ -162,7 +162,7 @@ void padding_impl_opt(T *in, T *out, gsl::span in_shape, dh = out_shape[0]; hh = out_shape[1]; wh = out_shape[2]; - } else { + } else if (in_shape.size() == 4) { cl = in_shape[0]; dl = in_shape[1]; hl = in_shape[2]; @@ -171,6 +171,16 @@ void padding_impl_opt(T *in, T *out, gsl::span in_shape, dh = out_shape[1]; hh = out_shape[2]; wh = out_shape[3]; + } else // size ==2 + { + cl = 1; + dl = 1; + hl = in_shape[0]; + wl = in_shape[1]; + ch = 1; + dh = 1; + hh = out_shape[0]; + wh = out_shape[1]; } pad_data2(in, out, cl, dl, hl, wl, ch, dh, hh, wh, value); @@ -216,7 +226,7 @@ result nncase::kernels::stackvm::reference::pad( std::all_of( paddings.begin(), paddings.end(), [](const padding &p) { return p.before == 0 && p.after >= 0; }) && - mode == pad_mode_t::constant && in_shape.size() >= 3; + mode == pad_mode_t::constant && in_shape.size() >= 2; if (std::all_of(paddings.begin(), paddings.end(), [](const padding &p) { return p.interior == 0; })) { diff --git a/src/Native/src/kernels/stackvm/shape_ops.cpp b/src/Native/src/kernels/stackvm/shape_ops.cpp index 7677b10c2b..991b16950c 100644 --- a/src/Native/src/kernels/stackvm/shape_ops.cpp +++ b/src/Native/src/kernels/stackvm/shape_ops.cpp @@ -106,6 +106,12 @@ result nncase::kernels::stackvm::broadcast_shape(value_t inputs, KERNEL_FINISH; } +#define WRITE_OUT_SHAPE \ + try_output(out_mem, output, dt_int64, dims_t{out_shape.size()}); \ + for (int i = 0; i < out_shape.size(); ++i) { \ + OUT_CAST(int64_t, out_mem)[i] = out_shape[i]; \ + } + result nncase::kernels::stackvm::mat_mul_shape(value_t lhs, value_t rhs, value_t output, @@ -113,9 +119,105 @@ result nncase::kernels::stackvm::mat_mul_shape(value_t lhs, try_dims(lhs_shape, lhs); try_dims(rhs_shape, rhs); try_var(out_shape, matmul_infer_shape(lhs_shape, rhs_shape)); - try_output(out_mem, output, dt_int64, dims_t{out_shape.size()}); - for (int i = 0; i < out_shape.size(); ++i) { - OUT_CAST(int64_t, out_mem)[i] = out_shape[i]; + WRITE_OUT_SHAPE; + KERNEL_FINISH; +} + +inline int get_windowed_output_size(int size, int filter, int stride, + int dilation, bool same, bool ceilMode) { + auto effectiveFilterSize = ((filter - 1) * dilation) + 1; + auto falseBranch = !ceilMode + ? ((size - effectiveFilterSize + stride) / stride) + : ceil(size - effectiveFilterSize + stride / stride); + auto trueBranch = (size + stride - 1) / stride; + return same ? trueBranch : falseBranch; +} + +inline padding get_windowed_padding(int32_t input_size, int32_t output_size, + int32_t filter, int32_t stride, + int32_t dilation, bool lower) { + auto effective_filter_size = (filter - 1) * dilation + 1; + int padding = std::max(0, (output_size - 1) * stride + + effective_filter_size - input_size); + auto before = padding / 2; + auto after = padding - padding / 2; + if (lower) { + return {std::max(before, after), std::min(before, after)}; } + return {before, after}; +} + +result nncase::kernels::stackvm::get_paddings( + value_t input_shape, value_t weights_shape, value_t strides, + value_t dilations, value_t same, value_t lower, value_t output, + [[maybe_unused]] kernel_context &) { + try_dims(in_shape, input_shape); + try_dims(w_shape, weights_shape); + try_strides(strides_value, strides); + try_strides(dilations_value, dilations); + try_to_scalar_v(same, bool); + try_to_scalar_v(lower, bool); + auto out_h = + get_windowed_output_size(in_shape[2], w_shape[2], strides_value[0], + dilations_value[0], same_value, false); + auto out_w = + get_windowed_output_size(in_shape[3], w_shape[3], strides_value[1], + dilations_value[1], same_value, false); + auto pad_h = + get_windowed_padding(in_shape[2], out_h, w_shape[2], strides_value[0], + dilations_value[0], lower_value); + auto pad_w = + get_windowed_padding(in_shape[3], out_w, w_shape[3], strides_value[1], + dilations_value[1], lower_value); + auto out_shape = dims_t{2, 2}; + try_out_mem(output, dt_int64, out_shape); + OUT_CAST(int64_t, output_mem)[0] = pad_h.before; + OUT_CAST(int64_t, output_mem)[1] = pad_h.after; + OUT_CAST(int64_t, output_mem)[2] = pad_w.before; + OUT_CAST(int64_t, output_mem)[3] = pad_w.after; + KERNEL_FINISH; +} + +result nncase::kernels::stackvm::reshape_shape(value_t input_shape, + value_t shape, + value_t output, + kernel_context &) { + try_dims(in_shape, input_shape); + try_axes(shape_value, shape); + auto out_shape = reshape_shape_infer(in_shape, shape_value); + WRITE_OUT_SHAPE; + KERNEL_FINISH; +} + +result +nncase::kernels::stackvm::transpose_shape(value_t input_shape, value_t perm, + value_t output, + [[maybe_unused]] kernel_context &) { + try_dims(in_shape, input_shape); + try_dims(perm_value, perm); + auto out_shape = transpose_infer_shape(in_shape, perm_value); + WRITE_OUT_SHAPE; + KERNEL_FINISH; +} + +result +nncase::kernels::stackvm::squeeze_shape(value_t input_shape, value_t dim, + value_t output, + [[maybe_unused]] kernel_context &) { + try_dims(in_shape, input_shape); + try_positive_axes(dim_value, dim, in_shape.size()); + auto out_shape = squeeze_infer_shape(in_shape, dim_value); + WRITE_OUT_SHAPE; + KERNEL_FINISH; +} + +result +nncase::kernels::stackvm::unsqueeze_shape(value_t input_shape, value_t dim, + value_t output, + [[maybe_unused]] kernel_context &) { + try_dims(in_shape, input_shape); + try_axes(dim_value, dim); + auto out_shape = unsqueeze_infer_shape(in_shape, dim_value); + WRITE_OUT_SHAPE; KERNEL_FINISH; } \ No newline at end of file diff --git a/src/Native/src/runtime/stackvm/op_reader.cpp b/src/Native/src/runtime/stackvm/op_reader.cpp index 04776e0f7d..901f0f6125 100644 --- a/src/Native/src/runtime/stackvm/op_reader.cpp +++ b/src/Native/src/runtime/stackvm/op_reader.cpp @@ -1,4 +1,4 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39 +/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 * +08:00. * * Copyright 2019-2021 Canaan Inc. @@ -99,6 +99,9 @@ result tensor_op_visitor::visit(tensor_function_t tensor_funct, return visit(tensor_op_reader()(reader)); case tensor_function_t::get_item: return visit(tensor_op_reader()(reader)); + case tensor_function_t::get_paddings: + return visit( + tensor_op_reader()(reader)); case tensor_function_t::hard_sigmoid: return visit( tensor_op_reader()(reader)); @@ -173,6 +176,9 @@ result tensor_op_visitor::visit(tensor_function_t tensor_funct, return visit(tensor_op_reader()(reader)); case tensor_function_t::reshape: return visit(tensor_op_reader()(reader)); + case tensor_function_t::reshape_shape: + return visit( + tensor_op_reader()(reader)); case tensor_function_t::resize_image: return visit( tensor_op_reader()(reader)); @@ -206,6 +212,9 @@ result tensor_op_visitor::visit(tensor_function_t tensor_funct, return visit(tensor_op_reader()(reader)); case tensor_function_t::squeeze: return visit(tensor_op_reader()(reader)); + case tensor_function_t::squeeze_shape: + return visit( + tensor_op_reader()(reader)); case tensor_function_t::stack: return visit(tensor_op_reader()(reader)); case tensor_function_t::swish: @@ -216,6 +225,9 @@ result tensor_op_visitor::visit(tensor_function_t tensor_funct, return visit(tensor_op_reader()(reader)); case tensor_function_t::transpose: return visit(tensor_op_reader()(reader)); + case tensor_function_t::transpose_shape: + return visit( + tensor_op_reader()(reader)); case tensor_function_t::trilu: return visit(tensor_op_reader()(reader)); case tensor_function_t::unary: @@ -227,6 +239,9 @@ result tensor_op_visitor::visit(tensor_function_t tensor_funct, tensor_op_reader()(reader)); case tensor_function_t::unsqueeze: return visit(tensor_op_reader()(reader)); + case tensor_function_t::unsqueeze_shape: + return visit( + tensor_op_reader()(reader)); case tensor_function_t::where: return visit(tensor_op_reader()(reader)); default: diff --git a/src/Native/src/runtime/stackvm/ops/tensor.cpp b/src/Native/src/runtime/stackvm/ops/tensor.cpp index 0439e7887c..6f09a7084c 100644 --- a/src/Native/src/runtime/stackvm/ops/tensor.cpp +++ b/src/Native/src/runtime/stackvm/ops/tensor.cpp @@ -1,4 +1,4 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39 +/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 * +08:00. * * Copyright 2019-2021 Canaan Inc. @@ -564,6 +564,29 @@ result stackvm_runtime_function::visit( return ok(); } +result stackvm_runtime_function::visit( + [[maybe_unused]] const tensor_get_paddings_op_t &op) noexcept { + dump_op("get_paddings"); + try_var(input_shape, pop_value()); + dump_input(input_shape); + try_var(weights_shape, pop_value()); + dump_input(weights_shape); + try_var(strides, pop_value()); + dump_input(strides); + try_var(dilations, pop_value()); + dump_input(dilations); + try_var(same, pop_value()); + dump_input(same); + try_var(lower, pop_value()); + dump_input(lower); + try_var(output, kernels::stackvm::get_paddings( + input_shape, weights_shape, strides, dilations, same, + lower, nullptr, module().kernel_context())); + dump_output(output); + stack_.push(std::move(output)); + return ok(); +} + result stackvm_runtime_function::visit( [[maybe_unused]] const tensor_hard_sigmoid_op_t &op) noexcept { dump_op("hard_sigmoid"); @@ -1091,6 +1114,20 @@ result stackvm_runtime_function::visit( return ok(); } +result stackvm_runtime_function::visit( + [[maybe_unused]] const tensor_reshape_shape_op_t &op) noexcept { + dump_op("reshape_shape"); + try_var(input_shape, pop_value()); + dump_input(input_shape); + try_var(shape, pop_value()); + dump_input(shape); + try_var(output, kernels::stackvm::reshape_shape(input_shape, shape, nullptr, + module().kernel_context())); + dump_output(output); + stack_.push(std::move(output)); + return ok(); +} + result stackvm_runtime_function::visit( [[maybe_unused]] const tensor_resize_image_op_t &op) noexcept { dump_op("resize_image"); @@ -1327,6 +1364,20 @@ result stackvm_runtime_function::visit( return ok(); } +result stackvm_runtime_function::visit( + [[maybe_unused]] const tensor_squeeze_shape_op_t &op) noexcept { + dump_op("squeeze_shape"); + try_var(input_shape, pop_value()); + dump_input(input_shape); + try_var(dim, pop_value()); + dump_input(dim); + try_var(output, kernels::stackvm::squeeze_shape(input_shape, dim, nullptr, + module().kernel_context())); + dump_output(output); + stack_.push(std::move(output)); + return ok(); +} + result stackvm_runtime_function::visit( [[maybe_unused]] const tensor_stack_op_t &op) noexcept { dump_op("stack"); @@ -1402,6 +1453,20 @@ result stackvm_runtime_function::visit( return ok(); } +result stackvm_runtime_function::visit( + [[maybe_unused]] const tensor_transpose_shape_op_t &op) noexcept { + dump_op("transpose_shape"); + try_var(input_shape, pop_value()); + dump_input(input_shape); + try_var(perm, pop_value()); + dump_input(perm); + try_var(output, kernels::stackvm::transpose_shape( + input_shape, perm, nullptr, module().kernel_context())); + dump_output(output); + stack_.push(std::move(output)); + return ok(); +} + result stackvm_runtime_function::visit( [[maybe_unused]] const tensor_trilu_op_t &op) noexcept { dump_op("trilu"); @@ -1482,6 +1547,20 @@ result stackvm_runtime_function::visit( return ok(); } +result stackvm_runtime_function::visit( + [[maybe_unused]] const tensor_unsqueeze_shape_op_t &op) noexcept { + dump_op("unsqueeze_shape"); + try_var(input_shape, pop_value()); + dump_input(input_shape); + try_var(dim, pop_value()); + dump_input(dim); + try_var(output, kernels::stackvm::unsqueeze_shape( + input_shape, dim, nullptr, module().kernel_context())); + dump_output(output); + stack_.push(std::move(output)); + return ok(); +} + result stackvm_runtime_function::visit( [[maybe_unused]] const tensor_where_op_t &op) noexcept { dump_op("where"); diff --git a/src/Native/src/runtime/stackvm/runtime_function_ops.h b/src/Native/src/runtime/stackvm/runtime_function_ops.h index 02841ebe35..ae6944ef59 100644 --- a/src/Native/src/runtime/stackvm/runtime_function_ops.h +++ b/src/Native/src/runtime/stackvm/runtime_function_ops.h @@ -1,4 +1,4 @@ -/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39 +/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29 * +08:00. * * Copyright 2019-2021 Canaan Inc. @@ -49,6 +49,7 @@ result visit(const tensor_gather_elements_op_t &op) noexcept override; result visit(const tensor_gather_nd_op_t &op) noexcept override; result visit(const tensor_gelu_op_t &op) noexcept override; result visit(const tensor_get_item_op_t &op) noexcept override; +result visit(const tensor_get_paddings_op_t &op) noexcept override; result visit(const tensor_hard_sigmoid_op_t &op) noexcept override; result visit(const tensor_hard_swish_op_t &op) noexcept override; result visit(const tensor_hardmax_op_t &op) noexcept override; @@ -82,6 +83,7 @@ result visit(const tensor_relu_op_t &op) noexcept override; result visit(const tensor_relu6_op_t &op) noexcept override; result visit(const tensor_require_op_t &op) noexcept override; result visit(const tensor_reshape_op_t &op) noexcept override; +result visit(const tensor_reshape_shape_op_t &op) noexcept override; result visit(const tensor_resize_image_op_t &op) noexcept override; result visit(const tensor_reverse_sequence_op_t &op) noexcept override; result visit(const tensor_scatter_nd_op_t &op) noexcept override; @@ -97,14 +99,17 @@ result visit(const tensor_softsign_op_t &op) noexcept override; result visit(const tensor_space_to_batch_op_t &op) noexcept override; result visit(const tensor_split_op_t &op) noexcept override; result visit(const tensor_squeeze_op_t &op) noexcept override; +result visit(const tensor_squeeze_shape_op_t &op) noexcept override; result visit(const tensor_stack_op_t &op) noexcept override; result visit(const tensor_swish_op_t &op) noexcept override; result visit(const tensor_tile_op_t &op) noexcept override; result visit(const tensor_top_k_op_t &op) noexcept override; result visit(const tensor_transpose_op_t &op) noexcept override; +result visit(const tensor_transpose_shape_op_t &op) noexcept override; result visit(const tensor_trilu_op_t &op) noexcept override; result visit(const tensor_unary_op_t &op) noexcept override; result visit(const tensor_uniform_op_t &op) noexcept override; result visit(const tensor_uniform_like_op_t &op) noexcept override; result visit(const tensor_unsqueeze_op_t &op) noexcept override; +result visit(const tensor_unsqueeze_shape_op_t &op) noexcept override; result visit(const tensor_where_op_t &op) noexcept override; diff --git a/src/Nncase.Compiler/Compiler.cs b/src/Nncase.Compiler/Compiler.cs index 74eccd2190..f4014fa047 100644 --- a/src/Nncase.Compiler/Compiler.cs +++ b/src/Nncase.Compiler/Compiler.cs @@ -12,6 +12,7 @@ using Nncase.Evaluator; using Nncase.Hosting; using Nncase.IR; +using Nncase.IR.NN; using Nncase.IR.Tensors; using Nncase.Passes; using Nncase.Passes.Mutators; @@ -118,7 +119,29 @@ public void TargetIndependentPass(IPassManager passManager) p.Add(); p.Add(); p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); + p.Add(); }); + passManager.AddWithName("NeutralOptimizeTranspose").Configure(p => { p.Add(); @@ -180,24 +203,30 @@ public void TargetIndependentPass(IPassManager passManager) public void RegisterShapeBucket(IPassManager p) { var options = _compileSession.CompileOptions.ShapeBucketOptions; - var singleVar = options.VarMap.Values.SelectMany(x => x).OfType().ToHashSet().Count <= 1; if (!options.Enable) { return; } + var singleVar = options.VarMap.Values.SelectMany(x => x).OfType().ToHashSet().Count <= 1; CheckShapeBucketOptions(options); - ToFusion(p); - MergeOp(p); - LostToFusion(p, singleVar); - MergeOp(p); - ClearMarker(p); - - // MergeFusion(p, singleVar); - Bucket(p); - // Rebuild(p); - Simplify(p); + if (HasNotBucketOp(_module!.Entry!) || !singleVar) + { + ToFusion(p); + MergeOp(p, true); + LostToFusion(p, singleVar); + MergeOp(p, true); + ClearMarker(p); + MergeFusion(p, singleVar, true); + Bucket(p); + Rebuild(p, singleVar); + Simplify(p); + } + else + { + p.AddWithName("FullBucket"); + } } public void ClearFixShape(IPassManager p) diff --git a/src/Nncase.Core/IR/ShapeExpr/Functional.cs b/src/Nncase.Core/IR/ShapeExpr/Functional.cs index 8631e5e26f..555e9fa367 100644 --- a/src/Nncase.Core/IR/ShapeExpr/Functional.cs +++ b/src/Nncase.Core/IR/ShapeExpr/Functional.cs @@ -14,4 +14,14 @@ public static class ShapeExpr public static Call Conv2DTransposeShape(Expr input, Expr weights, Expr stride, Expr dilation, Expr padding, Expr outputPadding, Expr groups) => new(new Conv2DTransposeShape(), input, weights, stride, dilation, padding, outputPadding, groups); public static Call MatMulShape(Expr lhs, Expr rhs) => new(new MatMulShape(), lhs, rhs); + + public static Call GetPaddings(Expr input, Expr weights, Expr strides, Expr dilation, Expr same, Expr lower) => new(new GetPaddings(), input, weights, strides, dilation, same, lower); + + public static Call ReshapeShape(Expr input, Expr shape) => new(new ReshapeShape(), input, shape); + + public static Call SqueezeShape(Expr input, Expr dims) => new(new SqueezeShape(), input, dims); + + public static Call TransposeShape(Expr input, Expr perm) => new(new TransposeShape(), input, perm); + + public static Call UnsqueezeShape(Expr lhs, Expr rhs) => new(new UnsqueezeShape(), lhs, rhs); } diff --git a/src/Nncase.Core/IR/ShapeExpr/GetPaddings.cs b/src/Nncase.Core/IR/ShapeExpr/GetPaddings.cs new file mode 100644 index 0000000000..5bed9f4660 --- /dev/null +++ b/src/Nncase.Core/IR/ShapeExpr/GetPaddings.cs @@ -0,0 +1,40 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.PatternMatch; + +namespace Nncase.IR.ShapeExpr; + +[PatternFunctionalGenerator] +public class GetPaddings : Op +{ + /// + /// Gets Input. + /// + public static readonly ParameterInfo InputShape = new(typeof(GetPaddings), 0, "input_shape"); + + /// + /// Gets Weights. + /// + public static readonly ParameterInfo WeightsShape = new(typeof(GetPaddings), 1, "weights_shape"); + + /// + /// Gets Strides. + /// + public static readonly ParameterInfo Strides = new(typeof(GetPaddings), 2, "strides"); + + /// + /// Gets Dilations. + /// + public static readonly ParameterInfo Dilations = new(typeof(GetPaddings), 3, "dilations"); + + /// + /// Gets Same. + /// + public static readonly ParameterInfo Same = new(typeof(GetPaddings), 4, "same"); + + /// + /// Gets Lower. + /// + public static readonly ParameterInfo Lower = new(typeof(GetPaddings), 5, "lower"); +} diff --git a/src/Nncase.Core/IR/ShapeExpr/ReshapeShape.cs b/src/Nncase.Core/IR/ShapeExpr/ReshapeShape.cs new file mode 100644 index 0000000000..f6db6836a0 --- /dev/null +++ b/src/Nncase.Core/IR/ShapeExpr/ReshapeShape.cs @@ -0,0 +1,24 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.IR.ShapeExpr; + +/// +/// Reshape expression. +/// +[PatternFunctionalGenerator] +public sealed partial class ReshapeShape : Op +{ + /// + /// Gets input shape. + /// + public static readonly ParameterInfo InputShape = new(typeof(ReshapeShape), 0, "input_shape"); + + /// + /// Gets shape. + /// + public static readonly ParameterInfo Shape = new(typeof(ReshapeShape), 1, "shape", HasRank(1)); +} diff --git a/src/Nncase.Core/IR/ShapeExpr/SqueezeShape.cs b/src/Nncase.Core/IR/ShapeExpr/SqueezeShape.cs new file mode 100644 index 0000000000..55338691b6 --- /dev/null +++ b/src/Nncase.Core/IR/ShapeExpr/SqueezeShape.cs @@ -0,0 +1,24 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.IR.ShapeExpr; + +/// +/// Squeeze expression. +/// +[PatternFunctionalGenerator] +public sealed partial class SqueezeShape : Op +{ + /// + /// Gets input shape. + /// + public static readonly ParameterInfo InputShape = new(typeof(SqueezeShape), 0, "input_shape"); + + /// + /// Gets dimension. + /// + public static readonly ParameterInfo Dim = new(typeof(SqueezeShape), 1, "dim", HasRank(1) & IsIntegral()); +} diff --git a/src/Nncase.Core/IR/ShapeExpr/TransposeShape.cs b/src/Nncase.Core/IR/ShapeExpr/TransposeShape.cs new file mode 100644 index 0000000000..c658e15dc6 --- /dev/null +++ b/src/Nncase.Core/IR/ShapeExpr/TransposeShape.cs @@ -0,0 +1,24 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.IR.ShapeExpr; + +/// +/// Gets input. +/// +[PatternFunctionalGenerator] +public sealed partial class TransposeShape : Op +{ + /// + /// Gets input. + /// + public static readonly ParameterInfo InputShape = new(typeof(TransposeShape), 0, "input"); + + /// + /// Gets perm. + /// + public static readonly ParameterInfo Perm = new(typeof(TransposeShape), 1, "perm", HasRank(1) & IsIntegral()); +} diff --git a/src/Nncase.Core/IR/ShapeExpr/UnsqueezeShape.cs b/src/Nncase.Core/IR/ShapeExpr/UnsqueezeShape.cs new file mode 100644 index 0000000000..ed1f396fef --- /dev/null +++ b/src/Nncase.Core/IR/ShapeExpr/UnsqueezeShape.cs @@ -0,0 +1,21 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; + +namespace Nncase.IR.ShapeExpr; + +[PatternFunctionalGenerator] +public sealed partial class UnsqueezeShape : Op +{ + /// + /// Gets input_shape. + /// + public static readonly ParameterInfo InputShape = new(typeof(UnsqueezeShape), 0, "input_shape"); + + /// + /// Gets dimension. + /// + public static readonly ParameterInfo Dim = new(typeof(UnsqueezeShape), 1, "dim", HasRank(1) & IsIntegral()); +} diff --git a/src/Nncase.Core/Utilities/ShapeExprUtility.cs b/src/Nncase.Core/Utilities/ShapeExprUtility.cs index 28d953a8ad..c1da04dca4 100644 --- a/src/Nncase.Core/Utilities/ShapeExprUtility.cs +++ b/src/Nncase.Core/Utilities/ShapeExprUtility.cs @@ -19,8 +19,8 @@ public static Expr BroadcastShape(Expr lhsShape, params Expr[] rhsShape) public static Expr Positive(Expr axis, Expr inShape) { var rank = new Call(new Rank(), inShape); - var i32Axis = Cast(axis, DataTypes.Int32); - return new If(i32Axis < 0, i32Axis + rank, i32Axis); + var i64Axis = Cast(axis, DataTypes.Int64); + return new If(i64Axis < 0L, i64Axis + rank, i64Axis); } public static Expr Slice(Expr shape, int begin, int end) @@ -35,28 +35,28 @@ public static Expr Slice(Expr shape, Expr begin, Expr end) public static Expr Replace(Expr shapeExpr, Expr index, Expr value) { - return SliceAndMerge(shapeExpr, index, value, 1); + return SliceAndMerge(shapeExpr, index, value, 1L); } public static Expr Insert(Expr shapeExpr, Expr index, Expr value) { if (shapeExpr.CheckedShape.IsScalar) { - return SliceAndMerge(StackScalar(shapeExpr), index, value, 0); + return SliceAndMerge(StackScalar(shapeExpr), index, value, 0L); } - return SliceAndMerge(shapeExpr, index, value, 0); + return SliceAndMerge(shapeExpr, index, value, 0L); } public static Expr ReplaceList(Expr shapeExpr, Expr list, Expr value) { - return SliceAndMerge(shapeExpr, list, value, 1, false); + return SliceAndMerge(shapeExpr, list, value, 1L, false); } public static Expr Remove(Expr shapeExpr, Expr index) { var front = Slice(shapeExpr, 0, index); - var last = Slice(shapeExpr, index + 1, int.MaxValue); + var last = Slice(shapeExpr, index + 1L, int.MaxValue); return Concat(new IR.Tuple(front, last), 0); } @@ -65,10 +65,17 @@ public static Expr StackOne(Expr expr) return Stack(new IR.Tuple(expr), 0); } - private static Expr SliceAndMerge(Expr shapeExpr, Expr index, Expr value, Expr indexOffset, bool valueIsList = true) + public static IValue GetShapeValue(Call call) { + call.InferenceType(); + return Value.FromTensor(call.CheckedShape.ToValueArray().Select(x => (long)x).ToArray()); + } + + private static Expr SliceAndMerge(Expr originShapeExpr, Expr index, Expr value, Expr indexOffset, bool valueIsList = true) + { + var shapeExpr = Cast(originShapeExpr, DataTypes.Int64); var front = Slice(shapeExpr, 0, index); - var last = Slice(shapeExpr, Cast(index, DataTypes.Int32) + indexOffset, int.MaxValue); + var last = Slice(shapeExpr, Cast(index, DataTypes.Int64) + indexOffset, int.MaxValue); var c = valueIsList ? StackOne(value) : value; if (c.CheckedShape.IsScalar) { diff --git a/src/Nncase.Evaluator/Math/Binary.cs b/src/Nncase.Evaluator/Math/Binary.cs index 0b5515ca28..f5f9d3c65b 100755 --- a/src/Nncase.Evaluator/Math/Binary.cs +++ b/src/Nncase.Evaluator/Math/Binary.cs @@ -118,7 +118,7 @@ public Expr Visit(IShapeEvaluateContext context, Binary target) { var lhs = context.GetArgumentShape(target, Binary.Lhs); var rhs = context.GetArgumentShape(target, Binary.Rhs); - return IR.F.Tensors.Cast(ShapeExprUtility.BroadcastShape(lhs, rhs), DataTypes.Int32); + return ShapeExprUtility.BroadcastShape(lhs, rhs); } private int Compute(BinaryOp op, int a, int b) => op switch diff --git a/src/Nncase.Evaluator/Math/MatMul.cs b/src/Nncase.Evaluator/Math/MatMul.cs index 444d221c28..4785e1e1c1 100644 --- a/src/Nncase.Evaluator/Math/MatMul.cs +++ b/src/Nncase.Evaluator/Math/MatMul.cs @@ -71,7 +71,7 @@ public Expr Visit(IShapeEvaluateContext context, MatMul target) { var lhs = context.GetArgumentShape(target, MatMul.Lhs); var rhs = context.GetArgumentShape(target, MatMul.Rhs); - return Cast(IR.F.ShapeExpr.MatMulShape(lhs, rhs), DataTypes.Int32); + return IR.F.ShapeExpr.MatMulShape(lhs, rhs); } private IRType Visit(TensorType lhs, TensorType rhs) diff --git a/src/Nncase.Evaluator/Math/Reduce.cs b/src/Nncase.Evaluator/Math/Reduce.cs index 69e512b717..ead03f4701 100644 --- a/src/Nncase.Evaluator/Math/Reduce.cs +++ b/src/Nncase.Evaluator/Math/Reduce.cs @@ -110,7 +110,7 @@ public Expr Visit(IShapeEvaluateContext context, Reduce target) { if (axes.Length == input.CheckedShape.Count && keepDimsValue == 0) { - return Array.Empty(); + return Array.Empty(); } } @@ -119,7 +119,7 @@ public Expr Visit(IShapeEvaluateContext context, Reduce target) var ax = ShapeExprUtility.Positive(axValue, inShape); if (keepDimsValue == 1) { - outShape = ShapeExprUtility.Replace(outShape, ax, 1); + outShape = ShapeExprUtility.Replace(outShape, ax, 1L); } else { @@ -127,7 +127,7 @@ public Expr Visit(IShapeEvaluateContext context, Reduce target) } } - return outShape; + return Cast(outShape, DataTypes.Int64); } throw new NotImplementedException(); diff --git a/src/Nncase.Evaluator/NN/BatchToSpace.cs b/src/Nncase.Evaluator/NN/BatchToSpace.cs index d60473b41a..e53297c7ab 100644 --- a/src/Nncase.Evaluator/NN/BatchToSpace.cs +++ b/src/Nncase.Evaluator/NN/BatchToSpace.cs @@ -115,13 +115,13 @@ public Expr Visit(IShapeEvaluateContext context, BatchToSpace target) inShape = Stack(new IR.Tuple(inShape[0], inShape[2], inShape[1]), 0); } - var blockShape = context.GetArgument(target, BatchToSpace.BlockShape); + var blockShape = Cast(context.GetArgument(target, BatchToSpace.BlockShape), DataTypes.Int64); if (!blockShape.CheckedShape.IsFixed) { throw new NotImplementedException(); } - var crops = context.GetArgument(target, BatchToSpace.Crops); + var crops = Cast(context.GetArgument(target, BatchToSpace.Crops), DataTypes.Int64); var blockSize = Prod(blockShape); var batch = inShape[0]; var d0 = batch / blockSize; @@ -131,7 +131,7 @@ public Expr Visit(IShapeEvaluateContext context, BatchToSpace target) var inRank = Cast(ShapeOf(inShape)[0], DataTypes.Int32); var remainSize = inRank - 1 - m; - var remainShape = new If(remainSize > 0, ShapeExprUtility.Slice(inShape, 1 + m, int.MaxValue), Array.Empty()); + var remainShape = new If(remainSize > 0, ShapeExprUtility.Slice(inShape, 1 + m, int.MaxValue), Array.Empty()); var outShapeList = Concat(new IR.Tuple(Stack(new IR.Tuple(new[] { d0 }), 0), Stack(new IR.Tuple(cropSection), 0), remainShape), 0); diff --git a/src/Nncase.Evaluator/NN/Conv2D.cs b/src/Nncase.Evaluator/NN/Conv2D.cs index 67f361adc8..c26e219821 100644 --- a/src/Nncase.Evaluator/NN/Conv2D.cs +++ b/src/Nncase.Evaluator/NN/Conv2D.cs @@ -92,10 +92,10 @@ public Expr Visit(IShapeEvaluateContext context, Conv2D target) { var input = context.GetArgumentShape(target, Conv2D.Input); var weights = context.GetArgumentShape(target, Conv2D.Weights); - var pad = Cast(context.GetArgument(target, Conv2D.Padding), DataTypes.Int32); - var stride = Cast(context.GetArgument(target, Conv2D.Stride), DataTypes.Int32); - var dilation = Cast(context.GetArgument(target, Conv2D.Dilation), DataTypes.Int32); - var groups = Cast(context.GetArgument(target, Conv2D.Groups), DataTypes.Int32); + var pad = context.GetArgument(target, Conv2D.Padding); + var stride = context.GetArgument(target, Conv2D.Stride); + var dilation = context.GetArgument(target, Conv2D.Dilation); + var groups = context.GetArgument(target, Conv2D.Groups); return IR.F.ShapeExpr.Conv2DShape(input, weights, pad, stride, dilation, groups); } diff --git a/src/Nncase.Evaluator/NN/Pad.cs b/src/Nncase.Evaluator/NN/Pad.cs index 13a4111456..5d33750659 100644 --- a/src/Nncase.Evaluator/NN/Pad.cs +++ b/src/Nncase.Evaluator/NN/Pad.cs @@ -118,11 +118,11 @@ public Expr Visit(IShapeEvaluateContext context, Pad target) var end = Slice(pads, new[] { 1 }, new[] { 2 }, new[] { 1 }, new[] { 1 }); // paddings = [4, 2] -> [4, 1] + [4, 1] - var paddings = front + end; + var paddings = Cast(front + end, DataTypes.Int64); // outShape = inShape + paddings - var padsSumShape = StackScalar(Cast(ShapeOf(paddings)[0], DataTypes.Int32)); - var outShape = inShape + Cast(Reshape(paddings, padsSumShape), DataTypes.Int32); + var padsSumShape = StackScalar(ShapeOf(paddings)[0]); + var outShape = inShape + Reshape(paddings, padsSumShape); return outShape; } diff --git a/src/Nncase.Evaluator/NN/SpaceToBatch.cs b/src/Nncase.Evaluator/NN/SpaceToBatch.cs index 756f8c1472..98ec949151 100644 --- a/src/Nncase.Evaluator/NN/SpaceToBatch.cs +++ b/src/Nncase.Evaluator/NN/SpaceToBatch.cs @@ -98,11 +98,11 @@ public Expr Visit(IShapeEvaluateContext context, SpaceToBatch target) { var inShape = context.GetArgumentShape(target, SpaceToBatch.Input); var blockShape = context.GetArgument(target, SpaceToBatch.BlockShape); - var padding = context.GetArgument(target, SpaceToBatch.Paddings); + var padding = Cast(context.GetArgument(target, SpaceToBatch.Paddings), DataTypes.Int64); var input = context.GetArgument(target, SpaceToBatch.Input); if (blockShape is TensorConst blockConst) { - var blockShapeValue = blockConst.Value.ToArray(); + var blockShapeValue = blockConst.Value.ToArray(); var m = blockShapeValue.Length; var inRank = input.CheckedShape.Rank; @@ -122,7 +122,7 @@ public Expr Visit(IShapeEvaluateContext context, SpaceToBatch target) }).ToArray(); var remainSize = inRank - 1 - m; - var remainShape = new If(remainSize > 0, ShapeExprUtility.Slice(inShape, 1 + m, int.MaxValue), Array.Empty()); + var remainShape = new If(remainSize > 0, ShapeExprUtility.Slice(inShape, 1 + m, int.MaxValue), Array.Empty()); var outLast = remainShape; var outShape = Concat(new IR.Tuple(Stack(new IR.Tuple(outFirst.Concat(outMid).ToArray()), 0), outLast), 0); return outShape; diff --git a/src/Nncase.Evaluator/ShapeEvaluateContext.cs b/src/Nncase.Evaluator/ShapeEvaluateContext.cs index e9919fb3d6..6f3e5ff94c 100644 --- a/src/Nncase.Evaluator/ShapeEvaluateContext.cs +++ b/src/Nncase.Evaluator/ShapeEvaluateContext.cs @@ -22,8 +22,13 @@ internal sealed class ShapeEvaluateContext : IShapeEvaluateContext public ShapeEvaluateContext(Dictionary memo, ShapeExprCache cache) { _memo = memo; - Cache = cache.Cache; + foreach (var (key, value) in cache.Cache) + { + _memo[key] = value; + } + VarMap = cache.VarMap; + Cache = new(); } public IReadOnlyDictionary VarMap { get; } @@ -55,7 +60,7 @@ public Expr GetArgumentShape(Op op, ParameterInfo parameter) var expr = GetArgument(op, parameter); if (expr is Tuple tuple) { - return new Tuple(tuple.Fields.ToArray().Select(v => Cast(GetResultFromMemo(v), DataTypes.Int32)).ToArray()); + return new Tuple(tuple.Fields.ToArray().Select(v => Cast(GetResultFromMemo(v), DataTypes.Int64)).ToArray()); } // call @@ -64,7 +69,7 @@ public Expr GetArgumentShape(Op op, ParameterInfo parameter) var shape = expr.EvaluateShapeExpr(new ShapeExprCache(VarMap)); if (shape is Call c && c.Target is IR.Math.Require && c.Arguments[IR.Math.Require.Value.Index] is Tuple tupleShapeExpr) { - return new Tuple(tupleShapeExpr.Fields.ToArray().Select(expr => Cast(expr, DataTypes.Int32)).ToArray()); + return new Tuple(tupleShapeExpr.Fields.ToArray().Select(expr => Cast(expr, DataTypes.Int64)).ToArray()); } // for split @@ -76,23 +81,23 @@ public Expr GetArgumentShape(Op op, ParameterInfo parameter) return new Tuple( Enumerable .Range(0, tupleType.Fields.Count) - .Select(i => Cast(shape[i], DataTypes.Int32)) + .Select(i => Cast(shape[i], DataTypes.Int64)) .ToArray()); } else { - return new Tuple(((Tuple)shape).Fields.ToArray().Select(expr => Cast(expr, DataTypes.Int32)).ToArray()); + return new Tuple(((Tuple)shape).Fields.ToArray().Select(expr => Cast(expr, DataTypes.Int64)).ToArray()); } } } var shapeExpr = GetResultFromMemo(expr); - return Cast(shapeExpr, DataTypes.Int32); + return Cast(shapeExpr, DataTypes.Int64); } public Expr GetArgumentRank(Op op, ParameterInfo parameter) { - return StackScalar(Cast(GetArgumentShape(op, parameter)[0], DataTypes.Int32)); + return StackScalar(Cast(GetArgumentShape(op, parameter)[0], DataTypes.Int64)); } private Call GetCurrentCall() => CurrentCall ?? throw new InvalidOperationException("Current call is not set."); diff --git a/src/Nncase.Evaluator/ShapeEvaluateVisitor.cs b/src/Nncase.Evaluator/ShapeEvaluateVisitor.cs index f5e8c3ee4d..04b24b5109 100644 --- a/src/Nncase.Evaluator/ShapeEvaluateVisitor.cs +++ b/src/Nncase.Evaluator/ShapeEvaluateVisitor.cs @@ -102,7 +102,7 @@ protected override Expr VisitLeafVar(Var expr) throw new InvalidOperationException(); } - var shapeExpr = shape.Select((x, i) => x.IsFixed ? x.FixedValue : _context.VarMap[expr][i]).Select(x => IR.F.Tensors.Cast(x, DataTypes.Int32)).ToArray(); + var shapeExpr = shape.Select((x, i) => x.IsFixed ? x.FixedValue : _context.VarMap[expr][i]).Select(x => IR.F.Tensors.Cast(x, DataTypes.Int64)).ToArray(); return IR.F.Tensors.Stack(new IR.Tuple(shapeExpr), 0); } diff --git a/src/Nncase.Evaluator/ShapeEvaluatorProvider.cs b/src/Nncase.Evaluator/ShapeEvaluatorProvider.cs index 6c9a785a5a..adb4d296d8 100644 --- a/src/Nncase.Evaluator/ShapeEvaluatorProvider.cs +++ b/src/Nncase.Evaluator/ShapeEvaluatorProvider.cs @@ -64,6 +64,7 @@ public Expr EvaluateOpShapeExpr(Op op, IShapeEvaluateContext context) var evaluatorType = typeof(IShapeEvaluator<>).MakeGenericType(op.GetType()); var evaluator = (IShapeEvaluator)_serviceProvider.GetRequiredService(evaluatorType); var result = evaluator.Visit(context, op); + var s = op.GetType().Name + "_" + op.DisplayProperty(); if (!result.InferenceType()) { if (DumpScope.Current.IsEnabled(DumpFlags.Compile)) @@ -71,7 +72,7 @@ public Expr EvaluateOpShapeExpr(Op op, IShapeEvaluateContext context) DumpScope.Current.DumpIR(result, "EvaluateOpShapeExprInvalidResult"); } - throw new InvalidOperationException(); + throw new InvalidOperationException(s); } return result; diff --git a/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs b/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs index 61cb3e9438..c7cb702bca 100644 --- a/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs +++ b/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs @@ -41,9 +41,8 @@ public Cost Visit(ICostEvaluateContext context, BroadcastShape target) public Expr Visit(IShapeEvaluateContext context, BroadcastShape target) { var inShape = context.GetArgumentShape(target, BroadcastShape.Inputs); - var len = ((IR.Tuple)inShape).Fields.ToArray().Aggregate((Expr)1, (i, call) => IR.F.Math.Max(i, call)); - var bn = IR.F.Tensors.Cast(len, DataTypes.Int32); - return bn; + var len = ((IR.Tuple)inShape).Fields.ToArray().Aggregate((Expr)1L, (i, call) => IR.F.Math.Max(i, call)); + return len; } public Metric Visit(IMetricEvaluateContext context, BroadcastShape target) diff --git a/src/Nncase.Evaluator/ShapeExpr/GetPaddings.cs b/src/Nncase.Evaluator/ShapeExpr/GetPaddings.cs new file mode 100644 index 0000000000..5d5d775ea8 --- /dev/null +++ b/src/Nncase.Evaluator/ShapeExpr/GetPaddings.cs @@ -0,0 +1,89 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Nncase.CostModel; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.ShapeExpr; +using Nncase.IR.Tensors; +using Nncase.Utilities; +using static Nncase.IR.F.Tensors; + +namespace Nncase.Evaluator.ShapeExpr; + +public partial class GetPaddingsEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IShapeEvaluator, IMetricEvaluator +{ + public static Expr ConcatPadding(Expr[] padH, Expr[] padW) + { + // return [[padh_before, padh_after], + // [padw_before, padw_after]] + return Stack( + new IR.Tuple( + Stack(new IR.Tuple(padH), 0), + Stack(new IR.Tuple(padW), 0)), + 0); + } + + public IValue Visit(IEvaluateContext context, GetPaddings target) + { + var inShape = context.GetArgumentValueAsArray(target, GetPaddings.InputShape); + var wShape = context.GetArgumentValueAsArray(target, GetPaddings.WeightsShape); + var strides = context.GetArgumentValueAsArray(target, GetPaddings.Strides); + var dilations = context.GetArgumentValueAsArray(target, GetPaddings.Dilations); + var same = context.GetArgumentValueAsScalar(target, GetPaddings.Same); + var lower = context.GetArgumentValueAsScalar(target, GetPaddings.Lower); + var padH = GetWindowedPadding(inShape[2], wShape[2], strides[0], dilations[0], same, lower); + var padW = GetWindowedPadding(inShape[3], wShape[3], strides[1], dilations[1], same, lower); + return ConcatPadding(padH, padW).Evaluate(); + } + + public IRType Visit(ITypeInferenceContext context, GetPaddings target) + { + return new TensorType(DataTypes.Int64, new[] { 2, 2 }); + } + + public Cost Visit(ICostEvaluateContext context, GetPaddings target) + { + return CostUtility.GetShapeExprCost(); + } + + public Expr Visit(IShapeEvaluateContext context, GetPaddings target) + { + return new[] { 2, 2 }; + } + + public Metric Visit(IMetricEvaluateContext context, GetPaddings target) + { + var returnType = context.GetReturnType(); + return new() + { + [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType), + }; + } + + private static Expr[] GetWindowedPadding(Expr inputSize, Expr filter, Expr stride, Expr dilation, bool same, bool lower = false) + { + var i32InputSize = Cast(inputSize, DataTypes.Int32); + var i32Filter = Cast(filter, DataTypes.Int32); + var i32Stride = Cast(stride, DataTypes.Int32); + var i32Dilation = Cast(dilation, DataTypes.Int32); + var outputSize = IR.Util.GetWindowedOutputSize(i32InputSize, i32Filter, i32Stride, i32Dilation, same, false); + return GetWindowedPaddingValue(i32InputSize, outputSize, i32Filter, i32Stride, i32Dilation, lower); + } + + private static Expr[] GetWindowedPaddingValue(Expr inputSize, Expr outputSize, Expr filter, Expr stride, Expr dilation, bool lower) + { + var effectiveFilterSize = ((filter - 1) * dilation) + 1; + var padding = IR.F.Math.Max(0, ((outputSize - 1) * stride) + effectiveFilterSize - inputSize); + var before = Cast(padding / 2, DataTypes.Int32); + var after = Cast(padding - (padding / 2), DataTypes.Int32); + if (lower) + { + return new[] { IR.F.Math.Max(before, after), IR.F.Math.Min(before, after) }; + } + + return new[] { before, after }; + } +} diff --git a/src/Nncase.Evaluator/ShapeExpr/ReshapeShape.cs b/src/Nncase.Evaluator/ShapeExpr/ReshapeShape.cs new file mode 100644 index 0000000000..c8b7d1a536 --- /dev/null +++ b/src/Nncase.Evaluator/ShapeExpr/ReshapeShape.cs @@ -0,0 +1,50 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Nncase.CostModel; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.ShapeExpr; +using Nncase.IR.Tensors; +using Nncase.Utilities; + +namespace Nncase.Evaluator.ShapeExpr; + +public partial class ReshapeShapeEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IShapeEvaluator, IMetricEvaluator +{ + public IValue Visit(IEvaluateContext context, ReshapeShape target) + { + var inShape = context.GetArgumentValueAsArray(target, ReshapeShape.InputShape); + var shape = context.GetArgumentValueAsTensor(target, ReshapeShape.Shape); + var t = IR.F.Tensors.Reshape(new Var(new TensorType(DataTypes.Float32, inShape)), shape); + return ShapeExprUtility.GetShapeValue(t); + } + + public IRType Visit(ITypeInferenceContext context, ReshapeShape target) + { + var shape = context.CheckArgumentType(target, ReshapeShape.Shape); + return new TensorType(DataTypes.Int64, shape.Shape.ToValueArray()); + } + + public Cost Visit(ICostEvaluateContext context, ReshapeShape target) + { + return CostUtility.GetShapeExprCost(); + } + + public Expr Visit(IShapeEvaluateContext context, ReshapeShape target) + { + var shape = context.GetArgument(target, ReshapeShape.Shape); + return shape.CheckedShape.ToValueArray(); + } + + public Metric Visit(IMetricEvaluateContext context, ReshapeShape target) + { + var returnType = context.GetReturnType(); + return new() + { + [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType), + }; + } +} diff --git a/src/Nncase.Evaluator/ShapeExpr/ShapeExprModule.cs b/src/Nncase.Evaluator/ShapeExpr/ShapeExprModule.cs index bbc3631372..e0a1ec3505 100644 --- a/src/Nncase.Evaluator/ShapeExpr/ShapeExprModule.cs +++ b/src/Nncase.Evaluator/ShapeExpr/ShapeExprModule.cs @@ -20,5 +20,10 @@ public void ConfigureServices(IRegistrator registrator) registrator.RegisterManyInterface(reuse: Reuse.Singleton); registrator.RegisterManyInterface(reuse: Reuse.Singleton); registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); + registrator.RegisterManyInterface(reuse: Reuse.Singleton); } } diff --git a/src/Nncase.Evaluator/ShapeExpr/SqueezeShape.cs b/src/Nncase.Evaluator/ShapeExpr/SqueezeShape.cs new file mode 100644 index 0000000000..df1c668bdf --- /dev/null +++ b/src/Nncase.Evaluator/ShapeExpr/SqueezeShape.cs @@ -0,0 +1,57 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Nncase.CostModel; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.ShapeExpr; +using Nncase.IR.Tensors; +using Nncase.Utilities; + +namespace Nncase.Evaluator.ShapeExpr; + +public partial class SqueezeShapeEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IShapeEvaluator, IMetricEvaluator +{ + public IValue Visit(IEvaluateContext context, SqueezeShape target) + { + var inShape = context.GetArgumentValueAsArray(target, SqueezeShape.InputShape); + var dims = context.GetArgumentValueAsTensor(target, SqueezeShape.Dim); + var t = IR.F.Tensors.Squeeze(new Var(new TensorType(DataTypes.Float32, inShape)), dims); + return ShapeExprUtility.GetShapeValue(t); + } + + public IRType Visit(ITypeInferenceContext context, SqueezeShape target) + { + var input = context.GetArgument(target, SqueezeShape.InputShape); + var dims = context.CheckArgumentType(target, SqueezeShape.Dim); + if (!input.CheckedShape.IsFixed) + { + return new TensorType(DataTypes.Int64, new[] { Dimension.Unknown }); + } + + return new TensorType(DataTypes.Int64, new[] { input.CheckedShape.Size - dims.Shape[0] }); + } + + public Cost Visit(ICostEvaluateContext context, SqueezeShape target) + { + return CostUtility.GetShapeExprCost(); + } + + public Expr Visit(IShapeEvaluateContext context, SqueezeShape target) + { + var input = context.GetArgument(target, SqueezeShape.InputShape); + var dims = context.GetArgument(target, SqueezeShape.Dim); + return new[] { input.CheckedShape.Size - dims.CheckedShape[0].FixedValue }; + } + + public Metric Visit(IMetricEvaluateContext context, SqueezeShape target) + { + var returnType = context.GetReturnType(); + return new() + { + [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType), + }; + } +} diff --git a/src/Nncase.Evaluator/ShapeExpr/TransposeShape.cs b/src/Nncase.Evaluator/ShapeExpr/TransposeShape.cs new file mode 100644 index 0000000000..49a1046627 --- /dev/null +++ b/src/Nncase.Evaluator/ShapeExpr/TransposeShape.cs @@ -0,0 +1,51 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using NetFabric.Hyperlinq; +using Nncase.CostModel; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.ShapeExpr; +using Nncase.IR.Tensors; +using Nncase.Utilities; + +namespace Nncase.Evaluator.ShapeExpr; + +public partial class TransposeShapeEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IShapeEvaluator, IMetricEvaluator +{ + public IValue Visit(IEvaluateContext context, TransposeShape target) + { + var inShape = context.GetArgumentValueAsArray(target, TransposeShape.InputShape); + var perm = context.GetArgumentValueAsTensor(target, TransposeShape.Perm); + var t = IR.F.Tensors.Transpose(new Var(new TensorType(DataTypes.Float32, inShape)), perm); + return ShapeExprUtility.GetShapeValue(t); + } + + public IRType Visit(ITypeInferenceContext context, TransposeShape target) + { + var tt = context.CheckArgumentType(target, TransposeShape.InputShape); + return new TensorType(DataTypes.Int64, new[] { tt.Shape[0] }); + } + + public Cost Visit(ICostEvaluateContext context, TransposeShape target) + { + return CostUtility.GetShapeExprCost(); + } + + public Expr Visit(IShapeEvaluateContext context, TransposeShape target) + { + var input = context.GetArgument(target, TransposeShape.Perm); + return input.CheckedShape[0].FixedValue; + } + + public Metric Visit(IMetricEvaluateContext context, TransposeShape target) + { + var returnType = context.GetReturnType(); + return new() + { + [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType), + }; + } +} diff --git a/src/Nncase.Evaluator/ShapeExpr/UnsqueezeShape.cs b/src/Nncase.Evaluator/ShapeExpr/UnsqueezeShape.cs new file mode 100644 index 0000000000..c195b6c2dc --- /dev/null +++ b/src/Nncase.Evaluator/ShapeExpr/UnsqueezeShape.cs @@ -0,0 +1,57 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Nncase.CostModel; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.ShapeExpr; +using Nncase.IR.Tensors; +using Nncase.Utilities; + +namespace Nncase.Evaluator.ShapeExpr; + +public partial class UnsqueezeShapeEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IShapeEvaluator, IMetricEvaluator +{ + public IValue Visit(IEvaluateContext context, UnsqueezeShape target) + { + var inShape = context.GetArgumentValueAsArray(target, UnsqueezeShape.InputShape); + var dims = context.GetArgumentValueAsTensor(target, UnsqueezeShape.Dim); + var t = IR.F.Tensors.Unsqueeze(new Var(new TensorType(DataTypes.Float32, inShape)), dims); + return ShapeExprUtility.GetShapeValue(t); + } + + public IRType Visit(ITypeInferenceContext context, UnsqueezeShape target) + { + var input = context.CheckArgumentType(target, UnsqueezeShape.InputShape); + var dims = context.CheckArgumentType(target, UnsqueezeShape.Dim); + if (!input.Shape.IsFixed) + { + return new TensorType(DataTypes.Int64, new[] { Dimension.Unknown }); + } + + return new TensorType(DataTypes.Int64, new[] { input.Shape.Size + dims.Shape[0] }); + } + + public Cost Visit(ICostEvaluateContext context, UnsqueezeShape target) + { + return CostUtility.GetShapeExprCost(); + } + + public Expr Visit(IShapeEvaluateContext context, UnsqueezeShape target) + { + var input = context.GetArgument(target, UnsqueezeShape.InputShape); + var dims = context.GetArgument(target, UnsqueezeShape.Dim); + return IR.F.Tensors.Stack(new IR.Tuple(new[] { IR.F.Tensors.ShapeOf(input)[0] + (long)dims.CheckedShape[0].FixedValue }), 0); + } + + public Metric Visit(IMetricEvaluateContext context, UnsqueezeShape target) + { + var returnType = context.GetReturnType(); + return new() + { + [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType), + }; + } +} diff --git a/src/Nncase.Evaluator/Tensors/Concat.cs b/src/Nncase.Evaluator/Tensors/Concat.cs index 994e76de22..d1098ccf7b 100644 --- a/src/Nncase.Evaluator/Tensors/Concat.cs +++ b/src/Nncase.Evaluator/Tensors/Concat.cs @@ -54,8 +54,8 @@ public Expr Visit(IShapeEvaluateContext context, Concat target) var inShape = context.GetArgumentShape(target, Concat.Input); var axis = context.GetArgument(target, Concat.Axis); var axisV = ShapeExprUtility.Positive(axis, inShape[0]); - var inShapes = ((IR.Tuple)inShape).Fields.ToArray().Select(x => Cast(x, DataTypes.Int32)).ToArray(); - var dim = inShapes.ToArray().Aggregate((Expr)0, (sum, shape) => sum + shape[axisV]); + var inShapes = ((IR.Tuple)inShape).Fields.ToArray().Select(x => Cast(x, DataTypes.Int64)).ToArray(); + var dim = inShapes.ToArray().Aggregate((Expr)0L, (sum, shape) => sum + shape[axisV]); var outShape = ShapeExprUtility.Replace(inShapes[0], axisV, dim); return outShape; } diff --git a/src/Nncase.Evaluator/Tensors/Range.cs b/src/Nncase.Evaluator/Tensors/Range.cs index 5f185fe350..dac1afb1d1 100644 --- a/src/Nncase.Evaluator/Tensors/Range.cs +++ b/src/Nncase.Evaluator/Tensors/Range.cs @@ -95,6 +95,6 @@ public Expr Visit(IShapeEvaluateContext context, Range target) var begin = context.GetArgument(target, Range.Begin); var end = context.GetArgument(target, Range.End); var step = context.GetArgument(target, Range.Step); - return ShapeExprUtility.StackOne((end - begin) / step); + return IR.F.Tensors.Cast(StackOne((end - begin) / step), DataTypes.Int64); } } diff --git a/src/Nncase.Evaluator/Tensors/Reshape.cs b/src/Nncase.Evaluator/Tensors/Reshape.cs index 615728f7c7..38c4d150ee 100644 --- a/src/Nncase.Evaluator/Tensors/Reshape.cs +++ b/src/Nncase.Evaluator/Tensors/Reshape.cs @@ -53,32 +53,9 @@ Cost ICostEvaluator.Visit(ICostEvaluateContext context, Reshape target) public Expr Visit(IShapeEvaluateContext context, Reshape target) { + var inShape = context.GetArgumentShape(target, Reshape.Input); var shape = context.GetArgument(target, Reshape.Shape); - var inputShape = Cast(context.GetArgumentShape(target, Reshape.Input), DataTypes.Int32); - if (shape is TensorConst shapeConst) - { - var shapeArray = shapeConst.Value.ToArray(); - var negIndex = shapeArray.IndexOf(-1); - if (negIndex < 0) - { - return shapeArray; - } - - var dim = Prod(inputShape) / System.Math.Abs(shapeArray.Aggregate((s, x) => x * s)); - var rhs = shapeArray.Select((_, i) => i == negIndex ? dim + 1 : (Expr)0).ToArray(); - var newShape = Stack(new IR.Tuple(rhs), 0); - - // dim = Product(inShape) / Produce(Reshape.Shape) - // e.g. [1, 3, -1, 24] + [dim + 1, 0] = [1, 3, dim, 24] - return newShape + shapeArray; - } - - shape = Cast(shape, DataTypes.Int32); - var iSize = Prod(inputShape); - var sSize = Prod(shape); - var negDimInfactValue = iSize / Abs(sSize); - var index = IndexOf(shape, -1); - return new If(sSize < 0, ShapeExprUtility.Replace(shape, index, negDimInfactValue), shape); + return IR.F.ShapeExpr.ReshapeShape(inShape, shape); } public Metric Visit(IMetricEvaluateContext context, Reshape target) @@ -93,12 +70,10 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i return input; } - if (context.GetArgument(target, Reshape.Shape) is TensorConst shapeConst && - input.Shape.IsFixed) + if (context.GetArgument(target, Reshape.Shape) is TensorConst shapeConst) { var shapeValue = shapeConst.Value.ToArray(); var negCount = shapeValue.Count(IsMinus1); - var inputSize = input.Shape.Prod().FixedValue; var shapeSize = shapeValue.Aggregate(1, (x, y) => x * y); if (negCount > 1) { @@ -106,27 +81,39 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i $"Reshape at most one dimension of the new shape can be -1," + $" shape:{shapeValue}"); } - else if (negCount < 1) + + if (input.Shape.IsFixed) { - if (inputSize != shapeSize) + var inputSize = input.Shape.Prod().FixedValue; + if (negCount < 1) { - return new InvalidType("Reshape input shape size and param shape size must be same," + - $" shape:{shapeValue.ToArray().Aggregate(string.Empty, (s, i) => s + i + " ")}, input shape${string.Join(",", input.Shape)}"); - } + if (inputSize != shapeSize) + { + return new InvalidType("Reshape input shape size and param shape size must be same," + + $" shape:{shapeValue.ToArray().Aggregate(string.Empty, (s, i) => s + i + " ")}, input shape${string.Join(",", input.Shape)}"); + } - return input with { Shape = new Shape(shapeValue) }; + return input with { Shape = new Shape(shapeValue) }; + } + else + { + shapeSize = -shapeSize; + var negIndex = shapeValue.Select((dim, index) => (dim, index)).First(x => IsMinus1(x.dim)).index; + if (inputSize % shapeSize != 0) + { + return new InvalidType("Reshape input size must be divisible by shapeSize when has -1"); + } + + shapeValue[negIndex] = inputSize / shapeSize; + return input with { Shape = new Shape(shapeValue) }; + } } else { - shapeSize = -shapeSize; - var negIndex = shapeValue.Select((dim, index) => (dim, index)).First(x => IsMinus1(x.dim)).index; - if (inputSize % shapeSize != 0) + return input with { - return new InvalidType("Reshape input size must be divisible by shapeSize when has -1"); - } - - shapeValue[negIndex] = inputSize / shapeSize; - return input with { Shape = new Shape(shapeValue) }; + Shape = new Shape(shapeValue.Select(x => x == -1 ? Dimension.Unknown : x).ToArray()), + }; } } diff --git a/src/Nncase.Evaluator/Tensors/Slice.cs b/src/Nncase.Evaluator/Tensors/Slice.cs index 219894f0da..eada657d01 100644 --- a/src/Nncase.Evaluator/Tensors/Slice.cs +++ b/src/Nncase.Evaluator/Tensors/Slice.cs @@ -104,7 +104,7 @@ public Expr Visit(IShapeEvaluateContext context, Slice target) return Ceil(Abs(end - begin) / Abs(stride)); }).ToArray(); - return Stack(new IR.Tuple(outDims), 0); + return Cast(Stack(new IR.Tuple(outDims), 0), DataTypes.Int64); } /// Axis. diff --git a/src/Nncase.Evaluator/Tensors/Squeeze.cs b/src/Nncase.Evaluator/Tensors/Squeeze.cs index 05517c5b4d..19573e7208 100644 --- a/src/Nncase.Evaluator/Tensors/Squeeze.cs +++ b/src/Nncase.Evaluator/Tensors/Squeeze.cs @@ -38,28 +38,9 @@ public Cost Visit(ICostEvaluateContext context, Squeeze target) public Expr Visit(IShapeEvaluateContext context, Squeeze target) { - var inShape = context.GetArgumentShape(target, Squeeze.Input); - var input = context.GetArgument(target, Squeeze.Input); + var input = context.GetArgumentShape(target, Squeeze.Input); var dims = context.GetArgument(target, Squeeze.Dim); - if (dims is TensorConst dimConst) - { - var rank = input.CheckedShape.Count; - var dimValue = dimConst.Value.ToArray().Select(x => x < 0 ? x + rank : x).ToArray(); - var outDims = Enumerable.Range(0, rank).Where(i => !dimValue.Contains(i)).Select(i => inShape[i]).ToArray(); - if (outDims.Length == 0) - { - return 1; - } - - if (outDims.Length == input.CheckedShape.Rank) - { - throw new InvalidOperationException("Bad Squeeze Shape Expr"); - } - - return IR.F.Tensors.Stack(new IR.Tuple(outDims), 0); - } - - throw new NotImplementedException(); + return IR.F.ShapeExpr.SqueezeShape(input, dims); } public Metric Visit(IMetricEvaluateContext context, Squeeze target) diff --git a/src/Nncase.Evaluator/Tensors/Stack.cs b/src/Nncase.Evaluator/Tensors/Stack.cs index 00cb5c4472..4d07bda39e 100644 --- a/src/Nncase.Evaluator/Tensors/Stack.cs +++ b/src/Nncase.Evaluator/Tensors/Stack.cs @@ -58,10 +58,10 @@ public Cost Visit(ICostEvaluateContext context, Stack target) public Expr Visit(IShapeEvaluateContext context, Stack target) { var inShape = context.GetArgumentShape(target, Stack.Inputs); - Expr one = new[] { 1 }; + Expr one = new[] { 1L }; if (inShape[0].CheckedShape.IsScalar) { - one = 1; + one = 1L; } return IR.F.Tensors.Concat(new IR.Tuple(inShape[0], one), 0); diff --git a/src/Nncase.Evaluator/Tensors/Tile.cs b/src/Nncase.Evaluator/Tensors/Tile.cs index 122f122d5a..19bebffdfa 100644 --- a/src/Nncase.Evaluator/Tensors/Tile.cs +++ b/src/Nncase.Evaluator/Tensors/Tile.cs @@ -46,7 +46,7 @@ public Expr Visit(IShapeEvaluateContext context, Tile target) { var inShape = context.GetArgumentShape(target, Tile.Input); var repeats = context.GetArgument(target, Tile.Repeats); - return inShape * IR.F.Tensors.Cast(repeats, DataTypes.Int32); + return inShape * IR.F.Tensors.Cast(repeats, DataTypes.Int64); } public Metric Visit(IMetricEvaluateContext context, Tile target) diff --git a/src/Nncase.Evaluator/Tensors/Transpose.cs b/src/Nncase.Evaluator/Tensors/Transpose.cs index e4bd2ee45f..77e74d07f1 100644 --- a/src/Nncase.Evaluator/Tensors/Transpose.cs +++ b/src/Nncase.Evaluator/Tensors/Transpose.cs @@ -92,17 +92,9 @@ public Metric Visit(IMetricEvaluateContext context, Transpose target) public Expr Visit(IShapeEvaluateContext context, Transpose target) { - var perm = context.GetArgument(target, Transpose.Perm); - var rank = context.GetArgument(target, Transpose.Input).CheckedShape.Rank; - var permValue = IR.F.Tensors.Cast(perm, DataTypes.Int32); var inShape = context.GetArgumentShape(target, Transpose.Input); - var outShape = Enumerable.Range(0, rank).Select(i => inShape[i]).ToArray(); - for (int i = 0; i < rank; i++) - { - outShape[i] = inShape[permValue[i]]; - } - - return IR.F.Tensors.Stack(new IR.Tuple(outShape), 0); + var perm = context.GetArgument(target, Transpose.Perm); + return IR.F.ShapeExpr.TransposeShape(inShape, perm); } private IRType Visit(ITypeInferenceContext context, Transpose target, TensorType input) diff --git a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs index b8c937c44a..bd86940fee 100644 --- a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs +++ b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs @@ -43,29 +43,9 @@ public Cost Visit(ICostEvaluateContext context, Unsqueeze target) public Expr Visit(IShapeEvaluateContext context, Unsqueeze target) { + var input = context.GetArgumentShape(target, Unsqueeze.Input); var dims = context.GetArgument(target, Unsqueeze.Dim); - if (dims is TensorConst dimsConst) - { - var dimsValue = dimsConst.Value.ToArray(); - var outShape = context.GetArgumentShape(target, Unsqueeze.Input); - - foreach (var dimVal in dimsValue) - { - if (dimVal >= 0) - { - outShape = ShapeExprUtility.Insert(outShape, dimVal, 1); - } - else - { - var index = IR.F.Math.Max(IR.F.Tensors.Cast(IR.F.Tensors.ShapeOf(outShape)[0], DataTypes.Int32) + dimVal + 1, 0); - outShape = ShapeExprUtility.Insert(outShape, index, 1); - } - } - - return outShape; - } - - throw new NotImplementedException(); + return IR.F.ShapeExpr.UnsqueezeShape(input, dims); } public Metric Visit(IMetricEvaluateContext context, Unsqueeze target) => Metric.Zero; diff --git a/src/Nncase.Importer/TFLite/Conv2D.cs b/src/Nncase.Importer/TFLite/Conv2D.cs index 0fc41dfb7f..260ee76d0e 100644 --- a/src/Nncase.Importer/TFLite/Conv2D.cs +++ b/src/Nncase.Importer/TFLite/Conv2D.cs @@ -29,19 +29,14 @@ private Expr VisitConv2D(in tflite.Operator op) weights = F.Tensors.NHWCToNCHW(weights); var bias = GetInputExprs(op, 2); var options = op.BuiltinOptionsAsConv2DOptions(); - var (inH, inW) = Util.GetHW(input); - var (fH, fW) = Util.GetHW(weights); var strideH = options.StrideH; var strideW = options.StrideW; var dilationH = options.DilationHFactor; var dilationW = options.DilationWFactor; - var padH = Util.GetWindowedPadding(inH, fH, strideH, dilationH, options.Padding == tflite.Padding.SAME); - var padW = Util.GetWindowedPadding(inW, fW, strideW, dilationW, options.Padding == tflite.Padding.SAME); var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 }); var dilation = Tensor.From(new[] { dilationH, dilationW }, new[] { 2 }); - var padding = Util.ConcatPadding(padH, padW); + var padding = Util.GetPaddings(input, weights, stride, dilation, options.Padding == tflite.Padding.SAME, false); var clamp = ToFloatClampRange(options.FusedActivationFunction); - var inputQuantParams = GetInputQuantParams(op, 0); var weightsQuantParams = GetInputQuantParams(op, 1); var biasQuantParams = GetInputQuantParams(op, 2); @@ -123,19 +118,14 @@ private Expr VisitDepthwiseConv2D(in tflite.Operator op) var bias = GetInputExprs(op, 2); input = F.Tensors.NHWCToNCHW(input); weights = F.Tensors.Transpose(weights, new[] { 3, 0, 1, 2 }); - _ = GetTensorShape(GetInputTensor(op, 1)); var options = op.BuiltinOptionsAsDepthwiseConv2DOptions(); - var (inH, inW) = Util.GetHW(input); - var (fH, fW) = Util.GetHW(weights); var strideH = options.StrideH; var strideW = options.StrideW; var dilationH = options.DilationHFactor; var dilationW = options.DilationWFactor; - var padH = Util.GetWindowedPadding(inH, fH, strideH, dilationH, options.Padding == tflite.Padding.SAME); - var padW = Util.GetWindowedPadding(inW, fW, strideW, dilationW, options.Padding == tflite.Padding.SAME); var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 }); var dilation = Tensor.From(new[] { dilationH, dilationW }, new[] { 2 }); - var padding = Util.ConcatPadding(padH, padW); + var padding = Util.GetPaddings(input, weights, stride, dilation, options.Padding == tflite.Padding.SAME, false); var depthMul = options.DepthMultiplier; if (depthMul != 1) { diff --git a/src/Nncase.Importer/TFLite/Conv2DTranspose.cs b/src/Nncase.Importer/TFLite/Conv2DTranspose.cs index d5687e7cdb..ec8e64f9af 100644 --- a/src/Nncase.Importer/TFLite/Conv2DTranspose.cs +++ b/src/Nncase.Importer/TFLite/Conv2DTranspose.cs @@ -32,17 +32,15 @@ private Expr VisitConv2DTranspose(in tflite.Operator op) } var options = op.BuiltinOptionsAsTransposeConvOptions(); - var (_, _) = Util.GetHW(input); - var (fH, fW) = Util.GetHW(weights, true); var strideH = options.StrideH; var strideW = options.StrideW; var dilationH = 1; var dilationW = 1; - var padH = Util.GetWindowedPadding(newOutShape[2], fH, strideH, dilationH, options.Padding == tflite.Padding.SAME); - var padW = Util.GetWindowedPadding(newOutShape[3], fW, strideW, dilationW, options.Padding == tflite.Padding.SAME); var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 }); var dilation = Tensor.From(new[] { dilationH, dilationW }, new[] { 2 }); - var padding = Util.ConcatPadding(padH, padW); + var oldWShape = F.Tensors.ShapeOf(weights); + var wShape = F.Tensors.Stack(new IR.Tuple(oldWShape[0], oldWShape[3], oldWShape[1], oldWShape[2]), 0); + var padding = F.ShapeExpr.GetPaddings(F.Tensors.Stack(new IR.Tuple(newOutShape), 0), wShape, stride, dilation, options.Padding == tflite.Padding.SAME, false); var clamp = ValueRange.Full; return F.Tensors.NCHWToNHWC(F.Math.Clamp( diff --git a/src/Nncase.Importer/Util.cs b/src/Nncase.Importer/Util.cs index a3ca52d74e..762d503235 100644 --- a/src/Nncase.Importer/Util.cs +++ b/src/Nncase.Importer/Util.cs @@ -90,13 +90,9 @@ public static Expr[] GetWindowedPadding(Expr inputSize, Expr filter, Expr stride } // lower used for onnx when auto_pad attr is SAME_LOWER - public static Expr GetPaddings(Expr input, Expr weights, long[] stride, long[] dilation, bool same, bool lower = false) + public static Expr GetPaddings(Expr input, Expr weights, Expr stride, Expr dilation, bool same, bool lower = false) { - var (inH, inW) = GetHW(input); - var (fH, fW) = GetHW(weights); - var padH = GetWindowedPadding(inH, fH, (int)stride[0], (int)dilation[0], same, lower); - var padW = GetWindowedPadding(inW, fW, (int)stride[1], (int)dilation[1], same, lower); - return ConcatPadding(padH, padW); + return IR.F.ShapeExpr.GetPaddings(ShapeOf(input), ShapeOf(weights), stride, dilation, same, lower); } public static Expr ComputeSplit(Expr input, long outputSize, long axis) diff --git a/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs b/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs index a5324fd11a..2d12883101 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs @@ -78,6 +78,11 @@ public sealed partial class ReshapeToTranspose : IRewriteRule private Expr? GetReplace(Expr input, Call call) { + if (input.CheckedShape.Rank <= 1) + { + return null; + } + var newShape = call.CheckedShape.ToValueArray(); var inShape = input.CheckedShape.ToValueArray(); var sigNewShape = newShape.Where(x => x != 1).ToArray(); diff --git a/src/Nncase.Passes/Rules/Neutral/FoldSqueeze.cs b/src/Nncase.Passes/Rules/Neutral/FoldSqueeze.cs new file mode 100644 index 0000000000..87437d0bdc --- /dev/null +++ b/src/Nncase.Passes/Rules/Neutral/FoldSqueeze.cs @@ -0,0 +1,62 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.IR.Tensors; +using Nncase.Passes; +using Nncase.PatternMatch; +using static Nncase.IR.F.Math; +using static Nncase.IR.F.Tensors; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.Neutral; + +[RuleGenerator] +public partial class FoldUnsqueezeSqueeze : RewriteRule +{ + /// + public override Pattern Pattern => IsUnsqueeze( + "unsqu", + "output", + IsSqueeze(IsWildcard("input"), IsTensorConst("sqAxes")), + IsTensorConst("unsqAxes")); + + private Expr? GetReplace(Call output, Expr input) + { + if (output.CheckedShape.SequenceEqual(input.CheckedShape)) + { + return input; + } + + return null; + } +} + +[RuleGenerator] +public partial class FoldSqueezeUnsqueeze : RewriteRule +{ + /// + public override Pattern Pattern => IsSqueeze( + "sqOp", + "output", + IsUnsqueeze(IsWildcard("input"), IsTensorConst("unsqAxes")), + IsTensorConst("sqAxes")); + + private Expr? GetReplace(Call output, Expr input) + { + if (output.CheckedShape.SequenceEqual(input.CheckedShape)) + { + return input; + } + + return null; + } +} diff --git a/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs b/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs index c3611f11b7..d5e9d6256f 100644 --- a/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs +++ b/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs @@ -101,6 +101,11 @@ public sealed partial class TransposeToReshape : IRewriteRule private Expr? GetReplace(Expr input, Expr tp, Tensor perm, RunPassContext context) { + if (input.CheckedShape.Rank <= 1) + { + return null; + } + // If all significant dims remains ascending order, it can be converted to a reshape. var inShape = input.CheckedShape; var sigAxes = new HashSet(); diff --git a/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs b/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs new file mode 100644 index 0000000000..4361b2db15 --- /dev/null +++ b/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs @@ -0,0 +1,180 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.Tensors; +using Nncase.Passes.Rules.ShapeExpr; +using Nncase.PatternMatch; +using Nncase.Utilities; +using OrtKISharp; +using static Nncase.IR.F.Math; +using static Nncase.IR.F.NN; +using static Nncase.IR.F.Tensors; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.F.NN; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.Neutral; + +[RuleGenerator] +public partial class SplitSpaceToBatch : RewriteRule +{ + /// + public override Pattern Pattern { get; } = IsSpaceToBatch( + IsWildcard("input") with { TypePattern = HasRank() }, + IsWildcard("blockShape") with { TypePattern = HasFixedShape() }, + IsWildcard("paddings")); + + public Expr? GetReplace(Expr input, Expr blockShape, Expr paddings) + { + var spatialSize = blockShape.CheckedShape.Size; + var remainShapeSize = input.CheckedShape.Rank - spatialSize - 1; + var newPaddings = Enumerable.Repeat((Expr)0, (1 + spatialSize + remainShapeSize) * 2).ToArray(); + for (int i = 0; i < spatialSize; i++) + { + newPaddings[1 + i] = paddings[i, 0]; + newPaddings[1 + (newPaddings.Length / 2) + i] = paddings[i, 1]; + } + + var tmpPaddings = Stack(new IR.Tuple(newPaddings), 0); + var newPaddingsTensor = Transpose(Reshape(tmpPaddings, new long[] { 2, 1 + spatialSize + remainShapeSize }), new long[] { 1, 0 }); + var p = Pad(input, newPaddingsTensor, PadMode.Constant, 0f); + + var padShape = Cast(ShapeOf(p), DataTypes.Int32); + var batchShape1 = StackScalar(padShape[0]); + var spatialShape1 = RangeExec( + spatialSize, + i => Stack(new IR.Tuple(padShape[i + 1] / blockShape[i], blockShape[i]), 0)) + .Aggregate((x, y) => Concat(new IR.Tuple(x, y), 0)); + var remainShape1 = Stack(new IR.Tuple(RangeExec(remainShapeSize, i => padShape[1 + spatialSize + i])), 0); + var reshappedShape1 = Concat( + new IR.Tuple( + batchShape1, + spatialShape1, + remainShape1), + 0); + + var perm = RangeExec(spatialSize, i => (i * 2) + 2) + .Concat(new[] { 0 }) + .Concat(RangeExec(spatialSize, i => (i * 2) + 1)) + .Concat(RangeExec(remainShapeSize, i => i + ((int)spatialSize * 2) + 1)) + .Select(x => (long)x) + .ToArray(); + + var reshappedShape2 = Concat( + input: new IR.Tuple( + StackScalar(padShape[0] * Prod(blockShape)), + Stack(new IR.Tuple(RangeExec(spatialSize, i => padShape[i + 1] / blockShape[i])), 0), + Stack(new IR.Tuple(RangeExec(remainShapeSize, i => padShape[1 + spatialSize + i])), 0)), + 0); + + var reshape1 = Reshape(p, reshappedShape1); + var rt = Transpose(reshape1, perm); + var reshape2 = Reshape(rt, reshappedShape2); + return reshape2; + } + + private T[] RangeExec(long end, Func f) + { + return EndRange(0, (int)end).Select(f).ToArray(); + } + + private IEnumerable EndRange(int begin, int end) + { + return Enumerable.Range(begin, end - begin); + } +} + +[RuleGenerator] +public partial class SplitBatchToSpace : RewriteRule +{ + /// + public override Pattern Pattern { get; } = IsBatchToSpace( + IsWildcard("input") with { TypePattern = HasRank() }, + IsWildcard("blockShape") with { TypePattern = HasFixedShape() }, + IsWildcard("crop")); + + public Expr? GetReplace(Expr input, Expr blockShape, Expr crop) + { + // to nhwc + var input0 = NCHWToNHWC(input); + var blockLen = blockShape.CheckedShape.Size; + var xLen = input0.CheckedShape.Rank; + var xShape = Cast(ShapeOf(input0), DataTypes.Int32); + var spatial = ShapeExprUtility.Slice(xShape, 1, blockLen + 1); + var depth = ShapeExprUtility.Slice(xShape, blockLen + 1, xLen); + var targetSpatial = spatial * blockShape; + + var ccat1 = Concat(new IR.Tuple(spatial, blockShape), 0); + var re1 = Reshape(ccat1, new[] { ccat1.CheckedShape[0].FixedValue / blockLen, blockLen }); + var interLeave = Reshape(Transpose(re1, new long[] { 1, 0 }), new long[] { -1 }); + var shape1 = Concat(new IR.Tuple(new int[] { -1 }, interLeave, depth), 0); + + var g1 = BoostRange(2, (2 * blockLen) + 1, 2); + var g2 = BoostRange(1, (2 * blockLen) + 1, 2); + var g3 = BoostRange(0, xLen + blockLen).ToArray()[1 + (2 * blockLen)]; + var indices = g1.Append(0).Concat(g2).Append(g3); + + var perm = GetPerm(xLen, blockLen); + + var newShape = indices.Select(i => shape1[i]).ToArray(); + var x2 = Reshape(input0, Stack(new IR.Tuple(newShape), 0)); + var tr2 = Transpose(x2, perm); + var shape2 = Concat(new IR.Tuple(new[] { -1 }, targetSpatial, depth), 0); + var x3 = Reshape(tr2, shape2); + + var cropTransposed = Transpose(crop, new long[] { 1, 0 }); + var cropArray = Reshape(cropTransposed, new long[] { -1 }); + var w = cropTransposed.CheckedShape[1].FixedValue; + var cropStart = ShapeExprUtility.Slice(cropArray, 0, w); + var cropEnd = ShapeExprUtility.Slice(cropArray, w, w + w); + var endRange = targetSpatial - cropEnd; + var axesConst = BoostRange(1, blockLen + 1).ToArray(); + var strideConst = Enumerable.Repeat(1, axesConst.Length).ToArray(); + var result = Slice(x3, cropStart, endRange, axesConst, strideConst); + + // to nchw + var transposeResult = NHWCToNCHW(result); + return transposeResult; + } + + private static IEnumerable BoostRange(int start, int end, int step = 1) + { + int x = start; + do + { + yield return x; + x += step; + if ((step < 0 && x <= end) || (step > 0 && end <= x)) + { + break; + } + } + while (true); + } + + private long[] GetPerm(int xLen, int blockLen) + { + var perm = Enumerable.Range(0, xLen + blockLen).ToArray(); + perm[0] = blockLen; + perm[1] = blockLen + 1; + perm[2] = 0; + foreach (var i in BoostRange(3, (blockLen * 2) + 1)) + { + perm[i] = perm[i - 2] + 1; + } + + return perm.Select(x => (long)x).ToArray(); + } + + private T[] ZipExec(T[] a, T[] b, Func f) + { + return a.Zip(b).Select(x => f(x.First, x.Second)).ToArray(); + } +} diff --git a/src/Nncase.Passes/Rules/ShapeBucket/FoldBucketReshape.cs b/src/Nncase.Passes/Rules/ShapeBucket/FoldBucketReshape.cs deleted file mode 100644 index 2c78c5ca35..0000000000 --- a/src/Nncase.Passes/Rules/ShapeBucket/FoldBucketReshape.cs +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright (c) Canaan Inc. All rights reserved. -// Licensed under the Apache license. See LICENSE file in the project root for full license information. - -using System.Linq; -using Nncase.IR; -using Nncase.IR.Tensors; -using Nncase.PatternMatch; -using Nncase.Utilities; -using static Nncase.IR.TypePatternUtility; -using static Nncase.PatternMatch.F.Tensors; -using static Nncase.PatternMatch.Utility; - -namespace Nncase.Passes.Rules.ShapeBucket; - -[RuleGenerator] -public sealed partial class FoldBucketPadReshape : RewriteRule -{ - // Reshape(Gather(Shape, 0, 0), new[] { 0 }) -> GetItem(Shape, 0) - public override Pattern Pattern => IsReshape( - IsBucketPad(null, "bucketPad", IsWildcard(), IsTensorConst()), - IsTensorConst("newShape")); - - private Expr? GetReplace(Call bucketPad, Expr newShape) - { - return ReplaceUtility.ReplaceCallParams(bucketPad, (BucketPad.Shape.Index, newShape)); - } -} - -// todo: squeeze -[RuleGenerator] -public sealed partial class FoldBucketPadUnsqueeze : RewriteRule -{ - // Reshape(Gather(Shape, 0, 0), new[] { 0 }) -> GetItem(Shape, 0) - public override Pattern Pattern => IsUnsqueeze( - null, - "unsqueeze", - IsBucketPad(null, "bucketPad", IsWildcard(), IsTensorConst()), - IsTensorConst()); - - private Expr? GetReplace(Call bucketPad, Call unsqueeze) - { - return ReplaceUtility.ReplaceCallParams(bucketPad, (BucketPad.Shape.Index, unsqueeze.CheckedShape.ToValueArray())); - } -} diff --git a/src/Nncase.Passes/Rules/ShapeBucket/FoldNopTuple.cs b/src/Nncase.Passes/Rules/ShapeBucket/FoldNopTuple.cs new file mode 100644 index 0000000000..edbfd66074 --- /dev/null +++ b/src/Nncase.Passes/Rules/ShapeBucket/FoldNopTuple.cs @@ -0,0 +1,64 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Linq; +using System.Reactive; +using System.Threading.Tasks; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using Nncase.Utilities; +using static Nncase.PatternMatch.Utility; +using static Nncase.Utilities.ReplaceUtility; + +namespace Nncase.Passes.Rules.ShapeBucket; + +public class FoldNopTuple : FunctionPass +{ + protected override Task RunCoreAsync(BaseFunction input, RunPassContext context) + { + int i = 0; + while (true) + { + var preHash = input.GetHashCode(); + DumpScope.Current.DumpIR(input, $"{i}_before"); + + new FoldNopTupleVisitior().Visit(input); + DumpScope.Current.DumpIR(input, $"{i++}_after_convert"); + var afterHash = input.GetHashCode(); + if (preHash == afterHash) + { + return Task.FromResult(input); + } + } + } + + internal class FoldNopTupleVisitior : ExprVisitor + { + private bool _changed; + + public FoldNopTupleVisitior() + : base(true) + { + } + + protected override Expr DefaultVisitLeaf(Expr expr) => expr; + + protected override Expr VisitLeafTuple(Tuple expr) + { + if (!_changed && expr.Users.All(user => user is Call { Target: GetItem })) + { + foreach (var user in expr.Users) + { + var index = ((TensorConst)((Call)user).Arguments[GetItem.Index.Index]).Value.ToScalar(); + ReplaceUtility.ReplaceAllUsesWith(user, expr.Fields[index]); + } + + _changed = true; + } + + return expr; + } + } +} diff --git a/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs b/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs index f22e2d4c70..be7dad641b 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs @@ -25,16 +25,38 @@ namespace Nncase.Passes.Rules.ShapeBucket; public class MergeBucketFusionPass : FunctionPass { + private readonly bool _greedy; + + public MergeBucketFusionPass(bool greedy) + { + _greedy = greedy; + } + protected override async Task RunCoreAsync(BaseFunction input, RunPassContext context) { + // bool greedy and dynamic var main = (Function)input; + int i = 0; while (true) { var preHash = main.GetHashCode(); - CompilerServices.Rewrite(main, new IRewriteRule[] { new MultiUserCallToFusion(), new MergeTupleFusion() }, new()); - await new MergeSeqBucketFusion().RunAsync(main, context); - IRHelpers.DCE(main); - await new MergeMultiUsersFusion().RunAsync(main, context); + if (_greedy) + { + CompilerServices.Rewrite(main, new IRewriteRule[] { new MultiUserCallToFusion(false, _greedy), new MergeTupleFusion() }, new()); + await new MergeSeqBucketFusion().RunAsync(main, context); + IRHelpers.DCE(main); + await new MergeMultiUsersFusion().RunAsync(main, context); + DumpIR(main, $"{i}_before", "FoldNopTuple"); + await new FoldNopTuple().RunAsync(main, context); + } + else + { + await new MergeSeqBucketFusion().RunAsync(main, context); + IRHelpers.DCE(main); + } + + CheckRepeat(main); + CheckErrorVar(main, main.Parameters.ToArray()); var postHash = main.GetHashCode(); if (preHash == postHash) { @@ -42,6 +64,7 @@ protected override async Task RunCoreAsync(BaseFunction input, Run } } + DumpIR(main, "MergeBucketFusionEnd"); return main; } } @@ -176,10 +199,6 @@ public class MergeMultiUsersFusion : FunctionPass public static bool DetectedRing(Call outerCall, Expr[] users) { - // var users = outerCall.Users.ToArray(); - // todo: fix this,TestComplexExpr - // var userArgs = users.SelectMany(user => ((Call)user).Arguments.ToArray()).Except(users).ToArray(); - // 用这个不过,但是好像会引起其他问题?? var userArgs = users.SelectMany(user => ((Call)user).Arguments.ToArray()).ToArray(); foreach (var arg in userArgs) { @@ -231,23 +250,15 @@ private static (Expr? NewCall, UserInfo[] AllUsers) MergeMultiUserFusion(Call ou // todo: not support if (users.Any(user => user is Tuple)) { - // Console.WriteLine("HasTuple"); return notSupport; } var userInfos = CollectUsers(outerCall, users); - // todo: support only one user, because merge fusion rule is not enough - // maybe a error - // if (userInfos.Length < 2) - // { - // return null; - // } - // has invalid if (userInfos.Length != users.Distinct().ToArray().Length) { - Console.WriteLine("not all fusion call and getItemMode"); + // Console.WriteLine("not all fusion call and getItemMode"); return notSupport; } @@ -601,8 +612,7 @@ protected override Expr VisitLeafCall(Call expr) } // Console.WriteLine($"Match {fusion.Name} counter:{Counter}"); - DumpIR(Root, "OriginRoot", RelPath); - + // DumpIR(Root, "OriginRoot", RelPath); var (newCall, users) = MergeMultiUserFusion(outerCall, fusion); if (newCall != null) { diff --git a/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs b/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs index 5118cb44e9..bac3b6f135 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs @@ -7,6 +7,7 @@ using System.Xml; using NetFabric.Hyperlinq; using Nncase.IR; +using Nncase.IR.Math; using Nncase.IR.Tensors; using Nncase.PatternMatch; using static Nncase.Passes.Rules.ShapeBucket.ShapeBucketHelper; @@ -35,11 +36,6 @@ public static bool AllConst(Call originCall) return false; } - - public bool ValidTarget(Expr target) - { - return CallValidator.ValidTarget(target); - } } [RuleGenerator] @@ -106,6 +102,17 @@ public partial class MergePrevMarkerToFusion : MergeFusionBase [RuleGenerator] public partial class MergeNextCallToFusion : MergeFusionBase { + private readonly bool _greedy = true; + + public MergeNextCallToFusion(bool greedy = true) + { + _greedy = greedy; + } + + public MergeNextCallToFusion() + { + } + public Pattern FusionCall => IsCall( "fusionOuterCall", IsFusion( @@ -127,13 +134,13 @@ public partial class MergeNextCallToFusion : MergeFusionBase // nextCall(marker(fusion(x))) -> fusion(nextCall(marker(x))) public Expr? GetReplace(Call nextCall, Expr maybeFusionCallMarker, Expr target, Call fusionOuterCall, BucketFusion fusion) { - var singleVar = CompileSession.CompileOptions.ShapeBucketOptions.VarMap.Values.SelectMany(x => x).OfType().ToHashSet().Count <= 1; + var singleVar = SingleDimVar(CompileSession.CompileOptions.ShapeBucketOptions); if (!singleVar && nextCall.Arguments.ToArray().OfType().Count() > 1) { return null; } - if (!ValidTarget(target)) + if (!CallValidator.ValidTarget(nextCall, _greedy)) { return null; } @@ -237,8 +244,25 @@ private bool SameEffectVar(Call originCall, Fusion fusion) [RuleGenerator] public partial class MergePrevCallToFusion : MergeFusionBase { + // 输入必须匹配marker,因为即便合并marker也是要在外面保留一份副本 + // fusion(marker(prevCall()) { var } -> fusion(var) { marker(prevCall()) } + // fusion((prevCall()) { var } -> fusion(var) { prevCall() } + private readonly bool _greedy = true; + + private readonly bool _mergeFusion; + private string _prevCallStr = string.Empty; + public MergePrevCallToFusion() + { + } + + public MergePrevCallToFusion(bool greedy = true, bool mergeFusion = false) + { + _greedy = greedy; + _mergeFusion = mergeFusion; + } + public override Pattern Pattern => IsCall( "fusionOuterCall", IsFusion( @@ -257,10 +281,6 @@ public Pattern MaybeMarker(string exprName, Pattern exprPatten) => IsAlt( IsRangeOfMarker(exprPatten, IsWildcard()), exprPatten); - // 输入必须匹配marker,因为即便合并marker也是要在外面保留一份副本 - // fusion(marker(prevCall()) { var } -> fusion(var) { marker(prevCall()) } - // fusion((prevCall()) { var } -> fusion(var) { prevCall() } - // dfs // xx | marker(xx)不行, 会先匹配到xx // xx(marker) | xx 可以 @@ -600,7 +620,7 @@ private bool IsInvalid(Call lhsPrevCall, Expr lhsTarget) return true; } - if (!ValidTarget(lhsTarget)) + if (!CallValidator.ValidTarget(lhsPrevCall, _greedy)) { return true; } diff --git a/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs b/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs index 110e37026a..94f7977f2f 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs @@ -12,10 +12,22 @@ using Nncase.Evaluator; using Nncase.IR; using static Nncase.IR.F.Tensors; +using static Nncase.Passes.Rules.ShapeBucket.ShapeBucketHelper; namespace Nncase.Passes.Rules.ShapeBucket; -public record FusionShapeData(IValue Outshape, IValue[] InputShapes); +public class FusionShapeData +{ + public FusionShapeData(IValue outshape, IValue[] inputShapes) + { + Outshape = outshape; + InputShapes = inputShapes; + } + + public IValue Outshape { get; } + + public IValue[] InputShapes { get; } +} public class FusionShapeUpdater : ExprVisitor { @@ -26,7 +38,7 @@ public FusionShapeUpdater(Dictionary memo) _memo = memo; } - public Dictionary FusionShape { get; set; } = new(); + public Dictionary FusionShape { get; } = new(); protected override Expr DefaultVisitLeaf(Expr expr) => expr; @@ -34,7 +46,11 @@ protected override Expr VisitLeafCall(Call expr) { if (expr.Target is BucketFusion f) { - var argShape = expr.Arguments.ToArray().Select(arg => GetShape(_memo[arg])).ToArray(); + var argShape = expr.Arguments.ToArray().Select(arg => + { + var exp = arg is Marker m ? m.Target : arg; + return GetShape(_memo[exp]); + }).ToArray(); var shape = GetShape(_memo[expr]); FusionShape[f] = new FusionShapeData(shape, argShape); } @@ -54,25 +70,6 @@ private IValue GetShape(IValue value) } } -public class SimpleTimer : IDisposable -{ - private readonly DateTime _startTime; - private readonly string _name; - - public SimpleTimer(string name) - { - _startTime = System.DateTime.Now; - _name = name; - } - - public void Dispose() - { - var endTime = System.DateTime.Now; - var time = endTime - _startTime; - Console.WriteLine($"{_name} tooks {time.Seconds}"); - } -} - public class RecordFusionShape : FunctionPass { private Dictionary _dimVarValues = new(); @@ -84,11 +81,32 @@ public RecordFusionShape(Dictionary shapeList) public Dictionary FusionShapeInfo { get; set; } + // make dummy value from InputInfo + // VarInfo:(DimVar -> Value) + public static Dictionary + MakeDummyInput(IReadOnlyDictionary info, Dictionary varInfo) + { + return info.ToDictionary( + pair => pair.Key, + pair => + { + // todo: dummy input可能会有问题... + var shapeExpr = pair.Key.CheckedShape.IsScalar + ? (Expr)Array.Empty() + : Stack(new IR.Tuple(pair.Value.Select(x => Cast(x, DataTypes.Int64)).ToArray()), 0); + + var shape = shapeExpr.Evaluate(varInfo).AsTensor(); + return ConstantOfShape( + shape, + Cast(1, pair.Key.CheckedDataType)).Evaluate(varInfo); + }); + } + protected override Task RunCoreAsync(BaseFunction main, RunPassContext context) { var options = CompileSession.CompileOptions.ShapeBucketOptions; var varMap = options.VarMap; - _dimVarValues = ShapeBucketHelper.MakeVarValuesForAllSegment(options); + _dimVarValues = MakeVarValuesForAllSegment(options); // 一共有多组key seg var list = Enumerable.Range(0, _dimVarValues.First().Value.Length).Select(i => @@ -96,13 +114,14 @@ protected override Task RunCoreAsync(BaseFunction main, RunPassCon // 一组里面多个key seg return _dimVarValues.Select(pair => (pair.Key, Value: pair.Value[i])).ToArray(); }).ToArray(); + + var body = ((Function)main).Body; var tmpFusionShapeList = list.Select((seg, i) => { var varValues = seg.ToDictionary(pair => pair.Key, pair => (IValue)Value.FromTensor(pair.Value)); var exprValues = seg.ToDictionary(pair => (Expr)pair.Key, pair => (IValue)Value.FromTensor(pair.Value)); var input = MakeDummyInput(varMap, varValues); - var body = ((Function)main).Body; - var memo = EvaluatorUtil.GetMemo(body, input); + var memo = EvaluatorUtil.GetMemo(body, ConcatDictionary(input, varValues)); var f = new FusionShapeUpdater(ConcatDictionary(memo, exprValues)); f.Visit(main); return f.FusionShape; @@ -110,6 +129,7 @@ protected override Task RunCoreAsync(BaseFunction main, RunPassCon .ToLookup(x => x.Key, x => x.Value) .ToDictionary(pair => pair.Key, pair => pair.ToArray()); + GC.Collect(); foreach (var (f, shapeInfo) in tmpFusionShapeList) { FusionShapeInfo[f] = shapeInfo; @@ -117,35 +137,4 @@ protected override Task RunCoreAsync(BaseFunction main, RunPassCon return Task.FromResult(main); } - - private static Dictionary ConcatDictionary(Dictionary memo, Dictionary exprValues) - { - foreach (var (key, value) in exprValues) - { - memo[key] = value; - } - - return memo; - } - - // make dummy value from InputInfo - // VarInfo:(DimVar -> Value) - private static Dictionary - MakeDummyInput(IReadOnlyDictionary info, Dictionary varInfo) - { - return info.ToDictionary( - pair => pair.Key, - pair => - { - // todo: dummy input可能会有问题... - var shapeExpr = pair.Key.CheckedShape.IsScalar - ? (Expr)Array.Empty() - : Stack(new IR.Tuple(pair.Value.Select(x => Cast(x, DataTypes.Int32)).ToArray()), 0); - - var shape = shapeExpr.Evaluate(varInfo).AsTensor(); - return ConstantOfShape( - shape, - Cast(1, pair.Key.CheckedDataType)).Evaluate(varInfo); - }); - } } diff --git a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs index 197a5e0e8b..2eb9c0f39a 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucket.cs @@ -13,14 +13,17 @@ using System.Transactions; using DryIoc; using DryIoc.ImTools; +using GiGraph.Dot.Types.Geometry; using Microsoft.Extensions.DependencyInjection; using Microsoft.Toolkit.HighPerformance; using NetFabric.Hyperlinq; +using Nncase.CodeGen; using Nncase.Diagnostics; using Nncase.Evaluator; using Nncase.IR; using Nncase.IR.Math; using Nncase.IR.NN; +using Nncase.IR.ShapeExpr; using Nncase.IR.Tensors; using Nncase.Passes.Analysis; using Nncase.Passes.Rules.Lower; @@ -35,6 +38,7 @@ using static Nncase.PatternMatch.F.Tensors; using static Nncase.PatternMatch.Utility; using static Nncase.Utilities.ReplaceUtility; +using BaseFunction = Nncase.IR.BaseFunction; using Dimension = Nncase.IR.Dimension; using FoldConstCall = Nncase.Passes.Mutators.FoldConstCall; using Stack = Nncase.IR.Tensors.Stack; @@ -74,25 +78,6 @@ public BucketFusion(string moduleKind, Var[] effectVar, Expr body, params Var[] public Var[] EffectVar { get; set; } - public bool IsSimple - { - get - { - // todo: change list - var names = Name.Split("_"); - var list = new[] { "MatMul", "Conv2D", "Conv2DTranspose", "Transpose" }; - foreach (string name in names) - { - if (list.Contains(name)) - { - return false; - } - } - - return true; - } - } - public static BucketFusion FromNormalFusion(Fusion f, Var[] effectVars) { return new BucketFusion(f.Name, "stackvm", f.Body, f.Parameters.ToArray(), effectVars); @@ -139,8 +124,6 @@ public CallToFusion() public override Pattern Pattern => throw new InvalidOperationException(); - protected virtual bool MustHaveMarker => true; - private Call? CurrentCall { get; set; } private string Name => CurrentCall!.Target.GetType().Name; @@ -173,9 +156,6 @@ public virtual bool Check(Call call) Console.WriteLine(call.Target.GetType().Name); var argsMarkerData = CollectInputs(call); var args = argsMarkerData.Select(pair => pair.Item1).ToArray(); - - // var argsMarker = argsMarkerData.Select(pair => pair.Item1).ToArray(); - // var args = argsMarker.Select(arg => arg.Target).ToArray(); var varMap = CompileSession.CompileOptions.ShapeBucketOptions.VarMap; var set = MakeEffectVarArray(CompileSession, varMap, args); var fusionVars = MakeNewParam(args); @@ -351,9 +331,12 @@ protected override void Init(IMatchResult result) public class MultiUserCallToFusion : CallToFusion { - public MultiUserCallToFusion(bool onlyDynamic = false) + private readonly bool _greedy = true; + + public MultiUserCallToFusion(bool onlyDynamic = false, bool greedy = true) : base(onlyDynamic) { + _greedy = greedy; } public MultiUserCallToFusion() @@ -364,27 +347,7 @@ public MultiUserCallToFusion() { if (expr is Call c && c.Target is not BucketFusion) { - if (c.Target is Binary) - { - if (c.Arguments[0] is not Const && c.Arguments[1] is not Const) - { - return false; - } - - return true; - } - - if (c.Target is IR.Tensors.Reshape) - { - if (c.Arguments[IR.Tensors.Reshape.Shape.Index] is TensorConst) - { - return CallValidator.ValidTarget(c.Target); - } - } - else - { - return CallValidator.ValidTarget(c.Target); - } + return CallValidator.ValidTarget(c, _greedy); } return false; @@ -536,8 +499,29 @@ public TransposeToFusion(bool isDynamic = false) : base(isDynamic) { } +} - protected override bool MustHaveMarker => false; +public class ReshapeToFusion : CallToFusion +{ + public ReshapeToFusion(bool isDynamic = false) + : base(isDynamic) + { + } + + public override Pattern Pattern => IsCallWildcard("call", IsOp()); + + protected override (Expr, int)[] CollectInputs(Call call) + { + var input = call.Arguments[IR.Tensors.Reshape.Input.Index]; + var inputPair = (input, IR.Tensors.Reshape.Input.Index); + var padPair = (call.Arguments[IR.Tensors.Reshape.Shape.Index], IR.Tensors.Reshape.Shape.Index); + if (padPair.Item1 is TensorConst) + { + return new[] { inputPair }; + } + + return new[] { inputPair, padPair }; + } } public class UnaryToFusion : MarkerCallToFusion @@ -560,7 +544,7 @@ public BinaryToFusion(bool isDynamic = false) { } - // public override bool Check(Call call) => call.CheckedShape.Rank > 1; + public override bool Check(Call call) => call.CheckedShape.Rank > 1; } [RuleGenerator] @@ -621,20 +605,23 @@ public partial class ClearFusionOuterMarker : RewriteRule public class FusionBucketContext { - public FusionBucketContext(Call outerCall, BucketFusion fusion, Dictionary varMap, Dictionary dimVarValues, ShapeExprCache cache) + private readonly int _index; + + public FusionBucketContext(Call outerCall, BucketFusion fusion, ShapeBucketOptions options, ShapeExprCache cache, int index, FusionShapeData[] shapeInfos) { OuterCall = outerCall; Fusion = fusion; - VarMap = varMap; + VarMap = options.VarMap; Cache = cache; - Cache.VarMap = varMap; - FusionInputShapeExpr = MakeFusionInputShapeExpr(outerCall, fusion, cache); - CheckAlive(FusionInputShapeExpr); - DimVarValues = dimVarValues; + Cache.VarMap = options.VarMap; + FusionInputShapeExpr = new(); + DimVarValues = MakeVarValuesForAllSegment(options); + Arguments = OuterCall.Arguments.ToArray(); Parameters = Fusion.Parameters.ToArray(); FixedShapeCache = new(); - SliceShape = ComputeSliceShape(); + SliceShape = ComputeSliceShape(shapeInfos); + _index = index; } public Expr SliceShape { get; } @@ -680,7 +667,6 @@ private static Dictionary MakeFusionInputShapeExpr(Call call, Bucke { var data = fusion.Parameters.ToArray().Zip(call.Arguments.ToArray().Select((arg, i) => { - // DumpIR(arg, "MakeFusionInputShapeExprArg"); var result = arg.EvaluateShapeExpr(cache); if (!result.InferenceType()) { @@ -698,31 +684,138 @@ private static Dictionary MakeFusionInputShapeExpr(Call call, Bucke return fusionInputData; } - private static void CheckAlive(Dictionary fusionInputInfo) + private static Expr ReplaceShapeOf(Dictionary fusionInputsShapeExpr, Dictionary varMap, Expr originShape, Var[] parameters, Var[] dimVarKeys, FusionBucketContext context, Dictionary dict) { - foreach (var value in fusionInputInfo.Values) + // return originShape; + // 拷贝shape表达式,以免被原始的计算引用 + var cloneShape = originShape.Clone(); + CompilerServices.Rewrite(cloneShape, new[] { new RemoveMarker() }, new()); + var f = new FindVar(); + f.Visit(cloneShape); + var newVars = f.Vars; + + // 可能在VarMap里面有,但是newVar中没有,所以把newVar转换为oldVar + var newDict = fusionInputsShapeExpr + .Concat(varMap) + .Where(pair => newVars.FindFirst(newVar => newVar.Name == pair.Key.Name) != null) + .ToDictionary( + pair => + { + var k = newVars.FindFirst(newVar => newVar.Name == pair.Key.Name); + return k; + }, + pair => + { + var v = pair.Value; + return v; + }) + .ToDictionary(x => x.Key, x => x.Value); + + var originVars = parameters + .ToArray() + .Concat(varMap.Keys) + .Concat(dimVarKeys) + .ToDictionary(v => v.Name, v => v); + + Task.Run(() => new FoldNopTuple().RunAsync(new Function(cloneShape), new())).Wait(); + Expr sliceShape = cloneShape; + var p = new ReplaceOfCollector(); + p.Visit(cloneShape); + var processList = p.List; + processList.Reverse(); + + var argCache = context.Arguments.ToDictionary(arg => arg, arg => (Expr)ShapeOf(arg)); + var exprs = argCache.SelectMany(pair => new[] { pair.Key, pair.Value }).ToArray(); + var pinner = new ExprPinner(exprs); + var cache = new ShapeExprCache(newDict, argCache); + + foreach (var call in processList) + { + var newShapeOf = call.Arguments[0].EvaluateShapeExpr(cache); + ReplaceUtility.ReplaceAllUsesWith(call, newShapeOf); + } + + foreach (var (key, value) in dict) { - foreach (var expr in value) + var mutator = new Passes.Mutators.Substitutor(e => { - if (!expr.IsAlive) + if (e is Var v1 && v1.Name == value.Name) { - throw new NotImplementedException(); + return key; } - } + + return null; + }); + mutator.Visit(sliceShape, Unit.Default); } + + newVars.ToArray().ForEach(newVar => + { + if (originVars.TryGetValue(newVar.Name, out var originVar)) + { + ReplaceExpr(sliceShape, newVar, originVar); + } + }); + + var body = sliceShape; + var simplifySliceShape = SimplifyShape(body); + return simplifySliceShape; } - private Expr ComputeSliceShape() + private static Expr SimplifyShape(Expr body) => + CompilerServices.Rewrite( + body, + new IRewriteRule[] + { + new FoldStackGetItem(), new FoldShapeOf(), new FoldTwoReshapes(), new FoldTwoCasts(), + new FoldTwoSlices(), new FoldNopBinary(), new FoldNopCast(), new Neutral.FoldConstCall(), + new FoldNopReshape(), new FoldNopSlice(), new FoldIf(), new FoldBroadcastShape(), new FoldSplitShapeOf(), + }, + new()); + + private Expr ComputeSliceShape(FusionShapeData[] shapeInfos) { var originBody = FusionBody; var shapeOfFusionInput = MakeShapeOfFusionInput(Parameters, Arguments); var originShape = originBody.EvaluateShapeExpr(shapeOfFusionInput); originShape.InferenceType(); - return originShape; + // complex check + // 判断是否需要replace,里面是否存在满足条件的shapeof + var args = Arguments.ToDictionary(x => x, x => new Var(x.CheckedType)); + var input = MakeShapeOfFusionInput(Parameters, args.Values.ToArray()); + var varShape = originBody.EvaluateShapeExpr(input); + var p = new ReplaceOfCollector(); + p.Visit(originBody); + if (p.List.Count == 0) + { + return SimplifyShape(originShape); + } + + return ReplaceShapeOf(shapeOfFusionInput, VarMap, varShape, Parameters, DimVarValues.Keys.ToArray(), this, args); } } +public class ReplaceOfCollector : ExprVisitor +{ + public List List { get; } = new(); + + protected override Expr VisitLeafCall(Call expr) + { + var input = expr.Arguments[0]; + + // input is marker or call + if (expr.Target is ShapeOf && input.CheckedShape.Rank > 2 && input is not Var) + { + List.Add(expr); + } + + return expr; + } + + protected override Expr DefaultVisitLeaf(Expr expr) => expr; +} + [RuleGenerator] public partial class FusionBucket : RewriteRule { @@ -748,17 +841,18 @@ public FusionBucket(Dictionary list) GenerateParameters(null)), GenerateParameters(null)); - internal Dictionary VarMap => CompileSession.CompileOptions.ShapeBucketOptions.VarMap; - public static Expr PreProcess(FusionBucketContext context, Var param, Dictionary inputInfo, Dictionary varValues, Dictionary fusionInputData, int segIndex, int inputIndex) { // Console.WriteLine($"seg index{segIndex}"); if (context.FixedShapeCache.TryGetValue(segIndex, out var cachedFixedShape)) { - // var cachedShape = cachedFixedShape[inputIndex]; - // Console.WriteLine(string.Join(",", cachedShape)); - // Console.WriteLine("Cache ok"); - return new Call(new BucketPad(), param, cachedFixedShape[inputIndex]); + var shape = cachedFixedShape[inputIndex]; + if ((param.CheckedShape.IsFixed && shape.SequenceEqual(param.CheckedShape.ToValueArray())) || param.CheckedShape.IsScalar) + { + return param; + } + + return new Call(new BucketPad(), param, shape); } throw new InvalidDataException("Shape Cache not found"); @@ -780,7 +874,38 @@ public static (Dictionary MinDict, Dictionary MaxDict) return (minDict, maxDict); } - public static Expr MakeSplitEntry(FusionBucketContext context, Dictionary varInfo, int segIndex) + public static Expr Split(FusionBucketContext context) + { + var failure = MakeFailure(context.FusionBody); + + // todo: test this + var value = GetVarValue(context); + + int i = 0; + + // todo: only used for same range + var body = context.DimVarValues.First().Value.OrderByDescending(x => x).Aggregate( + failure, + (sum, seg) => + { + // 根据var,也就是target为这个fusion的call的参数来进行判断落在哪个段 + var cond = value <= (long)seg; + var sameCond = IR.F.Math.Equal(value, (long)seg); + + // select var value for current segment + var varInfo = context.DimVarValue(i); + var thenBody = MakeSplitEntry(context, varInfo, i, sameCond); + var elseBody = sum; + i++; + + var result = new If(cond, thenBody, elseBody); + return result; + }); + + return body; + } + + public static Expr MakeSplitEntry(FusionBucketContext context, Dictionary varInfo, int segIndex, Expr sameCond, bool sameOpt = false) { var originBody = context.FusionBody; var fusionVars = context.Parameters; @@ -791,7 +916,6 @@ public static Expr MakeSplitEntry(FusionBucketContext context, Dictionary fusion原始的var -> target为fusion的call的input // 本质上只是对这个body的所有输入做替换 // 避免这里的修改影响到原始的body,每个分支需要进行自己的修改,所以要clone处理 - // DumpIR(originBody, "originBody", _relPath); var call = ReplaceClone(originBody, fusionVars.Zip(fixInputs).ToArray()); if (!call.InferenceType()) { @@ -800,8 +924,26 @@ public static Expr MakeSplitEntry(FusionBucketContext context, Dictionary pair.Value.Any(x => x is Var)).Select(pair => + { + var (v, dims) = pair; + var i = dims.IndexOf(x => x is Var); + return ShapeOf(v)[i]; + }).ToArray(); + + if (varList.Length > 1) + { + return varList.Aggregate((sum, x) => IR.F.Math.Max(sum, x)); + } + + return varList.First(); } public Expr? GetReplace(Call outerCall, BucketFusion fusion, Expr fusionBody) @@ -811,24 +953,11 @@ public static Expr MakeSplitEntry(FusionBucketContext context, Dictionary(); if (!FusionShapeInfo.TryGetValue(fusion, out shapeInfos)) { @@ -842,6 +971,9 @@ public static Expr MakeSplitEntry(FusionBucketContext context, Dictionary x.InputShapes.Select(iShape => iShape.AsTensor().ToArray().ToArray()).ToArray()).ToArray(); @@ -853,6 +985,7 @@ public static Expr MakeSplitEntry(FusionBucketContext context, Dictionary 1"); } - var info = ComputeSegmentInfo(counts, options); - var body = Split(context, info); + // 1. 普通情况不应该rebuild + // 2. rebuild的正确性 + if (ShouldBeRebuild(context)) + { + _counter++; + Console.WriteLine("Rebuild"); + var rebuild = RestoreBodyWithArgs(context.Arguments, context.Parameters, context.FusionBody); + DumpIR(rebuild, "Rebuild", _relPath); + return rebuild; + } + + var body = Split(context); body.InferenceType(); - if (body.Users.Count > 1 || body.CheckedType is InvalidType) + if (body.CheckedType is InvalidType) { + DumpIR(body, "InvalidBody"); throw new InvalidOperationException(); } - if (body is not If) + if (body.Users.Count > 1) { - _counter++; - DumpIR(body, "Rebuild", _relPath); - return body; + throw new InvalidOperationException(); } - // DumpIR(body, "newBodyBeforeReplace", _relPath); // FixInput Replace Var var newBody = ReplaceFusionVarWithCallArgs(fusion, context.Arguments, body); @@ -938,19 +1079,19 @@ private static void PrintShapeInfos(FusionShapeData[] shapeInfos) private static Expr MakeSlice(FusionBucketContext context, Expr call, Expr originBody) { + var sliceShape = context.SliceShape; if (call.CheckedType is TupleType tuple) { var fields = Enumerable.Range(0, tuple.Count) - .Select(i => MakeSliceForTensor(originBody[i], call[i], context)).ToArray(); + .Select(i => MakeSliceForTensor(sliceShape[i], call[i], context)).ToArray(); return new IR.Tuple(fields); } - return MakeSliceForTensor(originBody, call, context); + return MakeSliceForTensor(sliceShape, call, context); } - private static Expr MakeSliceForTensor(Expr originBody, Expr call, FusionBucketContext context) + private static Expr MakeSliceForTensor(Expr sliceShape, Expr call, FusionBucketContext context) { - var sliceShape = context.SliceShape; var rank = call.CheckedShape.Rank; var simplifyCall = CompilerServices.Rewrite( call, @@ -967,10 +1108,13 @@ private static Expr MakeSliceForTensor(Expr originBody, Expr call, FusionBucketC new FoldNopReshape(), new FoldNopSlice(), new FoldIf(), + new FoldBroadcastShape(), }, new()); - var body = (Expr)Slice(simplifyCall, Enumerable.Repeat(0, rank).ToArray(), Cast(sliceShape, DataTypes.Int32), rank); + var axes = Tensor.From(Enumerable.Range(0, rank).Select(x => (long)x).ToArray()); + var strides = Tensor.FromScalar(1L, rank); + var body = (Expr)Slice(simplifyCall, Enumerable.Repeat(0L, rank).ToArray(), Cast(sliceShape, DataTypes.Int64), axes, strides); return body; } @@ -980,11 +1124,26 @@ private static bool IsFixed(int totalCount, int[][] minFixedShapeList, int[][] m private static bool ShouldRestore(Call outerCall, BucketFusion fusion) { - return fusion.IsSimple || - outerCall.CheckedType is TupleType || - outerCall.CheckedShape.Rank == 0 || - outerCall.Arguments.ToArray().Any(arg => - arg.CheckedType is TupleType); + if (CallValidator.IsSimple(fusion)) + { + return true; + } + + if (outerCall.CheckedType is TupleType tt) + { + if (tt.Fields.All(f => f is TensorType t && t.Shape.Rank < 2)) + { + return true; + } + } + + if (outerCall.Arguments.ToArray().Any(arg => + arg.CheckedType is TupleType)) + { + return true; + } + + return false; } private static Expr RestoreBodyWithArgs(Expr[] args, Var[] parameters, Expr body) => @@ -1055,72 +1214,141 @@ private static Expr ReplaceFusionVarWithCallArgs(BucketFusion fusion, Expr[] arg return result; }); - private static Expr Split(FusionBucketContext context, SegmentInfo info) + private static Expr MakeFailure(Expr fusionBody) { - var fusionInputs = context.Arguments; - var (inputIndex, dimIndex, segments) = info; - var dim = ShapeOf(fusionInputs[inputIndex])[dimIndex]; - var failure = MakeFailure(context.FusionBody); - - int i = 0; - - // 1. 普通情况不应该rebuild - // 2. rebuild的正确性 - // if (ShouldBeRebuild(context)) - // { - // Console.WriteLine("Rebuild"); - // return RestoreBodyWithArgs(context.Arguments, context.Parameters, context.FusionBody); - // } - var body = segments.OrderByDescending(x => x).Aggregate( - failure, - (sum, seg) => - { - // 根据var,也就是target为这个fusion的call的参数来进行判断落在哪个段 - var cond = dim <= (long)seg; - - // select var value for current segment - var varInfo = context.DimVarValue(i); - var thenBody = MakeSplitEntry(context, varInfo, i); - var elseBody = sum; - i++; - - var result = new If(cond, thenBody, elseBody); - return result; - }); - - return body; + var failure = fusionBody.CheckedType switch + { + TupleType tuple => new IR.Tuple(tuple.Fields.ToArray() + .Select(x => + { + return ConstantOfShape(new[] { 1 }, Cast(0, ((TensorType)x).DType)); + }).ToArray()), + TensorType tensorType => (Expr)ConstantOfShape(new[] { 1 }, Cast(0, tensorType.DType)), + _ => throw new ArgumentOutOfRangeException("fusionBody"), + }; + return IR.F.Math.Require(false, failure, "input dim large than limit"); } private static bool ShouldBeRebuild(FusionBucketContext context) { var varInfo = context.DimVarValue(0); - var entry = MakeSplitEntry(context, varInfo, 0); + var entry = MakeSplitEntry(context, varInfo, 0, false, false); return entry switch { IR.Tuple tuple => tuple.Fields.ToArray().Any(ShouldBeRebuild), Call => ShouldBeRebuild(entry), - _ => throw new ArgumentOutOfRangeException("context"), + _ => DumpError(entry), }; } - private static bool ShouldBeRebuild(Expr entry) => entry is Call { Target: IR.Tensors.Slice } c && - (!c.Arguments[IR.Tensors.Slice.Input.Index].CheckedShape - .IsFixed); + private static bool DumpError(Expr entry) + { + DumpIR(entry, "FailedEntry"); + throw new InvalidOperationException(); + } + + private static bool ShouldBeRebuild(Expr entry) + { + if (entry is Call { Target: IR.Tensors.Slice } c) + { + var body = c.Arguments[IR.Tensors.Slice.Input.Index]; + if (body.CheckedShape.IsFixed) + { + var visitor = new DynamicCheckVisitor(); + visitor.Visit(body); + return visitor.HasDynamic; + } + } - private static Expr MakeFailure(Expr fusionBody) + return true; + } + + public class DynamicCheckVisitor : ExprVisitor { - var failure = fusionBody.CheckedType switch + private bool _hasDynamic; + + public bool HasDynamic => _hasDynamic; + + protected override Expr DefaultVisitLeaf(Expr expr) => expr; + + protected override Expr VisitLeafCall(Call expr) { - TupleType tuple => new IR.Tuple(tuple.Fields.ToArray() - .Select(x => + if (CallValidator.ForceConvert.Contains(expr.Target.GetType().TypeHandle)) + { + if (!expr.CheckedShape.IsFixed) { - return ConstantOfShape(new[] { 1 }, Cast(0, ((TensorType)x).DType)); - }).ToArray()), - TensorType tensorType => (Expr)ConstantOfShape(new[] { 1 }, Cast(0, tensorType.DType)), - _ => throw new ArgumentOutOfRangeException("fusionBody"), - }; - return IR.F.Math.Require(false, failure, "input dim large than limit"); + _hasDynamic = true; + } + } + + return expr; + } } } internal record SegmentInfo(int InputIndex, int DimIndex, int[] Segments); + +public class FullBucket : FunctionPass +{ + protected override Task RunCoreAsync(BaseFunction input, RunPassContext ctx) + { + if (!SingleDimVar(CompileSession.CompileOptions.ShapeBucketOptions)) + { + throw new NotImplementedException("Not Implement multi DimVar for FullBucket"); + } + + var main = (Function)input; + var replaceItem = main.Parameters.ToArray().Select(param => (param, (Expr)new Var(param.CheckedType))).ToArray(); + var cloneMain = (Function)ReplaceClone(main, replaceItem); + var options = CompileSession.CompileOptions.ShapeBucketOptions; + var tmpFusion = new BucketFusion("stackvm", cloneMain.Body, cloneMain.Parameters, Array.Empty()); + var call = new Call(tmpFusion, main.Parameters.ToArray()); + var dimVarValues = MakeVarValuesForAllSegment(options); + var list = InputConfList(dimVarValues); + var shapeData = MakeShapeData(list, options); + + var context = new FusionBucketContext(call, tmpFusion, options, new ShapeExprCache(options.VarMap), 0, shapeData); + + var allFixedShapes = shapeData + .Select(x => + x.InputShapes.Select(iShape => iShape.AsTensor().ToArray().ToArray()).ToArray()).ToArray(); + for (int i = 0; i < shapeData.Length; i++) + { + for (int j = 0; j < allFixedShapes.Length; j++) + { + context.FixedShapeCache[j] = allFixedShapes[j]; + } + } + + var newBody = FusionBucket.Split(context); + foreach (var (oldVar, tmpVar) in replaceItem) + { + ReplaceExpr(newBody, tmpVar, oldVar); + } + + return Task.FromResult((BaseFunction)main.With(body: newBody)); + } + + private static FusionShapeData[] MakeShapeData((Var Key, int Value)[][] list, ShapeBucketOptions options) => + list.Select(seg => + { + var varValues = seg.ToDictionary(pair => pair.Key, pair => (IValue)Value.FromTensor(pair.Value)); + var inShape = options.VarMap.Select(pair => + { + var shapeExpr = pair.Key.CheckedShape.IsScalar + ? (Expr)Array.Empty() + : Stack(new IR.Tuple(pair.Value.Select(x => Cast(x, DataTypes.Int64)).ToArray()), 0); + + var shape = shapeExpr.Evaluate(varValues).AsTensor(); + return shape; + }).ToArray(); + return new FusionShapeData(Value.None, inShape.Select(Value.FromTensor).ToArray()); + }).ToArray(); + + private static (Var Key, int Value)[][] InputConfList(Dictionary dimVarValues) => + Enumerable.Range(0, dimVarValues.First().Value.Length).Select(i => + { + // 一组里面多个key seg + return dimVarValues.Select(pair => (pair.Key, Value: pair.Value[i])).ToArray(); + }).ToArray(); +} diff --git a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs index 3d9cbd389e..0275771c3c 100644 --- a/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs +++ b/src/Nncase.Passes/Rules/ShapeBucket/ShapeBucketHelper.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Reactive; using DryIoc.ImTools; +using Nncase.CodeGen; using Nncase.Diagnostics; using Nncase.IR; using Nncase.IR.Math; @@ -22,23 +23,18 @@ namespace Nncase.Passes.Rules.ShapeBucket; public static class CallValidator { - private static readonly HashSet ForceConvert = new() + public static readonly HashSet ForceConvert = new() { typeof(Conv2D).TypeHandle, + typeof(Conv2DTranspose).TypeHandle, typeof(MatMul).TypeHandle, - typeof(Unsqueeze).TypeHandle, - typeof(Squeeze).TypeHandle, - typeof(Cast).TypeHandle, - typeof(Unary).TypeHandle, typeof(Transpose).TypeHandle, typeof(Pad).TypeHandle, + typeof(Tile).TypeHandle, }; - // todo: add debug mode private static readonly HashSet MaybeDynamic = new() { - // typeof(SpaceToBatch).TypeHandle, - // typeof(BatchToSpace).TypeHandle, typeof(Concat).TypeHandle, typeof(Stack).TypeHandle, typeof(Binary).TypeHandle, @@ -46,7 +42,12 @@ public static class CallValidator typeof(Gather).TypeHandle, typeof(ShapeOf).TypeHandle, - // typeof(Reshape).TypeHandle, + typeof(Unsqueeze).TypeHandle, + typeof(Squeeze).TypeHandle, + typeof(Cast).TypeHandle, + typeof(Unary).TypeHandle, + + typeof(Reshape).TypeHandle, typeof(Expand).TypeHandle, typeof(ConstantOfShape).TypeHandle, typeof(Where).TypeHandle, @@ -60,22 +61,64 @@ public static class CallValidator public static bool IsMaybeDynamic(Expr target) => MaybeDynamic.Contains(target.GetType().TypeHandle); - public static bool IsForceConvert(Expr target) => ForceConvert.Contains(target.GetType().TypeHandle); + public static bool IsForceConvert(Expr target) => ForceConvert.Contains(target.GetType().TypeHandle) || target is ActivationOp; - public static bool ValidTarget(Expr target) + public static bool ValidTarget(Call call, bool greedy) { - if (target is ActivationOp) + var target = call.Target; + + var singleVar = + ShapeBucketHelper.SingleDimVar( + CompileSessionScope.GetCurrentThrowIfNull().CompileOptions.ShapeBucketOptions); + + if (target is Binary && call.Arguments.ToArray().OfType().Any()) { return true; } - if (IsMaybeDynamic(target) || IsForceConvert(target)) + if (IsForceConvert(target)) + { + return true; + } + + // dynamic reshape cause dynamic shape call + if (!greedy && IsDynamicReshape(call)) + { + return false; + } + + if (singleVar && greedy && IsMaybeDynamic(target)) { return true; } return false; } + + public static bool IsSimple(BucketFusion fusion) + { + var v = new OpCollector(); + v.Visit(fusion.Body); + foreach (var type in v.Counter.Keys) + { + if (CallValidator.ForceConvert.Contains(type)) + { + return false; + } + } + + foreach (var op in v.OpSet) + { + if (op is ActivationOp) + { + return false; + } + } + + return true; + } + + private static bool IsDynamicReshape(Call call) => call.Target is Reshape && call.Arguments[Reshape.Shape.Index] is not Const; } public static class ShapeBucketRegister @@ -91,16 +134,25 @@ public static void CheckShapeBucketOptions(ShapeBucketOptions options) } } - public static void MergeOp(IPassManager iPassManager) + public static bool HasNotBucketOp(Expr entry) + { + var counter = new OpCollector(); + counter.Visit(entry); + var invalid = new[] { typeof(Softmax), typeof(LayerNorm) }; + var canFullBucket = invalid.Any(x => counter.Counter.Keys.Contains(x.TypeHandle)); + return canFullBucket; + } + + public static void MergeOp(IPassManager iPassManager, bool greedy) { iPassManager.AddWithName("MergeNextCall").Configure(c => { - c.Add(); + c.Add(greedy); c.Add(); }); iPassManager.AddWithName("MergePrevCall").Configure(c => { - c.Add(); + c.Add(greedy); c.Add(); }); } @@ -108,6 +160,7 @@ public static void MergeOp(IPassManager iPassManager) public static void ToFusion(IPassManager p, bool onlyDynamic = false) => p.AddWithName("ToFusion").Configure(c => { + c.Add(); c.Add(onlyDynamic); c.Add(onlyDynamic); c.Add(onlyDynamic); @@ -124,21 +177,31 @@ public static void Bucket(IPassManager p) }); } - public static void Rebuild(IPassManager p) + public static void Rebuild(IPassManager p, bool singleVar) { // rebuild ToFusion(p, true); + MergeOp(p, false); + + // todo: lost to fusion + p.AddWithName("LostToFusion").Configure(p => + { + p.Add(true); + p.Add(true); + }); + + MergeFusion(p, singleVar, false); Bucket(p); } - public static void MergeFusion(IPassManager p, bool singleVar) + public static void MergeFusion(IPassManager p, bool singleVar, bool greedy) { if (!singleVar) { return; } - p.AddWithName("MergeBucketFusionPass"); + p.AddWithName("MergeBucketFusionPass", greedy); } public static void LostToFusion(IPassManager p, bool singleVar) => @@ -163,6 +226,7 @@ public static void ClearMarker(IPassManager p) => public static void Simplify(IPassManager p) => p.AddWithName("Simplify").Configure(c => { + c.Add(); c.Add(); c.Add(); c.Add(); @@ -174,19 +238,39 @@ public static void Simplify(IPassManager p) => c.Add(); c.Add(); c.Add(); + c.Add(); + c.Add(); + c.Add(); }); } public static class ShapeBucketHelper { + public static Dictionary ConcatDictionary(Dictionary memo, Dictionary exprValues) + where T : Expr + { + foreach (var (key, value) in exprValues) + { + memo[key] = value; + } + + return memo; + } + public static Dictionary MakeVarValuesForAllSegment(ShapeBucketOptions options) { int segmentCount = options.SegmentsCount; var varRange = options.RangeInfo; var varMap = options.VarMap; + var staticShape = false; var varAndInputAllSegment = varRange.ToDictionary(pair => pair.Key, pair => { var (min, max) = pair.Value; + if (staticShape) + { + return Enumerable.Range(min, max - min).ToArray(); + } + var segments = ComputeSegmentList(segmentCount, min, max); return segments; }); @@ -328,6 +412,31 @@ public static void DumpIR(Expr expr, string prefix, string? reletivePath = null, DumpScope.Current.DumpIR(expr, s, reletivePath); } } + + public static void CheckRepeat(Expr call) + { + // todo: 检查所有fusion里面的param有没有重复名字的 + // todo: 检查有没有fusion名字重复的 + var c = new CheckFusionCallVisitor(); + c.Visit(call); + c.Check(); + } + + public static void CheckErrorVar(Expr body, Var[] vars) + { + var f = new FindVar(); + f.Visit(body); + if (!f.Vars.All(vars.Contains)) + { + Console.WriteLine(string.Join(", ", f.Vars.Select(x => x.Name).ToArray())); + throw new InvalidOperationException("Has Invalid Var In Body"); + } + } + + public static bool SingleDimVar(ShapeBucketOptions options) + { + return options.VarMap.Values.SelectMany(x => x).OfType().ToHashSet().Count <= 1; + } } public class FindExpr : ExprVisitor @@ -380,7 +489,7 @@ protected override Expr DispatchVisit(Expr expr) public class FindVar : ExprVisitor { - public HashSet Vars { get; set; } = new(); + public HashSet Vars { get; } = new(); // todo: if visit call(VarFusion), then return EffectVar protected override Expr VisitLeafVar(Var expr) @@ -412,6 +521,34 @@ public sealed partial class ForceConvertOpChecker : RewriteRule } } +public class OpCollector : ExprVisitor +{ + public Dictionary Counter { get; } = new(); + + public HashSet OpSet { get; } = new(); + + protected override Expr VisitCall(Call expr) + { + if (expr.Target is Op op) + { + var handle = expr.Target.GetType().TypeHandle; + if (Counter.ContainsKey(handle)) + { + Counter[handle] += 1; + } + else + { + Counter[handle] = 1; + OpSet.Add(op); + } + } + + return base.VisitCall(expr); + } + + protected override Expr DefaultVisitLeaf(Expr expr) => expr; +} + internal static class ExprArrayExtension { public static IEnumerable OfNoConst(this IEnumerable args) @@ -433,25 +570,79 @@ public int GetHashCode(KeyValuePair obj) } } -internal class OpCounter : ExprVisitor +internal sealed class CheckFusionCallVisitor : ExprWalker { - private readonly Dictionary _counter = new(); + private readonly HashSet _callName = new(); + private readonly Dictionary _errorFusion = new(); - protected override Expr VisitCall(Call expr) + private readonly HashSet _fusionName = new(); + private readonly HashSet _repeatFusion = new(); + + private readonly HashSet _fusionParamsName = new(); + private readonly HashSet _repeatParamFusion = new(); + + public void Check() { - if (expr.Target is Op) + var error = false; + if (_errorFusion.Count != 0) { - var handle = expr.Target.GetType().TypeHandle; - if (_counter.ContainsKey(handle)) + error = true; + Console.WriteLine("errorFusion"); + } + + if (_repeatFusion.Count != 0) + { + error = true; + Print("repeatFusion not zero", _repeatFusion); + } + + if (_repeatParamFusion.Count != 0) + { + error = true; + Print("repeatParamFusion not zero", _repeatParamFusion); + } + + if (error) + { + throw new InvalidOperationException(); + } + } + + protected override Unit VisitLeafFusion(Fusion fusion) + { + // 可能有多个user啊,每次进来访问 + if (fusion is BucketFusion bf) + { + if (_fusionName.Contains(bf.Name)) { - _counter[handle] += 1; + _repeatFusion.Add(bf.Name); } else { - _counter[handle] = 1; + _fusionName.Add(bf.Name); } + + var parameters = bf.Parameters.ToArray(); + foreach (var parameter in parameters) + { + if (_fusionParamsName.Contains(parameter.Name)) + { + _repeatParamFusion.Add(parameter.Name); + } + } + + _fusionParamsName.UnionWith(parameters.Select(p => p.Name).ToArray()); } - return base.VisitCall(expr); + return default; + } + + private void Print(string name, HashSet list) + { + Console.WriteLine(name); + foreach (string s in list) + { + Console.WriteLine(s); + } } } diff --git a/src/Nncase.Passes/Rules/ShapeExpr/FoldBroadcastShape.cs b/src/Nncase.Passes/Rules/ShapeExpr/FoldBroadcastShape.cs new file mode 100644 index 0000000000..c45cf17931 --- /dev/null +++ b/src/Nncase.Passes/Rules/ShapeExpr/FoldBroadcastShape.cs @@ -0,0 +1,76 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Linq; +using Google.OrTools.Sat; +using Nncase.Evaluator; +using Nncase.IR; +using Nncase.IR.ShapeExpr; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.ShapeExpr; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.ShapeExpr; + +[RuleGenerator] +public partial class FoldBroadcastShapeConst : RewriteRule +{ + public override Pattern Pattern => IsCall(IsOp(), IsTuple("input")); + + private Expr? GetReplace(IR.Tuple input) + { + var constFields = input.Fields.ToArray().OfType().ToArray(); + if (constFields.Length == 0) + { + return null; + } + + if (constFields.Length == 1) + { + return null; + } + + var shape = IR.F.ShapeExpr.BroadcastShape(constFields.Select(x => (Expr)x.Value).ToArray()).Evaluate().AsTensor(); + var exprFields = input.Fields.ToArray().Where(x => x is not TensorConst).ToArray(); + + if (exprFields.Length == 0) + { + return shape; + } + + if ((shape.Shape.Count == 0 || (shape.Shape.Count == 1 && shape.Shape[0] == 1)) && exprFields.Length != 0) + { + return IR.F.ShapeExpr.BroadcastShape(exprFields); + } + + return IR.F.ShapeExpr.BroadcastShape(exprFields.Append(shape).ToArray()); + } +} + +[RuleGenerator] +public partial class FoldBroadcastShape : RewriteRule +{ + public override Pattern Pattern => IsCall(IsOp(), IsTuple("input")); + + private Expr? GetReplace(IR.Tuple input) + { + var broadcastShapeList = input.Fields.ToArray().Where(field => field is Call c && c.Target is BroadcastShape).ToArray(); + if (broadcastShapeList.Length > 0) + { + var newFields = input.Fields.ToArray().SelectMany(field => + { + if (field is Call { Target: BroadcastShape } c) + { + return ((Tuple)c.Arguments[0]).Fields.ToArray(); + } + + return new[] { field }; + }).ToArray(); + return IR.F.ShapeExpr.BroadcastShape(newFields); + } + + return null; + } +} diff --git a/src/Nncase.Passes/Rules/ShapeExpr/FoldGetItemShapeOf.cs b/src/Nncase.Passes/Rules/ShapeExpr/FoldGetItemShapeOf.cs new file mode 100644 index 0000000000..e8e6072a78 --- /dev/null +++ b/src/Nncase.Passes/Rules/ShapeExpr/FoldGetItemShapeOf.cs @@ -0,0 +1,41 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; +using GetItem = Nncase.IR.Tensors.GetItem; + +namespace Nncase.Passes.Rules.ShapeExpr +{ + [RuleGenerator] + public partial class FoldGetItemShapeOf : RewriteRule + { + public override Pattern Pattern => IsGetItem(null, "getItem", IsAlt(CastPattern, ShapeOfPattern), IsTensorConst("index")); + + public Pattern CastPattern => IsCast("cast", _ => true, ShapeOfPattern); + + public Pattern ShapeOfPattern => IsShapeOf(IsWildcard("input")); + + private Expr? GetReplace(Expr input, Tensor index, Call getItem) + { + DataType dt = DataTypes.Int64; + + if (getItem.Arguments[GetItem.Input.Index] is Call c && c.Target is IR.Tensors.Cast cast) + { + dt = cast.NewType; + } + + if (index.Shape.IsScalar) + { + var dim = input.CheckedShape[index.ToScalar()]; + return dim.IsFixed ? IR.F.Tensors.Cast(dim.FixedValue, dt) : null; + } + + return input; + } + } +} diff --git a/src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs b/src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs new file mode 100644 index 0000000000..737e371356 --- /dev/null +++ b/src/Nncase.Passes/Rules/ShapeExpr/FoldSplitShapeOf.cs @@ -0,0 +1,57 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Linq; +using Nncase.IR; +using Nncase.IR.Tensors; +using Nncase.PatternMatch; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; +using GetItem = Nncase.IR.Tensors.GetItem; + +namespace Nncase.Passes.Rules.ShapeExpr; + +// shape = ShapeOf(input) +// Stack(cast(shape[0]), cast(shape[1])) -> shape +[RuleGenerator] +public partial class FoldSplitShapeOf : RewriteRule +{ + public override Pattern Pattern => IsStack( + null, + "stack", + IsTuple( + "tuple", + new VArgsPattern( + list => + Enumerable.Range(0, list.Length) + .Select(_ => IsGetItem(InputPattern, IsTensorConst())) + .ToArray(), + "args")), + IsTensorConst(tensor => tensor.Value.ToScalar() == 0)); + + public Pattern InputPattern => IsShapeOf(IsWildcard()); + + private Expr? GetReplace(IR.Tuple tuple) + { + var getItemList = tuple.Fields.ToArray().OfType().ToArray(); + var getItemIndices = getItemList.Select(x => x.Arguments[GetItem.Index.Index]).OfType().Select(x => x.Value.ToScalar()).ToArray(); + if (getItemIndices.Length == 0) + { + return null; + } + + var shapeOf = getItemList[0].Arguments[GetItem.Input.Index]; + if (!shapeOf.CheckedShape[0].IsFixed) + { + return null; + } + + if (getItemIndices.SequenceEqual(Enumerable.Range(0, shapeOf.CheckedShape[0].FixedValue))) + { + return shapeOf; + } + + return null; + } +} diff --git a/src/Nncase.Passes/Rules/ShapeExpr/GatherToGetItem.cs b/src/Nncase.Passes/Rules/ShapeExpr/GatherToGetItem.cs new file mode 100644 index 0000000000..2a1453f3c9 --- /dev/null +++ b/src/Nncase.Passes/Rules/ShapeExpr/GatherToGetItem.cs @@ -0,0 +1,24 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Linq; +using Nncase.IR; +using Nncase.PatternMatch; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.ShapeExpr; + +[RuleGenerator] +public sealed partial class GatherToGetItem : RewriteRule +{ + // (Gather(input, 0, 0) -> GetItem(input) + public override Pattern Pattern => IsGather( + IsWildcard("input"), IsTensorConst("axis"), IsTensorConst("index") with { TypePattern = IsScalar() }); + + private Expr? GetReplace(Expr input, int axis, int index) + { + return input[index]; + } +} diff --git a/src/Nncase.Passes/Rules/ShapeExpr/SliceToGetItem.cs b/src/Nncase.Passes/Rules/ShapeExpr/SliceToGetItem.cs new file mode 100644 index 0000000000..dd158f54ff --- /dev/null +++ b/src/Nncase.Passes/Rules/ShapeExpr/SliceToGetItem.cs @@ -0,0 +1,44 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using Nncase.IR; +using Nncase.IR.Math; +using Nncase.IR.Tensors; +using Nncase.Passes; +using Nncase.PatternMatch; +using static Nncase.IR.F.Math; +using static Nncase.IR.F.Tensors; +using static Nncase.IR.TypePatternUtility; +using static Nncase.PatternMatch.F.Math; +using static Nncase.PatternMatch.F.Tensors; +using static Nncase.PatternMatch.Utility; + +namespace Nncase.Passes.Rules.Neutral; + +// Slice(shape, 1, 2, 1, 1) -> shape[1] +[RuleGenerator] +public partial class SliceToGetItem : RewriteRule +{ + public override Pattern Pattern => IsSqueeze( + IsSlice( + IsWildcard("input") with { TypePattern = HasRank(1) }, + IsTensorConst("begins"), + IsTensorConst("ends"), + IsTensorConst("axes"), + IsTensorConst("strides", strides => strides.Value.ToArray()[0] == 1)), + IsTensorConst("dims")); + + private Expr? GetReplace(Expr input, int[] begins, int[] ends) + { + if ((ends[0] - begins[0]) == 1) + { + return input[begins[0]]; + } + + return null; + } +} diff --git a/src/Nncase.Tests.TestFixture/TransformTestBase.cs b/src/Nncase.Tests.TestFixture/TransformTestBase.cs index 44abc38dac..9f1f8435c3 100644 --- a/src/Nncase.Tests.TestFixture/TransformTestBase.cs +++ b/src/Nncase.Tests.TestFixture/TransformTestBase.cs @@ -88,6 +88,7 @@ public Expr TestMatchedCore(Function pre, IReadOnlyDictionary? feed public Expr TestMatchedCore(Expr pre, IReadOnlyDictionary? feeds = null, params IRewriteRule[] rules) { + pre.InferenceType(); Assert.True(pre.InferenceType(), "TestInferFailed:" + pre.CheckedType); if (rules.Length == 0) { diff --git a/src/Nncase.Tests/Importer/UnitTestUtil.cs b/src/Nncase.Tests/Importer/UnitTestUtil.cs index 9e1bc419a5..30622d79d3 100644 --- a/src/Nncase.Tests/Importer/UnitTestUtil.cs +++ b/src/Nncase.Tests/Importer/UnitTestUtil.cs @@ -34,23 +34,6 @@ public void TestZeroTensor() Assert.Equal(new TensorConst(Tensor.From(new[] { 0 })), Util.ZeroTensor()); } - [Fact] - public void TestGetPaddings() - { - var input = OrtKI.Random(1, 2, 4, 8).ToTensor(); - var weights = OrtKI.Random(3, 3, 2, 2).ToTensor(); - var stride = new long[] { 1, 1, 1, 1 }; - var dilation = new long[] { 1, 1 }; - var expr = Util.GetPaddings(input, weights, stride, dilation, true); - - var (inH, inW) = Util.GetHW(input); - var (fH, fW) = Util.GetHW(weights); - var padH = Util.GetWindowedPadding(inH, fH, (int)stride[0], (int)dilation[0], true); - var padW = Util.GetWindowedPadding(inW, fW, (int)stride[1], (int)dilation[1], true); - var expect = Util.ConcatPadding(padH, padW); - Assert.Equal(expect, expr); - } - [Fact] public void TestComputeSplit() { diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestFoldSqueeze.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldSqueeze.cs new file mode 100644 index 0000000000..3f30ee7dc8 --- /dev/null +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestFoldSqueeze.cs @@ -0,0 +1,33 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System.Collections.Generic; +using Nncase.IR; +using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; +using Xunit; +using static Nncase.IR.F.Tensors; + +namespace Nncase.Tests.Rules.NeutralTest; + +[AutoSetupTestMethod(InitSession = true)] +public class UnitTestFoldSqueeze : TransformTestBase +{ + [Fact] + public void TestFoldSqueezeUnsqueeze() + { + var input = Testing.Rand(1, 3, 24); + var inputVar = new Var(new TensorType(input.ElementType, input.Shape)); + var expr = Squeeze(Unsqueeze(inputVar, new[] { -3 }), new[] { 1 }); + TestMatched(expr, new Dictionary { { inputVar, Value.FromTensor(input) } }); + } + + [Fact] + public void TestFoldUnsqueezeSqueeze() + { + var input = Testing.Rand(1, 1, 3, 24); + var inputVar = new Var(new TensorType(input.ElementType, input.Shape)); + var expr = Unsqueeze(Squeeze(inputVar, new[] { 1 }), new[] { -3 }); + TestMatched(expr, new Dictionary { { inputVar, Value.FromTensor(input) } }); + } +} diff --git a/src/Nncase.Tests/Rules/Neutral/UnitTestSplitSpaceToBatch.cs b/src/Nncase.Tests/Rules/Neutral/UnitTestSplitSpaceToBatch.cs new file mode 100644 index 0000000000..4ef0186aad --- /dev/null +++ b/src/Nncase.Tests/Rules/Neutral/UnitTestSplitSpaceToBatch.cs @@ -0,0 +1,39 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.IR.NN; +using Nncase.Passes.Rules.Neutral; +using Nncase.Tests.TestFixture; +using Xunit; +using static Nncase.IR.F.NN; + +namespace Nncase.Tests.Rules.NeutralTest; + +[AutoSetupTestMethod(InitSession = true)] +public class UnitTestSpaceToBatch : TransformTestBase +{ + [Fact] + public void TestSplitSpaceToBatch() + { + var i = SpaceToBatch(Testing.Rand(1, 206, 192), new[] { 3 }, new[,] { { 0, 1 } }); + var originEvaluateResult = i.Evaluate(); + var newBody = TestMatched(i); + var ev = newBody.Evaluate(); + _ = Comparator.CosSimilarity(originEvaluateResult, ev); + var dumpDir = Dumpper.Directory; + var (_, kmodel) = Testing.BuildKModel("kmodel", new IRModule(new Function(newBody, System.Array.Empty())), CompileSession); + var inputs = System.Array.Empty(); + var result = Testing.RunKModel(kmodel, dumpDir, inputs); + var v = Comparator.CosSimilarity(ev, result); + Assert.True(v[0] > 0.99f); + } + + [Fact] + public void TestSplitBatchToSpace() + { + var i = BatchToSpace(Testing.Rand(3, 192, 67), new[] { 3 }, new[,] { { 0, 1 } }); + TestMatched(i); + } +} diff --git a/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs b/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs index b014ad2d6f..a65c520bf1 100644 --- a/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs +++ b/src/Nncase.Tests/Rules/ShapeBucket/ShapeBucketTest.cs @@ -24,6 +24,7 @@ using static Nncase.IR.F.NN; using static Nncase.IR.F.Tensors; using static Nncase.Tests.Rules.ShapeBucket.ShapeBucketTestHelper; +using Reshape = Nncase.IR.Tensors.Reshape; namespace Nncase.Tests.Rules.ShapeBucket; @@ -60,7 +61,157 @@ public void TestBucketPad() Assert.True(cos > 0.999); } - private Var Scalar(string name) => new Var(new TensorType(DataTypes.Int32, Shape.Scalar)); + [Fact] + public async Task TestSingleVarFusionBucket() + { + var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var dimVar = Scalar("dimVar"); + CompileOptions.ShapeBucketOptions.Enable = true; + CompileOptions.ShapeBucketOptions.SegmentsCount = 2; + CompileOptions.ShapeBucketOptions.RangeInfo = + new Dictionary { { "dimVar", (1, 20) } }; + CompileOptions.ShapeBucketOptions.VarMap = new Dictionary { { mainVar, new Expr[] { 1, dimVar, 24, 24 } } }; + + var input = Testing.Rand(1, 3, 24, 24); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var f = new BucketFusion("MatMul_0", "stackvm", IR.F.Math.MatMul(fusionVar, fusionVar), new[] { fusionVar }, new[] { dimVar }); + var main = new Function("main", new Call(f, mainVar), mainVar); + var shape = new Dictionary(); + await new RecordFusionShape(shape).RunAsync(main, new()); + TestMatchedCore( + main.Body!, + new Dictionary { { mainVar, Value.FromTensor(input) } }, + new FusionBucket(shape)); + } + + [Fact] + public async Task TestRebuild() + { + var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var dimVar = Scalar("dimVar"); + CompileOptions.ShapeBucketOptions.Enable = true; + CompileOptions.ShapeBucketOptions.SegmentsCount = 2; + CompileOptions.ShapeBucketOptions.RangeInfo = + new Dictionary { { "dimVar", (1, 20) } }; + CompileOptions.ShapeBucketOptions.VarMap = new Dictionary { { mainVar, new Expr[] { 1, dimVar, 24, 24 } } }; + + var input = Testing.Rand(1, 3, 24, 24); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var shapeVar = new Var(new TensorType(DataTypes.Int64, new[] { 4 })); + var body = IR.F.Math.MatMul(Reshape(fusionVar, shapeVar), fusionVar); + var f = new BucketFusion("MatMul_0", "stackvm", body, new[] { fusionVar, shapeVar }, new[] { dimVar }); + var main = new Function("main", new Call(f, mainVar, Stack(new IR.Tuple(new[] { 1L, ShapeOf(mainVar)[1], 24L, 24L }), 0))); + var shape = new Dictionary(); + await new RecordFusionShape(shape).RunAsync(main, new()); + var newBody = TestMatchedCore( + main.Body!, + new Dictionary { { mainVar, Value.FromTensor(input) } }, + new FusionBucket(shape)); + Assert.True(newBody is Call { Target: IR.Math.MatMul }); + } + + [Fact] + public async Task TestTupleOutput() + { + var mainVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var dimVar = Scalar("dimVar"); + CompileOptions.ShapeBucketOptions.Enable = true; + CompileOptions.ShapeBucketOptions.SegmentsCount = 2; + CompileOptions.ShapeBucketOptions.RangeInfo = + new Dictionary { { "dimVar", (1, 20) } }; + CompileOptions.ShapeBucketOptions.VarMap = new Dictionary { { mainVar, new Expr[] { 1, dimVar, 24, 24 } } }; + + var input = Testing.Rand(1, 3, 24, 24); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var mm = IR.F.Math.MatMul(fusionVar, fusionVar); + var body = new IR.Tuple(mm, mm); + var f = new BucketFusion("MatMul_0", "stackvm", body, new[] { fusionVar }, new[] { dimVar }); + var main = new Function("main", new Call(f, mainVar), mainVar); + var shape = new Dictionary(); + await new RecordFusionShape(shape).RunAsync(main, new()); + TestMatchedCore( + main.Body!, + new Dictionary { { mainVar, Value.FromTensor(input) } }, + new FusionBucket(shape)); + } + + [Fact] + public async Task TestDoubleVarFusionBucket() + { + var mainVarLhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var mainVarRhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var dimVar1 = Scalar("dimVar1"); + var dimVar2 = Scalar("dimVar2"); + CompileOptions.ShapeBucketOptions.Enable = true; + CompileOptions.ShapeBucketOptions.SegmentsCount = 5; + CompileOptions.ShapeBucketOptions.RangeInfo = + new Dictionary + { + { "dimVar1", (1, 20) }, + { "dimVar2", (1, 20) }, + }; + CompileOptions.ShapeBucketOptions.VarMap = new Dictionary + { + { mainVarLhs, new Expr[] { 1, dimVar1, 24, 24 } }, + { mainVarRhs, new Expr[] { 1, dimVar2, 24, 24 } }, + }; + + var inputLhs = Testing.Rand(1, 3, 24, 24); + var inputRhs = Testing.Rand(1, 3, 24, 24); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var f = new BucketFusion("MatMul_0", "stackvm", IR.F.Math.MatMul(fusionVar, fusionVar), new[] { fusionVar }, new[] { dimVar1 }); + var main = new Function("main", new Call(f, mainVarLhs, mainVarRhs), mainVarLhs, mainVarRhs); + var shape = new Dictionary(); + await new RecordFusionShape(shape).RunAsync(main, new()); + TestMatchedCore( + main.Body!, + new Dictionary + { + { mainVarLhs, Value.FromTensor(inputLhs) }, + { mainVarRhs, Value.FromTensor(inputRhs) }, + }, + new FusionBucket(shape)); + } + + [Fact] + public async Task TestDoubleVarWithMultiDimEffect() + { + var mainVarLhs = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var mainVarRhs = new Var(new TensorType(DataTypes.Float32, new[] { Dimension.Unknown, 1, 24, 24 })); + var dimVar1 = Scalar("dimVar1"); + var dimVar2 = Scalar("dimVar2"); + CompileOptions.ShapeBucketOptions.Enable = true; + CompileOptions.ShapeBucketOptions.SegmentsCount = 5; + CompileOptions.ShapeBucketOptions.RangeInfo = + new Dictionary + { + { "dimVar1", (1, 20) }, + { "dimVar2", (1, 20) }, + }; + CompileOptions.ShapeBucketOptions.VarMap = new Dictionary + { + { mainVarLhs, new Expr[] { 1, dimVar1, 24, 24 } }, + { mainVarRhs, new Expr[] { dimVar2, 1, 24, 24 } }, + }; + + var inputLhs = Testing.Rand(1, 3, 24, 24); + var inputRhs = Testing.Rand(3, 1, 24, 24); + var fusionVar = new Var(new TensorType(DataTypes.Float32, new[] { 1, Dimension.Unknown, 24, 24 })); + var f = new BucketFusion("MatMul_0", "stackvm", IR.F.Math.MatMul(fusionVar, fusionVar), new[] { fusionVar }, new[] { dimVar1 }); + var main = new Function("main", new Call(f, mainVarLhs, mainVarRhs), mainVarLhs, mainVarRhs); + var shape = new Dictionary(); + await new RecordFusionShape(shape).RunAsync(main, new()); + TestMatchedCore( + main.Body!, + new Dictionary + { + { mainVarLhs, Value.FromTensor(inputLhs) }, + { mainVarRhs, Value.FromTensor(inputRhs) }, + }, + new FusionBucket(shape)); + } + + private Var Scalar(string name) => new Var(name, new TensorType(DataTypes.Int32, Shape.Scalar)); } [AutoSetupTestMethod(InitSession = true)] @@ -121,7 +272,7 @@ public void TestBodyMultiInputMergeRight() }); } - [Fact(Skip = "Reshape is not stable")] + [Fact] public void TestPrevMultiInputForDynamicReshape() { // fusion @@ -199,7 +350,7 @@ public void TestAfterMergeSameInput() TestMatched(c, new Dictionary { { inputVar, Value.FromTensor(input) } }); } - [Fact(Skip = "Reshape is not stable")] + [Fact] public void TestMatMulReshape() { // 左边的表达式是右边表达式的一部分 @@ -354,6 +505,17 @@ public void TestMergeInputInTupleWhichHadBeMerged() Assert.Equal(1, call.Arguments.Length); } + [Fact] + public async Task TestFoldNopTuple() + { + var input = (Expr)new[] { 1, 3, 24, 48 }; + var t = new IR.Tuple(new[] { input[0], input[1] }); + var b = t[0] + t[1]; + Dumpper.DumpIR(b, "b"); + var result = await new FoldNopTuple().RunAsync(new Function(b), new()); + Dumpper.DumpIR(result, "result"); + } + private static BucketFusion GetResultFusion(Expr result) { var fusion = (BucketFusion)((Call)result).Target; diff --git a/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldBroadcastShape.cs b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldBroadcastShape.cs new file mode 100644 index 0000000000..dfb1b8c415 --- /dev/null +++ b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldBroadcastShape.cs @@ -0,0 +1,31 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using Nncase.IR; +using Nncase.Passes.Rules.ShapeExpr; +using Nncase.Tests.TestFixture; +using Xunit; +using static Nncase.IR.F.ShapeExpr; + +namespace Nncase.Tests.Rules.ShapeExpr; + +[AutoSetupTestMethod(InitSession = true)] +public class UnitTestFoldBroadcastShape : TransformTestBase +{ + [Fact] + public void TestFoldBroadcastShape() + { + var b1 = BroadcastShape(new[] { (Expr)Tensor.From(new[] { 1, 3 }), Tensor.From(new[] { 1 }) }); + var b2 = BroadcastShape(new[] { (Expr)b1, Tensor.From(new[] { 1, 1 }) }); + TestMatched(b2); + } + + [Fact] + public void TestFoldBroadcastShapeConst() + { + var input = Testing.Rand(1, 3, 1, 1); + var b = BroadcastShape(new Expr[] { new[] { 1, 3 }, Array.Empty(), IR.F.Tensors.ShapeOf(input) }); + TestMatched(b); + } +} diff --git a/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs new file mode 100644 index 0000000000..bab7175f5a --- /dev/null +++ b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldGetItemShapeOf.cs @@ -0,0 +1,45 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.Passes.Rules.Neutral; +using Nncase.Passes.Rules.ShapeExpr; +using Xunit; +using Xunit.Abstractions; +using static Nncase.IR.F.Tensors; + +namespace Nncase.Tests.Rules.ShapeExpr; + +[TestFixture.AutoSetupTestMethod(InitSession = true)] +public class UnitTestFoldGetItemShapeOf : TransformTestBase +{ + [Fact] + public void TestFoldGetItemShapeOf() + { + var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown, 24 })); + var data = Testing.Rand(1, 3, 24, 24); + var dict = new Dictionary { { input, Value.FromTensor(data) } }; + TestMatched(ShapeOf(input)[1], dict); + } + + [Fact] + public void TestFoldGetItemShapeOfWithCast() + { + var input = new Var(new TensorType(DataTypes.Float32, new[] { 1, 3, Dimension.Unknown, 24 })); + var data = Testing.Rand(1, 3, 24, 24); + var dict = new Dictionary { { input, Value.FromTensor(data) } }; + TestMatched(Cast(ShapeOf(input), DataTypes.Int32)[1], dict); + } + + [Fact] + public void TestFoldGetItemShapeOfWithDynamic() + { + var input = new Var(new TensorType(DataTypes.Int32, new[] { 1, 3, Dimension.Unknown, 24 })); + TestNotMatch(ShapeOf(input)[2]); + } +} diff --git a/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldSplitShapeOf.cs b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldSplitShapeOf.cs new file mode 100644 index 0000000000..845755e2f8 --- /dev/null +++ b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestFoldSplitShapeOf.cs @@ -0,0 +1,28 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.Passes.Rules.ShapeExpr; +using Xunit; +using Xunit.Abstractions; +using static Nncase.IR.F.Tensors; + +namespace Nncase.Tests.Rules.ShapeExpr; + +[TestFixture.AutoSetupTestMethod(InitSession = true)] +public class UnitTestFoldSplitShapeOf : TransformTestBase +{ + [Fact] + public void TestFoldSplitShapeOf() + { + var input = Testing.Rand(1, 3, 24, 24); + var shape = ShapeOf(input); + var newShape = Stack(new IR.Tuple(shape[0], shape[1], shape[2], shape[3]), 0); + TestMatched(newShape); + } +} diff --git a/src/Nncase.Tests/Rules/ShapeExpr/UnitTestGatherToGetItem.cs b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestGatherToGetItem.cs new file mode 100644 index 0000000000..4d2a78bf47 --- /dev/null +++ b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestGatherToGetItem.cs @@ -0,0 +1,35 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.Passes.Rules.ShapeExpr; +using Xunit; +using Xunit.Abstractions; +using static Nncase.IR.F.Tensors; + +namespace Nncase.Tests.Rules.ShapeExpr; + +[TestFixture.AutoSetupTestMethod(InitSession = true)] +public class UnitTestGatherToGetItem : TransformTestBase +{ + [Fact] + public void TestGatherToGetItem() + { + var input = new[] { 1, 2, 3, 4 }; + var gather = Gather(input, 0, 0); + TestMatched(gather); + } + + [Fact] + public void TestIndexNotScalar() + { + var input = new[] { 1, 2, 3, 4 }; + var gather = Gather(input, 0, new[] { 1, 2 }); + TestNotMatch(gather); + } +} diff --git a/src/Nncase.Tests/Rules/ShapeExpr/UnitTestSliceToGetItem.cs b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestSliceToGetItem.cs new file mode 100644 index 0000000000..04c0d1c71b --- /dev/null +++ b/src/Nncase.Tests/Rules/ShapeExpr/UnitTestSliceToGetItem.cs @@ -0,0 +1,36 @@ +// Copyright (c) Canaan Inc. All rights reserved. +// Licensed under the Apache license. See LICENSE file in the project root for full license information. + +using System; +using System.Collections.Generic; +using System.Linq; +using System.Threading.Tasks; +using Nncase.Diagnostics; +using Nncase.IR; +using Nncase.Passes.Rules.Neutral; +using Nncase.Passes.Rules.ShapeExpr; +using Xunit; +using Xunit.Abstractions; +using static Nncase.IR.F.Tensors; + +namespace Nncase.Tests.Rules.ShapeExpr; + +[TestFixture.AutoSetupTestMethod(InitSession = true)] +public class UnitTestSliceToGetItem : TransformTestBase +{ + [Fact] + public void TestSliceToGetItem() + { + var input = new[] { 1, 2, 3, 4 }; + var gather = Squeeze(Slice(input, new[] { 1 }, new[] { 2 }, 1), new[] { 0 }); + TestMatched(gather); + } + + [Fact] + public void TestTooLong() + { + var input = new[] { 1, 2, 3, 4 }; + var gather = Slice(input, new[] { 1 }, new[] { 3 }, 1); + TestNotMatch(gather); + } +} diff --git a/test.md b/test.md deleted file mode 100644 index 05f5b33bb5..0000000000 --- a/test.md +++ /dev/null @@ -1,75 +0,0 @@ -| test | 完成情况 | -|-----------------------------|---------------| -| test_acosh.py | unary | -| test_and.py | err | -| test_argmax.py | err | -| test_argmin.py | err | -| test_asinh.py | unary | -| test_batchnorm.py | n | -| test_binary.py | tuple output | -| test_cast.py | numpy no bf16 | -| test_celu.py | y | -| test_clip.py | n | -| test_concat.py | tuple input | -| test_constantofshape.py | y | -| test_conv.py | err | -| test_conv_transpose.py | n | -| test_cosh.py | unary | -| test_cumsum.py | err | -| test_depthtospace.py | n | -| test_dequantizelinear.py | import | -| test_expand.py | n | -| test_flatten.py | n | -| test_gather_nd.py | err | -| test_gather.py | err | -| test_gemm2.py | x | -| test_gemm.py | x | -| test_hardmax.py | n | -| test_hardsigmoid.py | n | -| test_hardswish.py | y | -| test_identity.py | y | -| test_instancenorm.py | n | -| test_leakyrelu.py | y | -| test_logsoftmax.py | n | -| test_lrn.py | n | -| test_lstm.py | n | -| test_matmul.py | y | -| test_onehot.py | err | -| test_pad.py | err | -| test_pool.py | tuple output | -| test_prelu.py | y | -| test_quantizelinear.py | import | -| test_random_normal_like.py | n | -| test_random_normal.py | n | -| test_random_uniform_like.py | n | -| test_random_uniform.py | n | -| test_reducel1.py | y | -| test_reducel2.py | y | -| test_reduce_log_sum_exp.py | y | -| test_reduce_log_sum.py | y | -| test_reduce.py | y | -| test_reduce_sum_square.py | y | -| test_relu.py | y | -| test_reshape.py | | -| test_resize.py | | -| test_reverse_sequence.py | | -| test_selu.py | | -| test_shape.py | | -| test_sigmoid.py | | -| test_sign.py | | -| test_sinh.py | unary | -| test_size.py | | -| test_slice.py | | -| test_slice_to_conv2d.py | | -| test_softmax.py | | -| test_softplus.py | | -| test_softsign.py | | -| test_spacetodepth.py | | -| test_split.py | | -| test_squeeze.py | n | -| test_sum.py | | -| test_tile.py | n | -| test_transpose.py | err | -| test_unary.py | tuple output | -| test_unsqueeze.py | n | -| test_where.py | n | \ No newline at end of file diff --git a/testEnvironments.json b/testEnvironments.json deleted file mode 100644 index 2c5e92936e..0000000000 --- a/testEnvironments.json +++ /dev/null @@ -1,17 +0,0 @@ -{ - "version": "1", - "environments": [ - // 请参阅 https://aka.ms/remotetesting 获取更多信息 - // 了解如何配置远程环境。 - //{ - // "name": "WSL Ubuntu", - // "type": "wsl", - // "wslDistribution": "Ubuntu" - //}, - //{ - // "name": "Docker dotnet/sdk", - // "type": "docker", - // "dockerImage": "mcr.microsoft.com/dotnet/sdk" - //} - ] -} \ No newline at end of file