diff --git a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs
index b900856080..2cfd50e458 100644
--- a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs
+++ b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/CodeGenVisitor.g.cs
@@ -1,6 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39 +08:00. */
+/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:30 +08:00. */
using System;
using System.Collections.Generic;
@@ -271,9 +271,24 @@ private void EmitTensorCall(Op op)
case IR.ShapeExpr.Conv2DTransposeShape top:
Emitter.T.Conv2DTransposeShape();
break;
+ case IR.ShapeExpr.GetPaddings top:
+ Emitter.T.GetPaddings();
+ break;
case IR.ShapeExpr.MatMulShape top:
Emitter.T.MatMulShape();
break;
+ case IR.ShapeExpr.ReshapeShape top:
+ Emitter.T.ReshapeShape();
+ break;
+ case IR.ShapeExpr.SqueezeShape top:
+ Emitter.T.SqueezeShape();
+ break;
+ case IR.ShapeExpr.TransposeShape top:
+ Emitter.T.TransposeShape();
+ break;
+ case IR.ShapeExpr.UnsqueezeShape top:
+ Emitter.T.UnsqueezeShape();
+ break;
case IR.Random.Normal top:
Emitter.T.Normal(top.Type);
break;
diff --git a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs
index d69739ded8..6e2184c5ea 100644
--- a/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs
+++ b/modules/Nncase.Modules.StackVM/CodeGen/StackVM/StackVMEmitter.g.cs
@@ -1,6 +1,6 @@
// Copyright (c) Canaan Inc. All rights reserved.
// Licensed under the Apache license. See LICENSE file in the project root for full license information.
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39 +08:00. */
+/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:30 +08:00. */
using System;
using System.Collections.Generic;
@@ -876,52 +876,59 @@ public void GetItem()
}
///.
- public void Hardmax()
+ public void GetPaddings()
{
_emitter.Write((byte)100);
_emitter.Write((ushort)32);
}
///.
- public void HardSigmoid()
+ public void Hardmax()
{
_emitter.Write((byte)100);
_emitter.Write((ushort)33);
}
///.
- public void HardSwish()
+ public void HardSigmoid()
{
_emitter.Write((byte)100);
_emitter.Write((ushort)34);
}
///.
- public void IndexOf()
+ public void HardSwish()
{
_emitter.Write((byte)100);
_emitter.Write((ushort)35);
}
///.
- public void InstanceNormalization()
+ public void IndexOf()
{
_emitter.Write((byte)100);
_emitter.Write((ushort)36);
}
///.
- public void L2Normalization()
+ public void InstanceNormalization()
{
_emitter.Write((byte)100);
_emitter.Write((ushort)37);
}
///.
- public void LayerNorm(int axis, float epsilon)
+ public void L2Normalization()
{
_emitter.Write((byte)100);
_emitter.Write((ushort)38);
+ }
+
+ ///.
+ public void LayerNorm(int axis, float epsilon)
+ {
+ _emitter.Write((byte)100);
+ _emitter.Write((ushort)39);
_emitter.Write(axis);
_emitter.Write(epsilon);
}
@@ -930,35 +937,35 @@ public void LayerNorm(int axis, float epsilon)
public void LeakyRelu()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)39);
+ _emitter.Write((ushort)40);
}
///.
public void LogSoftmax()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)40);
+ _emitter.Write((ushort)41);
}
///.
public void LpNormalization()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)41);
+ _emitter.Write((ushort)42);
}
///.
public void LRN()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)42);
+ _emitter.Write((ushort)43);
}
///.
public void LSTM(LSTMDirection direction, LSTMLayout layout, string[] activations)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)43);
+ _emitter.Write((ushort)44);
_emitter.Write((int)direction);
_emitter.Write((int)layout);
_emitter.Write(activations);
@@ -968,21 +975,21 @@ public void LSTM(LSTMDirection direction, LSTMLayout layout, string[] activation
public void MatMul()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)44);
+ _emitter.Write((ushort)45);
}
///.
public void MatMulShape()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)45);
+ _emitter.Write((ushort)46);
}
///.
public void Normal(DataType type)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)46);
+ _emitter.Write((ushort)47);
_emitter.Write(type);
}
@@ -990,7 +997,7 @@ public void Normal(DataType type)
public void NormalLike(DataType type)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)47);
+ _emitter.Write((ushort)48);
_emitter.Write(type);
}
@@ -998,7 +1005,7 @@ public void NormalLike(DataType type)
public void OneHot(OneHotMode oneHotMode)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)48);
+ _emitter.Write((ushort)49);
_emitter.Write((byte)oneHotMode);
}
@@ -1006,7 +1013,7 @@ public void OneHot(OneHotMode oneHotMode)
public void Pad(PadMode padMode)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)49);
+ _emitter.Write((ushort)50);
_emitter.Write((byte)padMode);
}
@@ -1014,21 +1021,21 @@ public void Pad(PadMode padMode)
public void PRelu()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)50);
+ _emitter.Write((ushort)51);
}
///.
public void Prod()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)51);
+ _emitter.Write((ushort)52);
}
///.
public void Quantize(DataType targetType)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)52);
+ _emitter.Write((ushort)53);
_emitter.Write(targetType);
}
@@ -1036,7 +1043,7 @@ public void Quantize(DataType targetType)
public void QuantParamOf(QuantMode quantMode)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)53);
+ _emitter.Write((ushort)54);
_emitter.Write((int)quantMode);
}
@@ -1044,14 +1051,14 @@ public void QuantParamOf(QuantMode quantMode)
public void Range()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)54);
+ _emitter.Write((ushort)55);
}
///.
public void RangeOf(bool isRangeOfWeight)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)55);
+ _emitter.Write((ushort)56);
_emitter.Write(isRangeOfWeight);
}
@@ -1059,14 +1066,14 @@ public void RangeOf(bool isRangeOfWeight)
public void Rank()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)56);
+ _emitter.Write((ushort)57);
}
///.
public void Reduce(ReduceOp reduceOp)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)57);
+ _emitter.Write((ushort)58);
_emitter.Write((byte)reduceOp);
}
@@ -1074,7 +1081,7 @@ public void Reduce(ReduceOp reduceOp)
public void ReduceArg(ReduceArgOp reduceArgOp, DataType destType)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)58);
+ _emitter.Write((ushort)59);
_emitter.Write((byte)reduceArgOp);
_emitter.Write(destType);
}
@@ -1083,7 +1090,7 @@ public void ReduceArg(ReduceArgOp reduceArgOp, DataType destType)
public void ReduceWindow2D(ReduceOp reduceOp)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)59);
+ _emitter.Write((ushort)60);
_emitter.Write((byte)reduceOp);
}
@@ -1091,21 +1098,21 @@ public void ReduceWindow2D(ReduceOp reduceOp)
public void Relu()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)60);
+ _emitter.Write((ushort)61);
}
///.
public void Relu6()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)61);
+ _emitter.Write((ushort)62);
}
///.
public void Require(string message, bool canFoldConstCall)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)62);
+ _emitter.Write((ushort)63);
_emitter.Write(message);
_emitter.Write(canFoldConstCall);
}
@@ -1114,14 +1121,21 @@ public void Require(string message, bool canFoldConstCall)
public void Reshape()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)63);
+ _emitter.Write((ushort)64);
+ }
+
+ ///.
+ public void ReshapeShape()
+ {
+ _emitter.Write((byte)100);
+ _emitter.Write((ushort)65);
}
///.
public void ResizeImage(ImageResizeMode resizeMode, ImageResizeTransformationMode transformationMode, ImageResizeNearestMode nearestMode, bool isTFResize)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)64);
+ _emitter.Write((ushort)66);
_emitter.Write((byte)resizeMode);
_emitter.Write((int)transformationMode);
_emitter.Write((int)nearestMode);
@@ -1132,147 +1146,161 @@ public void ResizeImage(ImageResizeMode resizeMode, ImageResizeTransformationMod
public void ReverseSequence()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)65);
+ _emitter.Write((ushort)67);
}
///.
public void ScatterND()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)66);
+ _emitter.Write((ushort)68);
}
///.
public void Select()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)67);
+ _emitter.Write((ushort)69);
}
///.
public void Selu()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)68);
+ _emitter.Write((ushort)70);
}
///.
public void ShapeOf()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)69);
+ _emitter.Write((ushort)71);
}
///.
public void Sigmoid()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)70);
+ _emitter.Write((ushort)72);
}
///.
public void SizeOf()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)71);
+ _emitter.Write((ushort)73);
}
///.
public void Slice()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)72);
+ _emitter.Write((ushort)74);
}
///.
public void Softmax()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)73);
+ _emitter.Write((ushort)75);
}
///.
public void Softplus()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)74);
+ _emitter.Write((ushort)76);
}
///.
public void Softsign()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)75);
+ _emitter.Write((ushort)77);
}
///.
public void SpaceToBatch()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)76);
+ _emitter.Write((ushort)78);
}
///.
public void Split()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)77);
+ _emitter.Write((ushort)79);
}
///.
public void Squeeze()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)78);
+ _emitter.Write((ushort)80);
+ }
+
+ ///.
+ public void SqueezeShape()
+ {
+ _emitter.Write((byte)100);
+ _emitter.Write((ushort)81);
}
///.
public void Stack()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)79);
+ _emitter.Write((ushort)82);
}
///.
public void Swish()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)80);
+ _emitter.Write((ushort)83);
}
///.
public void Tile()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)81);
+ _emitter.Write((ushort)84);
}
///.
public void TopK()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)82);
+ _emitter.Write((ushort)85);
}
///.
public void Transpose()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)83);
+ _emitter.Write((ushort)86);
+ }
+
+ ///.
+ public void TransposeShape()
+ {
+ _emitter.Write((byte)100);
+ _emitter.Write((ushort)87);
}
///.
public void Trilu()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)84);
+ _emitter.Write((ushort)88);
}
///.
public void Unary(UnaryOp unaryOp)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)85);
+ _emitter.Write((ushort)89);
_emitter.Write((byte)unaryOp);
}
@@ -1280,7 +1308,7 @@ public void Unary(UnaryOp unaryOp)
public void Uniform(DataType type)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)86);
+ _emitter.Write((ushort)90);
_emitter.Write(type);
}
@@ -1288,7 +1316,7 @@ public void Uniform(DataType type)
public void UniformLike(DataType type)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)87);
+ _emitter.Write((ushort)91);
_emitter.Write(type);
}
@@ -1296,14 +1324,21 @@ public void UniformLike(DataType type)
public void Unsqueeze()
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)88);
+ _emitter.Write((ushort)92);
+ }
+
+ ///.
+ public void UnsqueezeShape()
+ {
+ _emitter.Write((byte)100);
+ _emitter.Write((ushort)93);
}
///.
public void Where(bool isTfWhere)
{
_emitter.Write((byte)100);
- _emitter.Write((ushort)89);
+ _emitter.Write((ushort)94);
_emitter.Write(isTfWhere);
}
}
diff --git a/src/Native/include/nncase/kernels/stackvm/tensor_ops.h b/src/Native/include/nncase/kernels/stackvm/tensor_ops.h
index db59425a29..e918f22f25 100644
--- a/src/Native/include/nncase/kernels/stackvm/tensor_ops.h
+++ b/src/Native/include/nncase/kernels/stackvm/tensor_ops.h
@@ -1,4 +1,4 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39
+/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
*
* Copyright 2019-2021 Canaan Inc.
@@ -178,6 +178,12 @@ NNCASE_API result
get_item(value_t input, value_t index, value_t output = nullptr,
kernel_context &context = default_kernel_context());
+NNCASE_API result
+get_paddings(value_t input_shape, value_t weights_shape, value_t strides,
+ value_t dilations, value_t same, value_t lower,
+ value_t output = nullptr,
+ kernel_context &context = default_kernel_context());
+
NNCASE_API result
hard_sigmoid(value_t input, value_t alpha, value_t beta,
value_t output = nullptr,
@@ -330,6 +336,10 @@ NNCASE_API result
reshape(value_t input, value_t shape, value_t output = nullptr,
kernel_context &context = default_kernel_context());
+NNCASE_API result
+reshape_shape(value_t input_shape, value_t shape, value_t output = nullptr,
+ kernel_context &context = default_kernel_context());
+
NNCASE_API result resize_image(
runtime::stackvm::image_resize_mode_t resize_mode,
runtime::stackvm::image_resize_transformation_mode_t transformation_mode,
@@ -400,6 +410,10 @@ NNCASE_API result
squeeze(value_t input, value_t dim, value_t output = nullptr,
kernel_context &context = default_kernel_context());
+NNCASE_API result
+squeeze_shape(value_t input_shape, value_t dim, value_t output = nullptr,
+ kernel_context &context = default_kernel_context());
+
NNCASE_API result
stack(value_t inputs, value_t axis, value_t output = nullptr,
kernel_context &context = default_kernel_context());
@@ -421,6 +435,10 @@ NNCASE_API result
transpose(value_t input, value_t perm, value_t output = nullptr,
kernel_context &context = default_kernel_context());
+NNCASE_API result
+transpose_shape(value_t input_shape, value_t perm, value_t output = nullptr,
+ kernel_context &context = default_kernel_context());
+
NNCASE_API result
trilu(value_t input, value_t k, value_t upper, value_t output = nullptr,
kernel_context &context = default_kernel_context());
@@ -444,6 +462,10 @@ NNCASE_API result
unsqueeze(value_t input, value_t dim, value_t output = nullptr,
kernel_context &context = default_kernel_context());
+NNCASE_API result
+unsqueeze_shape(value_t input_shape, value_t dim, value_t output = nullptr,
+ kernel_context &context = default_kernel_context());
+
NNCASE_API result
where(bool is_tf_where, value_t cond, value_t x, value_t y,
value_t output = nullptr,
diff --git a/src/Native/include/nncase/runtime/stackvm/op_reader.h b/src/Native/include/nncase/runtime/stackvm/op_reader.h
index 5de273cab1..80372463e4 100644
--- a/src/Native/include/nncase/runtime/stackvm/op_reader.h
+++ b/src/Native/include/nncase/runtime/stackvm/op_reader.h
@@ -1,4 +1,4 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39
+/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
*
* Copyright 2019-2021 Canaan Inc.
@@ -997,6 +997,14 @@ template <> struct tensor_op_reader {
}
};
+template <> struct tensor_op_reader {
+ tensor_get_paddings_op_t
+ operator()(NNCASE_UNUSED span_reader &reader) const {
+ tensor_get_paddings_op_t op;
+ return op;
+ }
+};
+
template <> struct tensor_op_reader {
tensor_hard_sigmoid_op_t
operator()(NNCASE_UNUSED span_reader &reader) const {
@@ -1257,6 +1265,14 @@ template <> struct tensor_op_reader {
}
};
+template <> struct tensor_op_reader {
+ tensor_reshape_shape_op_t
+ operator()(NNCASE_UNUSED span_reader &reader) const {
+ tensor_reshape_shape_op_t op;
+ return op;
+ }
+};
+
template <> struct tensor_op_reader {
tensor_resize_image_op_t
operator()(NNCASE_UNUSED span_reader &reader) const {
@@ -1373,6 +1389,14 @@ template <> struct tensor_op_reader {
}
};
+template <> struct tensor_op_reader {
+ tensor_squeeze_shape_op_t
+ operator()(NNCASE_UNUSED span_reader &reader) const {
+ tensor_squeeze_shape_op_t op;
+ return op;
+ }
+};
+
template <> struct tensor_op_reader {
tensor_stack_op_t operator()(NNCASE_UNUSED span_reader &reader) const {
tensor_stack_op_t op;
@@ -1408,6 +1432,14 @@ template <> struct tensor_op_reader {
}
};
+template <> struct tensor_op_reader {
+ tensor_transpose_shape_op_t
+ operator()(NNCASE_UNUSED span_reader &reader) const {
+ tensor_transpose_shape_op_t op;
+ return op;
+ }
+};
+
template <> struct tensor_op_reader {
tensor_trilu_op_t operator()(NNCASE_UNUSED span_reader &reader) const {
tensor_trilu_op_t op;
@@ -1447,6 +1479,14 @@ template <> struct tensor_op_reader {
}
};
+template <> struct tensor_op_reader {
+ tensor_unsqueeze_shape_op_t
+ operator()(NNCASE_UNUSED span_reader &reader) const {
+ tensor_unsqueeze_shape_op_t op;
+ return op;
+ }
+};
+
template <> struct tensor_op_reader {
tensor_where_op_t operator()(NNCASE_UNUSED span_reader &reader) const {
tensor_where_op_t op;
@@ -1589,6 +1629,10 @@ class NNCASE_API tensor_op_visitor {
return default_visit(tensor_function_t::get_item, &op);
}
virtual result
+ visit(NNCASE_UNUSED const tensor_get_paddings_op_t &op) noexcept {
+ return default_visit(tensor_function_t::get_paddings, &op);
+ }
+ virtual result
visit(NNCASE_UNUSED const tensor_hard_sigmoid_op_t &op) noexcept {
return default_visit(tensor_function_t::hard_sigmoid, &op);
}
@@ -1717,6 +1761,10 @@ class NNCASE_API tensor_op_visitor {
return default_visit(tensor_function_t::reshape, &op);
}
virtual result
+ visit(NNCASE_UNUSED const tensor_reshape_shape_op_t &op) noexcept {
+ return default_visit(tensor_function_t::reshape_shape, &op);
+ }
+ virtual result
visit(NNCASE_UNUSED const tensor_resize_image_op_t &op) noexcept {
return default_visit(tensor_function_t::resize_image, &op);
}
@@ -1777,6 +1825,10 @@ class NNCASE_API tensor_op_visitor {
return default_visit(tensor_function_t::squeeze, &op);
}
virtual result
+ visit(NNCASE_UNUSED const tensor_squeeze_shape_op_t &op) noexcept {
+ return default_visit(tensor_function_t::squeeze_shape, &op);
+ }
+ virtual result
visit(NNCASE_UNUSED const tensor_stack_op_t &op) noexcept {
return default_visit(tensor_function_t::stack, &op);
}
@@ -1797,6 +1849,10 @@ class NNCASE_API tensor_op_visitor {
return default_visit(tensor_function_t::transpose, &op);
}
virtual result
+ visit(NNCASE_UNUSED const tensor_transpose_shape_op_t &op) noexcept {
+ return default_visit(tensor_function_t::transpose_shape, &op);
+ }
+ virtual result
visit(NNCASE_UNUSED const tensor_trilu_op_t &op) noexcept {
return default_visit(tensor_function_t::trilu, &op);
}
@@ -1817,6 +1873,10 @@ class NNCASE_API tensor_op_visitor {
return default_visit(tensor_function_t::unsqueeze, &op);
}
virtual result
+ visit(NNCASE_UNUSED const tensor_unsqueeze_shape_op_t &op) noexcept {
+ return default_visit(tensor_function_t::unsqueeze_shape, &op);
+ }
+ virtual result
visit(NNCASE_UNUSED const tensor_where_op_t &op) noexcept {
return default_visit(tensor_function_t::where, &op);
}
diff --git a/src/Native/include/nncase/runtime/stackvm/opcode.h b/src/Native/include/nncase/runtime/stackvm/opcode.h
index 26e98927f1..5c17c82894 100644
--- a/src/Native/include/nncase/runtime/stackvm/opcode.h
+++ b/src/Native/include/nncase/runtime/stackvm/opcode.h
@@ -1,4 +1,4 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:38
+/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
*
* Copyright 2019-2021 Canaan Inc.
@@ -136,29 +136,29 @@ enum class tensor_function_t : uint16_t {
elu = 20,
erf = 21,
gelu = 30,
- hardmax = 32,
- hard_sigmoid = 33,
- hard_swish = 34,
- instance_normalization = 36,
- l2_normalization = 37,
- layer_norm = 38,
- leaky_relu = 39,
- log_softmax = 40,
- lp_normalization = 41,
- lrn = 42,
- one_hot = 48,
- pad = 49,
- prelu = 50,
- reduce_window2d = 59,
- relu = 60,
- relu6 = 61,
- selu = 68,
- sigmoid = 70,
- softmax = 73,
- softplus = 74,
- softsign = 75,
- space_to_batch = 76,
- swish = 80,
+ hardmax = 33,
+ hard_sigmoid = 34,
+ hard_swish = 35,
+ instance_normalization = 37,
+ l2_normalization = 38,
+ layer_norm = 39,
+ leaky_relu = 40,
+ log_softmax = 41,
+ lp_normalization = 42,
+ lrn = 43,
+ one_hot = 49,
+ pad = 50,
+ prelu = 51,
+ reduce_window2d = 60,
+ relu = 61,
+ relu6 = 62,
+ selu = 70,
+ sigmoid = 72,
+ softmax = 75,
+ softplus = 76,
+ softsign = 77,
+ space_to_batch = 78,
+ swish = 83,
binary = 2,
clamp = 9,
compare = 10,
@@ -167,15 +167,15 @@ enum class tensor_function_t : uint16_t {
dequantize = 19,
fake_dequantize = 23,
fake_quantize = 24,
- mat_mul = 44,
- quantize = 52,
- quant_param_of = 53,
- range_of = 55,
- reduce = 57,
- reduce_arg = 58,
- require = 62,
- select = 67,
- unary = 85,
+ mat_mul = 45,
+ quantize = 53,
+ quant_param_of = 54,
+ range_of = 56,
+ reduce = 58,
+ reduce_arg = 59,
+ require = 63,
+ select = 69,
+ unary = 89,
bitcast = 3,
broadcast = 4,
bucket_pad = 6,
@@ -189,35 +189,40 @@ enum class tensor_function_t : uint16_t {
gather_elements = 28,
gather_nd = 29,
get_item = 31,
- index_of = 35,
- lstm = 43,
- prod = 51,
- range = 54,
- rank = 56,
- reshape = 63,
- reverse_sequence = 65,
- scatter_nd = 66,
- shape_of = 69,
- size_of = 71,
- slice = 72,
- split = 77,
- squeeze = 78,
- stack = 79,
- tile = 81,
- top_k = 82,
- transpose = 83,
- trilu = 84,
- unsqueeze = 88,
- where = 89,
+ index_of = 36,
+ lstm = 44,
+ prod = 52,
+ range = 55,
+ rank = 57,
+ reshape = 64,
+ reverse_sequence = 67,
+ scatter_nd = 68,
+ shape_of = 71,
+ size_of = 73,
+ slice = 74,
+ split = 79,
+ squeeze = 80,
+ stack = 82,
+ tile = 84,
+ top_k = 85,
+ transpose = 86,
+ trilu = 88,
+ unsqueeze = 92,
+ where = 94,
broadcast_shape = 5,
conv2d_shape = 15,
conv2d_transpose_shape = 17,
- mat_mul_shape = 45,
- normal = 46,
- normal_like = 47,
- uniform = 86,
- uniform_like = 87,
- resize_image = 64,
+ get_paddings = 32,
+ mat_mul_shape = 46,
+ reshape_shape = 65,
+ squeeze_shape = 81,
+ transpose_shape = 87,
+ unsqueeze_shape = 93,
+ normal = 47,
+ normal_like = 48,
+ uniform = 90,
+ uniform_like = 91,
+ resize_image = 66,
};
enum class binary_op_t : uint8_t {
@@ -663,6 +668,8 @@ struct tensor_gelu_op_t {};
struct tensor_get_item_op_t {};
+struct tensor_get_paddings_op_t {};
+
struct tensor_hard_sigmoid_op_t {};
struct tensor_hard_swish_op_t {};
@@ -758,6 +765,8 @@ struct tensor_require_op_t {
struct tensor_reshape_op_t {};
+struct tensor_reshape_shape_op_t {};
+
struct tensor_resize_image_op_t {
image_resize_mode_t resize_mode;
image_resize_transformation_mode_t transformation_mode;
@@ -793,6 +802,8 @@ struct tensor_split_op_t {};
struct tensor_squeeze_op_t {};
+struct tensor_squeeze_shape_op_t {};
+
struct tensor_stack_op_t {};
struct tensor_swish_op_t {};
@@ -803,6 +814,8 @@ struct tensor_top_k_op_t {};
struct tensor_transpose_op_t {};
+struct tensor_transpose_shape_op_t {};
+
struct tensor_trilu_op_t {};
struct tensor_unary_op_t {
@@ -819,6 +832,8 @@ struct tensor_uniform_like_op_t {
struct tensor_unsqueeze_op_t {};
+struct tensor_unsqueeze_shape_op_t {};
+
struct tensor_where_op_t {
bool is_tf_where;
};
@@ -993,8 +1008,18 @@ inline std::string to_string(tensor_function_t tensor_funct) {
return "conv2d_shape";
case tensor_function_t::conv2d_transpose_shape:
return "conv2d_transpose_shape";
+ case tensor_function_t::get_paddings:
+ return "get_paddings";
case tensor_function_t::mat_mul_shape:
return "mat_mul_shape";
+ case tensor_function_t::reshape_shape:
+ return "reshape_shape";
+ case tensor_function_t::squeeze_shape:
+ return "squeeze_shape";
+ case tensor_function_t::transpose_shape:
+ return "transpose_shape";
+ case tensor_function_t::unsqueeze_shape:
+ return "unsqueeze_shape";
case tensor_function_t::normal:
return "normal";
case tensor_function_t::normal_like:
diff --git a/src/Native/src/kernels/stackvm/reference/pad.cpp b/src/Native/src/kernels/stackvm/reference/pad.cpp
index 6186101e8a..2b27400fab 100644
--- a/src/Native/src/kernels/stackvm/reference/pad.cpp
+++ b/src/Native/src/kernels/stackvm/reference/pad.cpp
@@ -162,7 +162,7 @@ void padding_impl_opt(T *in, T *out, gsl::span in_shape,
dh = out_shape[0];
hh = out_shape[1];
wh = out_shape[2];
- } else {
+ } else if (in_shape.size() == 4) {
cl = in_shape[0];
dl = in_shape[1];
hl = in_shape[2];
@@ -171,6 +171,16 @@ void padding_impl_opt(T *in, T *out, gsl::span in_shape,
dh = out_shape[1];
hh = out_shape[2];
wh = out_shape[3];
+ } else // size ==2
+ {
+ cl = 1;
+ dl = 1;
+ hl = in_shape[0];
+ wl = in_shape[1];
+ ch = 1;
+ dh = 1;
+ hh = out_shape[0];
+ wh = out_shape[1];
}
pad_data2(in, out, cl, dl, hl, wl, ch, dh, hh, wh, value);
@@ -216,7 +226,7 @@ result nncase::kernels::stackvm::reference::pad(
std::all_of(
paddings.begin(), paddings.end(),
[](const padding &p) { return p.before == 0 && p.after >= 0; }) &&
- mode == pad_mode_t::constant && in_shape.size() >= 3;
+ mode == pad_mode_t::constant && in_shape.size() >= 2;
if (std::all_of(paddings.begin(), paddings.end(),
[](const padding &p) { return p.interior == 0; })) {
diff --git a/src/Native/src/kernels/stackvm/shape_ops.cpp b/src/Native/src/kernels/stackvm/shape_ops.cpp
index 7677b10c2b..991b16950c 100644
--- a/src/Native/src/kernels/stackvm/shape_ops.cpp
+++ b/src/Native/src/kernels/stackvm/shape_ops.cpp
@@ -106,6 +106,12 @@ result nncase::kernels::stackvm::broadcast_shape(value_t inputs,
KERNEL_FINISH;
}
+#define WRITE_OUT_SHAPE \
+ try_output(out_mem, output, dt_int64, dims_t{out_shape.size()}); \
+ for (int i = 0; i < out_shape.size(); ++i) { \
+ OUT_CAST(int64_t, out_mem)[i] = out_shape[i]; \
+ }
+
result nncase::kernels::stackvm::mat_mul_shape(value_t lhs,
value_t rhs,
value_t output,
@@ -113,9 +119,105 @@ result nncase::kernels::stackvm::mat_mul_shape(value_t lhs,
try_dims(lhs_shape, lhs);
try_dims(rhs_shape, rhs);
try_var(out_shape, matmul_infer_shape(lhs_shape, rhs_shape));
- try_output(out_mem, output, dt_int64, dims_t{out_shape.size()});
- for (int i = 0; i < out_shape.size(); ++i) {
- OUT_CAST(int64_t, out_mem)[i] = out_shape[i];
+ WRITE_OUT_SHAPE;
+ KERNEL_FINISH;
+}
+
+inline int get_windowed_output_size(int size, int filter, int stride,
+ int dilation, bool same, bool ceilMode) {
+ auto effectiveFilterSize = ((filter - 1) * dilation) + 1;
+ auto falseBranch = !ceilMode
+ ? ((size - effectiveFilterSize + stride) / stride)
+ : ceil(size - effectiveFilterSize + stride / stride);
+ auto trueBranch = (size + stride - 1) / stride;
+ return same ? trueBranch : falseBranch;
+}
+
+inline padding get_windowed_padding(int32_t input_size, int32_t output_size,
+ int32_t filter, int32_t stride,
+ int32_t dilation, bool lower) {
+ auto effective_filter_size = (filter - 1) * dilation + 1;
+ int padding = std::max(0, (output_size - 1) * stride +
+ effective_filter_size - input_size);
+ auto before = padding / 2;
+ auto after = padding - padding / 2;
+ if (lower) {
+ return {std::max(before, after), std::min(before, after)};
}
+ return {before, after};
+}
+
+result nncase::kernels::stackvm::get_paddings(
+ value_t input_shape, value_t weights_shape, value_t strides,
+ value_t dilations, value_t same, value_t lower, value_t output,
+ [[maybe_unused]] kernel_context &) {
+ try_dims(in_shape, input_shape);
+ try_dims(w_shape, weights_shape);
+ try_strides(strides_value, strides);
+ try_strides(dilations_value, dilations);
+ try_to_scalar_v(same, bool);
+ try_to_scalar_v(lower, bool);
+ auto out_h =
+ get_windowed_output_size(in_shape[2], w_shape[2], strides_value[0],
+ dilations_value[0], same_value, false);
+ auto out_w =
+ get_windowed_output_size(in_shape[3], w_shape[3], strides_value[1],
+ dilations_value[1], same_value, false);
+ auto pad_h =
+ get_windowed_padding(in_shape[2], out_h, w_shape[2], strides_value[0],
+ dilations_value[0], lower_value);
+ auto pad_w =
+ get_windowed_padding(in_shape[3], out_w, w_shape[3], strides_value[1],
+ dilations_value[1], lower_value);
+ auto out_shape = dims_t{2, 2};
+ try_out_mem(output, dt_int64, out_shape);
+ OUT_CAST(int64_t, output_mem)[0] = pad_h.before;
+ OUT_CAST(int64_t, output_mem)[1] = pad_h.after;
+ OUT_CAST(int64_t, output_mem)[2] = pad_w.before;
+ OUT_CAST(int64_t, output_mem)[3] = pad_w.after;
+ KERNEL_FINISH;
+}
+
+result nncase::kernels::stackvm::reshape_shape(value_t input_shape,
+ value_t shape,
+ value_t output,
+ kernel_context &) {
+ try_dims(in_shape, input_shape);
+ try_axes(shape_value, shape);
+ auto out_shape = reshape_shape_infer(in_shape, shape_value);
+ WRITE_OUT_SHAPE;
+ KERNEL_FINISH;
+}
+
+result
+nncase::kernels::stackvm::transpose_shape(value_t input_shape, value_t perm,
+ value_t output,
+ [[maybe_unused]] kernel_context &) {
+ try_dims(in_shape, input_shape);
+ try_dims(perm_value, perm);
+ auto out_shape = transpose_infer_shape(in_shape, perm_value);
+ WRITE_OUT_SHAPE;
+ KERNEL_FINISH;
+}
+
+result
+nncase::kernels::stackvm::squeeze_shape(value_t input_shape, value_t dim,
+ value_t output,
+ [[maybe_unused]] kernel_context &) {
+ try_dims(in_shape, input_shape);
+ try_positive_axes(dim_value, dim, in_shape.size());
+ auto out_shape = squeeze_infer_shape(in_shape, dim_value);
+ WRITE_OUT_SHAPE;
+ KERNEL_FINISH;
+}
+
+result
+nncase::kernels::stackvm::unsqueeze_shape(value_t input_shape, value_t dim,
+ value_t output,
+ [[maybe_unused]] kernel_context &) {
+ try_dims(in_shape, input_shape);
+ try_axes(dim_value, dim);
+ auto out_shape = unsqueeze_infer_shape(in_shape, dim_value);
+ WRITE_OUT_SHAPE;
KERNEL_FINISH;
}
\ No newline at end of file
diff --git a/src/Native/src/runtime/stackvm/op_reader.cpp b/src/Native/src/runtime/stackvm/op_reader.cpp
index 04776e0f7d..901f0f6125 100644
--- a/src/Native/src/runtime/stackvm/op_reader.cpp
+++ b/src/Native/src/runtime/stackvm/op_reader.cpp
@@ -1,4 +1,4 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39
+/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
*
* Copyright 2019-2021 Canaan Inc.
@@ -99,6 +99,9 @@ result tensor_op_visitor::visit(tensor_function_t tensor_funct,
return visit(tensor_op_reader()(reader));
case tensor_function_t::get_item:
return visit(tensor_op_reader()(reader));
+ case tensor_function_t::get_paddings:
+ return visit(
+ tensor_op_reader()(reader));
case tensor_function_t::hard_sigmoid:
return visit(
tensor_op_reader()(reader));
@@ -173,6 +176,9 @@ result tensor_op_visitor::visit(tensor_function_t tensor_funct,
return visit(tensor_op_reader()(reader));
case tensor_function_t::reshape:
return visit(tensor_op_reader()(reader));
+ case tensor_function_t::reshape_shape:
+ return visit(
+ tensor_op_reader()(reader));
case tensor_function_t::resize_image:
return visit(
tensor_op_reader()(reader));
@@ -206,6 +212,9 @@ result tensor_op_visitor::visit(tensor_function_t tensor_funct,
return visit(tensor_op_reader()(reader));
case tensor_function_t::squeeze:
return visit(tensor_op_reader()(reader));
+ case tensor_function_t::squeeze_shape:
+ return visit(
+ tensor_op_reader()(reader));
case tensor_function_t::stack:
return visit(tensor_op_reader()(reader));
case tensor_function_t::swish:
@@ -216,6 +225,9 @@ result tensor_op_visitor::visit(tensor_function_t tensor_funct,
return visit(tensor_op_reader()(reader));
case tensor_function_t::transpose:
return visit(tensor_op_reader()(reader));
+ case tensor_function_t::transpose_shape:
+ return visit(
+ tensor_op_reader()(reader));
case tensor_function_t::trilu:
return visit(tensor_op_reader()(reader));
case tensor_function_t::unary:
@@ -227,6 +239,9 @@ result tensor_op_visitor::visit(tensor_function_t tensor_funct,
tensor_op_reader()(reader));
case tensor_function_t::unsqueeze:
return visit(tensor_op_reader()(reader));
+ case tensor_function_t::unsqueeze_shape:
+ return visit(
+ tensor_op_reader()(reader));
case tensor_function_t::where:
return visit(tensor_op_reader()(reader));
default:
diff --git a/src/Native/src/runtime/stackvm/ops/tensor.cpp b/src/Native/src/runtime/stackvm/ops/tensor.cpp
index 0439e7887c..6f09a7084c 100644
--- a/src/Native/src/runtime/stackvm/ops/tensor.cpp
+++ b/src/Native/src/runtime/stackvm/ops/tensor.cpp
@@ -1,4 +1,4 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39
+/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
*
* Copyright 2019-2021 Canaan Inc.
@@ -564,6 +564,29 @@ result stackvm_runtime_function::visit(
return ok();
}
+result stackvm_runtime_function::visit(
+ [[maybe_unused]] const tensor_get_paddings_op_t &op) noexcept {
+ dump_op("get_paddings");
+ try_var(input_shape, pop_value());
+ dump_input(input_shape);
+ try_var(weights_shape, pop_value());
+ dump_input(weights_shape);
+ try_var(strides, pop_value());
+ dump_input(strides);
+ try_var(dilations, pop_value());
+ dump_input(dilations);
+ try_var(same, pop_value());
+ dump_input(same);
+ try_var(lower, pop_value());
+ dump_input(lower);
+ try_var(output, kernels::stackvm::get_paddings(
+ input_shape, weights_shape, strides, dilations, same,
+ lower, nullptr, module().kernel_context()));
+ dump_output(output);
+ stack_.push(std::move(output));
+ return ok();
+}
+
result stackvm_runtime_function::visit(
[[maybe_unused]] const tensor_hard_sigmoid_op_t &op) noexcept {
dump_op("hard_sigmoid");
@@ -1091,6 +1114,20 @@ result stackvm_runtime_function::visit(
return ok();
}
+result stackvm_runtime_function::visit(
+ [[maybe_unused]] const tensor_reshape_shape_op_t &op) noexcept {
+ dump_op("reshape_shape");
+ try_var(input_shape, pop_value());
+ dump_input(input_shape);
+ try_var(shape, pop_value());
+ dump_input(shape);
+ try_var(output, kernels::stackvm::reshape_shape(input_shape, shape, nullptr,
+ module().kernel_context()));
+ dump_output(output);
+ stack_.push(std::move(output));
+ return ok();
+}
+
result stackvm_runtime_function::visit(
[[maybe_unused]] const tensor_resize_image_op_t &op) noexcept {
dump_op("resize_image");
@@ -1327,6 +1364,20 @@ result stackvm_runtime_function::visit(
return ok();
}
+result stackvm_runtime_function::visit(
+ [[maybe_unused]] const tensor_squeeze_shape_op_t &op) noexcept {
+ dump_op("squeeze_shape");
+ try_var(input_shape, pop_value());
+ dump_input(input_shape);
+ try_var(dim, pop_value());
+ dump_input(dim);
+ try_var(output, kernels::stackvm::squeeze_shape(input_shape, dim, nullptr,
+ module().kernel_context()));
+ dump_output(output);
+ stack_.push(std::move(output));
+ return ok();
+}
+
result stackvm_runtime_function::visit(
[[maybe_unused]] const tensor_stack_op_t &op) noexcept {
dump_op("stack");
@@ -1402,6 +1453,20 @@ result stackvm_runtime_function::visit(
return ok();
}
+result stackvm_runtime_function::visit(
+ [[maybe_unused]] const tensor_transpose_shape_op_t &op) noexcept {
+ dump_op("transpose_shape");
+ try_var(input_shape, pop_value());
+ dump_input(input_shape);
+ try_var(perm, pop_value());
+ dump_input(perm);
+ try_var(output, kernels::stackvm::transpose_shape(
+ input_shape, perm, nullptr, module().kernel_context()));
+ dump_output(output);
+ stack_.push(std::move(output));
+ return ok();
+}
+
result stackvm_runtime_function::visit(
[[maybe_unused]] const tensor_trilu_op_t &op) noexcept {
dump_op("trilu");
@@ -1482,6 +1547,20 @@ result stackvm_runtime_function::visit(
return ok();
}
+result stackvm_runtime_function::visit(
+ [[maybe_unused]] const tensor_unsqueeze_shape_op_t &op) noexcept {
+ dump_op("unsqueeze_shape");
+ try_var(input_shape, pop_value());
+ dump_input(input_shape);
+ try_var(dim, pop_value());
+ dump_input(dim);
+ try_var(output, kernels::stackvm::unsqueeze_shape(
+ input_shape, dim, nullptr, module().kernel_context()));
+ dump_output(output);
+ stack_.push(std::move(output));
+ return ok();
+}
+
result stackvm_runtime_function::visit(
[[maybe_unused]] const tensor_where_op_t &op) noexcept {
dump_op("where");
diff --git a/src/Native/src/runtime/stackvm/runtime_function_ops.h b/src/Native/src/runtime/stackvm/runtime_function_ops.h
index 02841ebe35..ae6944ef59 100644
--- a/src/Native/src/runtime/stackvm/runtime_function_ops.h
+++ b/src/Native/src/runtime/stackvm/runtime_function_ops.h
@@ -1,4 +1,4 @@
-/* This file is generated by tools/stackvm_gen/IsaGen at 2023/7/12 17:07:39
+/* This file is generated by tools/stackvm_gen/IsaGen at 2023/9/5 19:40:29
* +08:00.
*
* Copyright 2019-2021 Canaan Inc.
@@ -49,6 +49,7 @@ result visit(const tensor_gather_elements_op_t &op) noexcept override;
result visit(const tensor_gather_nd_op_t &op) noexcept override;
result visit(const tensor_gelu_op_t &op) noexcept override;
result visit(const tensor_get_item_op_t &op) noexcept override;
+result visit(const tensor_get_paddings_op_t &op) noexcept override;
result visit(const tensor_hard_sigmoid_op_t &op) noexcept override;
result visit(const tensor_hard_swish_op_t &op) noexcept override;
result visit(const tensor_hardmax_op_t &op) noexcept override;
@@ -82,6 +83,7 @@ result visit(const tensor_relu_op_t &op) noexcept override;
result visit(const tensor_relu6_op_t &op) noexcept override;
result visit(const tensor_require_op_t &op) noexcept override;
result visit(const tensor_reshape_op_t &op) noexcept override;
+result visit(const tensor_reshape_shape_op_t &op) noexcept override;
result visit(const tensor_resize_image_op_t &op) noexcept override;
result visit(const tensor_reverse_sequence_op_t &op) noexcept override;
result visit(const tensor_scatter_nd_op_t &op) noexcept override;
@@ -97,14 +99,17 @@ result visit(const tensor_softsign_op_t &op) noexcept override;
result visit(const tensor_space_to_batch_op_t &op) noexcept override;
result visit(const tensor_split_op_t &op) noexcept override;
result visit(const tensor_squeeze_op_t &op) noexcept override;
+result visit(const tensor_squeeze_shape_op_t &op) noexcept override;
result visit(const tensor_stack_op_t &op) noexcept override;
result visit(const tensor_swish_op_t &op) noexcept override;
result visit(const tensor_tile_op_t &op) noexcept override;
result visit(const tensor_top_k_op_t &op) noexcept override;
result visit(const tensor_transpose_op_t &op) noexcept override;
+result visit(const tensor_transpose_shape_op_t &op) noexcept override;
result visit(const tensor_trilu_op_t &op) noexcept override;
result visit(const tensor_unary_op_t &op) noexcept override;
result visit(const tensor_uniform_op_t &op) noexcept override;
result visit(const tensor_uniform_like_op_t &op) noexcept override;
result visit(const tensor_unsqueeze_op_t &op) noexcept override;
+result visit(const tensor_unsqueeze_shape_op_t &op) noexcept override;
result visit(const tensor_where_op_t &op) noexcept override;
diff --git a/src/Nncase.Compiler/Compiler.cs b/src/Nncase.Compiler/Compiler.cs
index 74eccd2190..f4014fa047 100644
--- a/src/Nncase.Compiler/Compiler.cs
+++ b/src/Nncase.Compiler/Compiler.cs
@@ -12,6 +12,7 @@
using Nncase.Evaluator;
using Nncase.Hosting;
using Nncase.IR;
+using Nncase.IR.NN;
using Nncase.IR.Tensors;
using Nncase.Passes;
using Nncase.Passes.Mutators;
@@ -118,7 +119,29 @@ public void TargetIndependentPass(IPassManager passManager)
p.Add();
p.Add();
p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
+ p.Add();
});
+
passManager.AddWithName("NeutralOptimizeTranspose").Configure(p =>
{
p.Add();
@@ -180,24 +203,30 @@ public void TargetIndependentPass(IPassManager passManager)
public void RegisterShapeBucket(IPassManager p)
{
var options = _compileSession.CompileOptions.ShapeBucketOptions;
- var singleVar = options.VarMap.Values.SelectMany(x => x).OfType().ToHashSet().Count <= 1;
if (!options.Enable)
{
return;
}
+ var singleVar = options.VarMap.Values.SelectMany(x => x).OfType().ToHashSet().Count <= 1;
CheckShapeBucketOptions(options);
- ToFusion(p);
- MergeOp(p);
- LostToFusion(p, singleVar);
- MergeOp(p);
- ClearMarker(p);
-
- // MergeFusion(p, singleVar);
- Bucket(p);
- // Rebuild(p);
- Simplify(p);
+ if (HasNotBucketOp(_module!.Entry!) || !singleVar)
+ {
+ ToFusion(p);
+ MergeOp(p, true);
+ LostToFusion(p, singleVar);
+ MergeOp(p, true);
+ ClearMarker(p);
+ MergeFusion(p, singleVar, true);
+ Bucket(p);
+ Rebuild(p, singleVar);
+ Simplify(p);
+ }
+ else
+ {
+ p.AddWithName("FullBucket");
+ }
}
public void ClearFixShape(IPassManager p)
diff --git a/src/Nncase.Core/IR/ShapeExpr/Functional.cs b/src/Nncase.Core/IR/ShapeExpr/Functional.cs
index 8631e5e26f..555e9fa367 100644
--- a/src/Nncase.Core/IR/ShapeExpr/Functional.cs
+++ b/src/Nncase.Core/IR/ShapeExpr/Functional.cs
@@ -14,4 +14,14 @@ public static class ShapeExpr
public static Call Conv2DTransposeShape(Expr input, Expr weights, Expr stride, Expr dilation, Expr padding, Expr outputPadding, Expr groups) => new(new Conv2DTransposeShape(), input, weights, stride, dilation, padding, outputPadding, groups);
public static Call MatMulShape(Expr lhs, Expr rhs) => new(new MatMulShape(), lhs, rhs);
+
+ public static Call GetPaddings(Expr input, Expr weights, Expr strides, Expr dilation, Expr same, Expr lower) => new(new GetPaddings(), input, weights, strides, dilation, same, lower);
+
+ public static Call ReshapeShape(Expr input, Expr shape) => new(new ReshapeShape(), input, shape);
+
+ public static Call SqueezeShape(Expr input, Expr dims) => new(new SqueezeShape(), input, dims);
+
+ public static Call TransposeShape(Expr input, Expr perm) => new(new TransposeShape(), input, perm);
+
+ public static Call UnsqueezeShape(Expr lhs, Expr rhs) => new(new UnsqueezeShape(), lhs, rhs);
}
diff --git a/src/Nncase.Core/IR/ShapeExpr/GetPaddings.cs b/src/Nncase.Core/IR/ShapeExpr/GetPaddings.cs
new file mode 100644
index 0000000000..5bed9f4660
--- /dev/null
+++ b/src/Nncase.Core/IR/ShapeExpr/GetPaddings.cs
@@ -0,0 +1,40 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using Nncase.PatternMatch;
+
+namespace Nncase.IR.ShapeExpr;
+
+[PatternFunctionalGenerator]
+public class GetPaddings : Op
+{
+ ///
+ /// Gets Input.
+ ///
+ public static readonly ParameterInfo InputShape = new(typeof(GetPaddings), 0, "input_shape");
+
+ ///
+ /// Gets Weights.
+ ///
+ public static readonly ParameterInfo WeightsShape = new(typeof(GetPaddings), 1, "weights_shape");
+
+ ///
+ /// Gets Strides.
+ ///
+ public static readonly ParameterInfo Strides = new(typeof(GetPaddings), 2, "strides");
+
+ ///
+ /// Gets Dilations.
+ ///
+ public static readonly ParameterInfo Dilations = new(typeof(GetPaddings), 3, "dilations");
+
+ ///
+ /// Gets Same.
+ ///
+ public static readonly ParameterInfo Same = new(typeof(GetPaddings), 4, "same");
+
+ ///
+ /// Gets Lower.
+ ///
+ public static readonly ParameterInfo Lower = new(typeof(GetPaddings), 5, "lower");
+}
diff --git a/src/Nncase.Core/IR/ShapeExpr/ReshapeShape.cs b/src/Nncase.Core/IR/ShapeExpr/ReshapeShape.cs
new file mode 100644
index 0000000000..f6db6836a0
--- /dev/null
+++ b/src/Nncase.Core/IR/ShapeExpr/ReshapeShape.cs
@@ -0,0 +1,24 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using Nncase.PatternMatch;
+using static Nncase.IR.TypePatternUtility;
+
+namespace Nncase.IR.ShapeExpr;
+
+///
+/// Reshape expression.
+///
+[PatternFunctionalGenerator]
+public sealed partial class ReshapeShape : Op
+{
+ ///
+ /// Gets input shape.
+ ///
+ public static readonly ParameterInfo InputShape = new(typeof(ReshapeShape), 0, "input_shape");
+
+ ///
+ /// Gets shape.
+ ///
+ public static readonly ParameterInfo Shape = new(typeof(ReshapeShape), 1, "shape", HasRank(1));
+}
diff --git a/src/Nncase.Core/IR/ShapeExpr/SqueezeShape.cs b/src/Nncase.Core/IR/ShapeExpr/SqueezeShape.cs
new file mode 100644
index 0000000000..55338691b6
--- /dev/null
+++ b/src/Nncase.Core/IR/ShapeExpr/SqueezeShape.cs
@@ -0,0 +1,24 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using Nncase.PatternMatch;
+using static Nncase.IR.TypePatternUtility;
+
+namespace Nncase.IR.ShapeExpr;
+
+///
+/// Squeeze expression.
+///
+[PatternFunctionalGenerator]
+public sealed partial class SqueezeShape : Op
+{
+ ///
+ /// Gets input shape.
+ ///
+ public static readonly ParameterInfo InputShape = new(typeof(SqueezeShape), 0, "input_shape");
+
+ ///
+ /// Gets dimension.
+ ///
+ public static readonly ParameterInfo Dim = new(typeof(SqueezeShape), 1, "dim", HasRank(1) & IsIntegral());
+}
diff --git a/src/Nncase.Core/IR/ShapeExpr/TransposeShape.cs b/src/Nncase.Core/IR/ShapeExpr/TransposeShape.cs
new file mode 100644
index 0000000000..c658e15dc6
--- /dev/null
+++ b/src/Nncase.Core/IR/ShapeExpr/TransposeShape.cs
@@ -0,0 +1,24 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using Nncase.PatternMatch;
+using static Nncase.IR.TypePatternUtility;
+
+namespace Nncase.IR.ShapeExpr;
+
+///
+/// Gets input.
+///
+[PatternFunctionalGenerator]
+public sealed partial class TransposeShape : Op
+{
+ ///
+ /// Gets input.
+ ///
+ public static readonly ParameterInfo InputShape = new(typeof(TransposeShape), 0, "input");
+
+ ///
+ /// Gets perm.
+ ///
+ public static readonly ParameterInfo Perm = new(typeof(TransposeShape), 1, "perm", HasRank(1) & IsIntegral());
+}
diff --git a/src/Nncase.Core/IR/ShapeExpr/UnsqueezeShape.cs b/src/Nncase.Core/IR/ShapeExpr/UnsqueezeShape.cs
new file mode 100644
index 0000000000..ed1f396fef
--- /dev/null
+++ b/src/Nncase.Core/IR/ShapeExpr/UnsqueezeShape.cs
@@ -0,0 +1,21 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using Nncase.PatternMatch;
+using static Nncase.IR.TypePatternUtility;
+
+namespace Nncase.IR.ShapeExpr;
+
+[PatternFunctionalGenerator]
+public sealed partial class UnsqueezeShape : Op
+{
+ ///
+ /// Gets input_shape.
+ ///
+ public static readonly ParameterInfo InputShape = new(typeof(UnsqueezeShape), 0, "input_shape");
+
+ ///
+ /// Gets dimension.
+ ///
+ public static readonly ParameterInfo Dim = new(typeof(UnsqueezeShape), 1, "dim", HasRank(1) & IsIntegral());
+}
diff --git a/src/Nncase.Core/Utilities/ShapeExprUtility.cs b/src/Nncase.Core/Utilities/ShapeExprUtility.cs
index 28d953a8ad..c1da04dca4 100644
--- a/src/Nncase.Core/Utilities/ShapeExprUtility.cs
+++ b/src/Nncase.Core/Utilities/ShapeExprUtility.cs
@@ -19,8 +19,8 @@ public static Expr BroadcastShape(Expr lhsShape, params Expr[] rhsShape)
public static Expr Positive(Expr axis, Expr inShape)
{
var rank = new Call(new Rank(), inShape);
- var i32Axis = Cast(axis, DataTypes.Int32);
- return new If(i32Axis < 0, i32Axis + rank, i32Axis);
+ var i64Axis = Cast(axis, DataTypes.Int64);
+ return new If(i64Axis < 0L, i64Axis + rank, i64Axis);
}
public static Expr Slice(Expr shape, int begin, int end)
@@ -35,28 +35,28 @@ public static Expr Slice(Expr shape, Expr begin, Expr end)
public static Expr Replace(Expr shapeExpr, Expr index, Expr value)
{
- return SliceAndMerge(shapeExpr, index, value, 1);
+ return SliceAndMerge(shapeExpr, index, value, 1L);
}
public static Expr Insert(Expr shapeExpr, Expr index, Expr value)
{
if (shapeExpr.CheckedShape.IsScalar)
{
- return SliceAndMerge(StackScalar(shapeExpr), index, value, 0);
+ return SliceAndMerge(StackScalar(shapeExpr), index, value, 0L);
}
- return SliceAndMerge(shapeExpr, index, value, 0);
+ return SliceAndMerge(shapeExpr, index, value, 0L);
}
public static Expr ReplaceList(Expr shapeExpr, Expr list, Expr value)
{
- return SliceAndMerge(shapeExpr, list, value, 1, false);
+ return SliceAndMerge(shapeExpr, list, value, 1L, false);
}
public static Expr Remove(Expr shapeExpr, Expr index)
{
var front = Slice(shapeExpr, 0, index);
- var last = Slice(shapeExpr, index + 1, int.MaxValue);
+ var last = Slice(shapeExpr, index + 1L, int.MaxValue);
return Concat(new IR.Tuple(front, last), 0);
}
@@ -65,10 +65,17 @@ public static Expr StackOne(Expr expr)
return Stack(new IR.Tuple(expr), 0);
}
- private static Expr SliceAndMerge(Expr shapeExpr, Expr index, Expr value, Expr indexOffset, bool valueIsList = true)
+ public static IValue GetShapeValue(Call call)
{
+ call.InferenceType();
+ return Value.FromTensor(call.CheckedShape.ToValueArray().Select(x => (long)x).ToArray());
+ }
+
+ private static Expr SliceAndMerge(Expr originShapeExpr, Expr index, Expr value, Expr indexOffset, bool valueIsList = true)
+ {
+ var shapeExpr = Cast(originShapeExpr, DataTypes.Int64);
var front = Slice(shapeExpr, 0, index);
- var last = Slice(shapeExpr, Cast(index, DataTypes.Int32) + indexOffset, int.MaxValue);
+ var last = Slice(shapeExpr, Cast(index, DataTypes.Int64) + indexOffset, int.MaxValue);
var c = valueIsList ? StackOne(value) : value;
if (c.CheckedShape.IsScalar)
{
diff --git a/src/Nncase.Evaluator/Math/Binary.cs b/src/Nncase.Evaluator/Math/Binary.cs
index 0b5515ca28..f5f9d3c65b 100755
--- a/src/Nncase.Evaluator/Math/Binary.cs
+++ b/src/Nncase.Evaluator/Math/Binary.cs
@@ -118,7 +118,7 @@ public Expr Visit(IShapeEvaluateContext context, Binary target)
{
var lhs = context.GetArgumentShape(target, Binary.Lhs);
var rhs = context.GetArgumentShape(target, Binary.Rhs);
- return IR.F.Tensors.Cast(ShapeExprUtility.BroadcastShape(lhs, rhs), DataTypes.Int32);
+ return ShapeExprUtility.BroadcastShape(lhs, rhs);
}
private int Compute(BinaryOp op, int a, int b) => op switch
diff --git a/src/Nncase.Evaluator/Math/MatMul.cs b/src/Nncase.Evaluator/Math/MatMul.cs
index 444d221c28..4785e1e1c1 100644
--- a/src/Nncase.Evaluator/Math/MatMul.cs
+++ b/src/Nncase.Evaluator/Math/MatMul.cs
@@ -71,7 +71,7 @@ public Expr Visit(IShapeEvaluateContext context, MatMul target)
{
var lhs = context.GetArgumentShape(target, MatMul.Lhs);
var rhs = context.GetArgumentShape(target, MatMul.Rhs);
- return Cast(IR.F.ShapeExpr.MatMulShape(lhs, rhs), DataTypes.Int32);
+ return IR.F.ShapeExpr.MatMulShape(lhs, rhs);
}
private IRType Visit(TensorType lhs, TensorType rhs)
diff --git a/src/Nncase.Evaluator/Math/Reduce.cs b/src/Nncase.Evaluator/Math/Reduce.cs
index 69e512b717..ead03f4701 100644
--- a/src/Nncase.Evaluator/Math/Reduce.cs
+++ b/src/Nncase.Evaluator/Math/Reduce.cs
@@ -110,7 +110,7 @@ public Expr Visit(IShapeEvaluateContext context, Reduce target)
{
if (axes.Length == input.CheckedShape.Count && keepDimsValue == 0)
{
- return Array.Empty();
+ return Array.Empty();
}
}
@@ -119,7 +119,7 @@ public Expr Visit(IShapeEvaluateContext context, Reduce target)
var ax = ShapeExprUtility.Positive(axValue, inShape);
if (keepDimsValue == 1)
{
- outShape = ShapeExprUtility.Replace(outShape, ax, 1);
+ outShape = ShapeExprUtility.Replace(outShape, ax, 1L);
}
else
{
@@ -127,7 +127,7 @@ public Expr Visit(IShapeEvaluateContext context, Reduce target)
}
}
- return outShape;
+ return Cast(outShape, DataTypes.Int64);
}
throw new NotImplementedException();
diff --git a/src/Nncase.Evaluator/NN/BatchToSpace.cs b/src/Nncase.Evaluator/NN/BatchToSpace.cs
index d60473b41a..e53297c7ab 100644
--- a/src/Nncase.Evaluator/NN/BatchToSpace.cs
+++ b/src/Nncase.Evaluator/NN/BatchToSpace.cs
@@ -115,13 +115,13 @@ public Expr Visit(IShapeEvaluateContext context, BatchToSpace target)
inShape = Stack(new IR.Tuple(inShape[0], inShape[2], inShape[1]), 0);
}
- var blockShape = context.GetArgument(target, BatchToSpace.BlockShape);
+ var blockShape = Cast(context.GetArgument(target, BatchToSpace.BlockShape), DataTypes.Int64);
if (!blockShape.CheckedShape.IsFixed)
{
throw new NotImplementedException();
}
- var crops = context.GetArgument(target, BatchToSpace.Crops);
+ var crops = Cast(context.GetArgument(target, BatchToSpace.Crops), DataTypes.Int64);
var blockSize = Prod(blockShape);
var batch = inShape[0];
var d0 = batch / blockSize;
@@ -131,7 +131,7 @@ public Expr Visit(IShapeEvaluateContext context, BatchToSpace target)
var inRank = Cast(ShapeOf(inShape)[0], DataTypes.Int32);
var remainSize = inRank - 1 - m;
- var remainShape = new If(remainSize > 0, ShapeExprUtility.Slice(inShape, 1 + m, int.MaxValue), Array.Empty());
+ var remainShape = new If(remainSize > 0, ShapeExprUtility.Slice(inShape, 1 + m, int.MaxValue), Array.Empty());
var outShapeList = Concat(new IR.Tuple(Stack(new IR.Tuple(new[] { d0 }), 0), Stack(new IR.Tuple(cropSection), 0), remainShape), 0);
diff --git a/src/Nncase.Evaluator/NN/Conv2D.cs b/src/Nncase.Evaluator/NN/Conv2D.cs
index 67f361adc8..c26e219821 100644
--- a/src/Nncase.Evaluator/NN/Conv2D.cs
+++ b/src/Nncase.Evaluator/NN/Conv2D.cs
@@ -92,10 +92,10 @@ public Expr Visit(IShapeEvaluateContext context, Conv2D target)
{
var input = context.GetArgumentShape(target, Conv2D.Input);
var weights = context.GetArgumentShape(target, Conv2D.Weights);
- var pad = Cast(context.GetArgument(target, Conv2D.Padding), DataTypes.Int32);
- var stride = Cast(context.GetArgument(target, Conv2D.Stride), DataTypes.Int32);
- var dilation = Cast(context.GetArgument(target, Conv2D.Dilation), DataTypes.Int32);
- var groups = Cast(context.GetArgument(target, Conv2D.Groups), DataTypes.Int32);
+ var pad = context.GetArgument(target, Conv2D.Padding);
+ var stride = context.GetArgument(target, Conv2D.Stride);
+ var dilation = context.GetArgument(target, Conv2D.Dilation);
+ var groups = context.GetArgument(target, Conv2D.Groups);
return IR.F.ShapeExpr.Conv2DShape(input, weights, pad, stride, dilation, groups);
}
diff --git a/src/Nncase.Evaluator/NN/Pad.cs b/src/Nncase.Evaluator/NN/Pad.cs
index 13a4111456..5d33750659 100644
--- a/src/Nncase.Evaluator/NN/Pad.cs
+++ b/src/Nncase.Evaluator/NN/Pad.cs
@@ -118,11 +118,11 @@ public Expr Visit(IShapeEvaluateContext context, Pad target)
var end = Slice(pads, new[] { 1 }, new[] { 2 }, new[] { 1 }, new[] { 1 });
// paddings = [4, 2] -> [4, 1] + [4, 1]
- var paddings = front + end;
+ var paddings = Cast(front + end, DataTypes.Int64);
// outShape = inShape + paddings
- var padsSumShape = StackScalar(Cast(ShapeOf(paddings)[0], DataTypes.Int32));
- var outShape = inShape + Cast(Reshape(paddings, padsSumShape), DataTypes.Int32);
+ var padsSumShape = StackScalar(ShapeOf(paddings)[0]);
+ var outShape = inShape + Reshape(paddings, padsSumShape);
return outShape;
}
diff --git a/src/Nncase.Evaluator/NN/SpaceToBatch.cs b/src/Nncase.Evaluator/NN/SpaceToBatch.cs
index 756f8c1472..98ec949151 100644
--- a/src/Nncase.Evaluator/NN/SpaceToBatch.cs
+++ b/src/Nncase.Evaluator/NN/SpaceToBatch.cs
@@ -98,11 +98,11 @@ public Expr Visit(IShapeEvaluateContext context, SpaceToBatch target)
{
var inShape = context.GetArgumentShape(target, SpaceToBatch.Input);
var blockShape = context.GetArgument(target, SpaceToBatch.BlockShape);
- var padding = context.GetArgument(target, SpaceToBatch.Paddings);
+ var padding = Cast(context.GetArgument(target, SpaceToBatch.Paddings), DataTypes.Int64);
var input = context.GetArgument(target, SpaceToBatch.Input);
if (blockShape is TensorConst blockConst)
{
- var blockShapeValue = blockConst.Value.ToArray();
+ var blockShapeValue = blockConst.Value.ToArray();
var m = blockShapeValue.Length;
var inRank = input.CheckedShape.Rank;
@@ -122,7 +122,7 @@ public Expr Visit(IShapeEvaluateContext context, SpaceToBatch target)
}).ToArray();
var remainSize = inRank - 1 - m;
- var remainShape = new If(remainSize > 0, ShapeExprUtility.Slice(inShape, 1 + m, int.MaxValue), Array.Empty());
+ var remainShape = new If(remainSize > 0, ShapeExprUtility.Slice(inShape, 1 + m, int.MaxValue), Array.Empty());
var outLast = remainShape;
var outShape = Concat(new IR.Tuple(Stack(new IR.Tuple(outFirst.Concat(outMid).ToArray()), 0), outLast), 0);
return outShape;
diff --git a/src/Nncase.Evaluator/ShapeEvaluateContext.cs b/src/Nncase.Evaluator/ShapeEvaluateContext.cs
index e9919fb3d6..6f3e5ff94c 100644
--- a/src/Nncase.Evaluator/ShapeEvaluateContext.cs
+++ b/src/Nncase.Evaluator/ShapeEvaluateContext.cs
@@ -22,8 +22,13 @@ internal sealed class ShapeEvaluateContext : IShapeEvaluateContext
public ShapeEvaluateContext(Dictionary memo, ShapeExprCache cache)
{
_memo = memo;
- Cache = cache.Cache;
+ foreach (var (key, value) in cache.Cache)
+ {
+ _memo[key] = value;
+ }
+
VarMap = cache.VarMap;
+ Cache = new();
}
public IReadOnlyDictionary VarMap { get; }
@@ -55,7 +60,7 @@ public Expr GetArgumentShape(Op op, ParameterInfo parameter)
var expr = GetArgument(op, parameter);
if (expr is Tuple tuple)
{
- return new Tuple(tuple.Fields.ToArray().Select(v => Cast(GetResultFromMemo(v), DataTypes.Int32)).ToArray());
+ return new Tuple(tuple.Fields.ToArray().Select(v => Cast(GetResultFromMemo(v), DataTypes.Int64)).ToArray());
}
// call
@@ -64,7 +69,7 @@ public Expr GetArgumentShape(Op op, ParameterInfo parameter)
var shape = expr.EvaluateShapeExpr(new ShapeExprCache(VarMap));
if (shape is Call c && c.Target is IR.Math.Require && c.Arguments[IR.Math.Require.Value.Index] is Tuple tupleShapeExpr)
{
- return new Tuple(tupleShapeExpr.Fields.ToArray().Select(expr => Cast(expr, DataTypes.Int32)).ToArray());
+ return new Tuple(tupleShapeExpr.Fields.ToArray().Select(expr => Cast(expr, DataTypes.Int64)).ToArray());
}
// for split
@@ -76,23 +81,23 @@ public Expr GetArgumentShape(Op op, ParameterInfo parameter)
return new Tuple(
Enumerable
.Range(0, tupleType.Fields.Count)
- .Select(i => Cast(shape[i], DataTypes.Int32))
+ .Select(i => Cast(shape[i], DataTypes.Int64))
.ToArray());
}
else
{
- return new Tuple(((Tuple)shape).Fields.ToArray().Select(expr => Cast(expr, DataTypes.Int32)).ToArray());
+ return new Tuple(((Tuple)shape).Fields.ToArray().Select(expr => Cast(expr, DataTypes.Int64)).ToArray());
}
}
}
var shapeExpr = GetResultFromMemo(expr);
- return Cast(shapeExpr, DataTypes.Int32);
+ return Cast(shapeExpr, DataTypes.Int64);
}
public Expr GetArgumentRank(Op op, ParameterInfo parameter)
{
- return StackScalar(Cast(GetArgumentShape(op, parameter)[0], DataTypes.Int32));
+ return StackScalar(Cast(GetArgumentShape(op, parameter)[0], DataTypes.Int64));
}
private Call GetCurrentCall() => CurrentCall ?? throw new InvalidOperationException("Current call is not set.");
diff --git a/src/Nncase.Evaluator/ShapeEvaluateVisitor.cs b/src/Nncase.Evaluator/ShapeEvaluateVisitor.cs
index f5e8c3ee4d..04b24b5109 100644
--- a/src/Nncase.Evaluator/ShapeEvaluateVisitor.cs
+++ b/src/Nncase.Evaluator/ShapeEvaluateVisitor.cs
@@ -102,7 +102,7 @@ protected override Expr VisitLeafVar(Var expr)
throw new InvalidOperationException();
}
- var shapeExpr = shape.Select((x, i) => x.IsFixed ? x.FixedValue : _context.VarMap[expr][i]).Select(x => IR.F.Tensors.Cast(x, DataTypes.Int32)).ToArray();
+ var shapeExpr = shape.Select((x, i) => x.IsFixed ? x.FixedValue : _context.VarMap[expr][i]).Select(x => IR.F.Tensors.Cast(x, DataTypes.Int64)).ToArray();
return IR.F.Tensors.Stack(new IR.Tuple(shapeExpr), 0);
}
diff --git a/src/Nncase.Evaluator/ShapeEvaluatorProvider.cs b/src/Nncase.Evaluator/ShapeEvaluatorProvider.cs
index 6c9a785a5a..adb4d296d8 100644
--- a/src/Nncase.Evaluator/ShapeEvaluatorProvider.cs
+++ b/src/Nncase.Evaluator/ShapeEvaluatorProvider.cs
@@ -64,6 +64,7 @@ public Expr EvaluateOpShapeExpr(Op op, IShapeEvaluateContext context)
var evaluatorType = typeof(IShapeEvaluator<>).MakeGenericType(op.GetType());
var evaluator = (IShapeEvaluator)_serviceProvider.GetRequiredService(evaluatorType);
var result = evaluator.Visit(context, op);
+ var s = op.GetType().Name + "_" + op.DisplayProperty();
if (!result.InferenceType())
{
if (DumpScope.Current.IsEnabled(DumpFlags.Compile))
@@ -71,7 +72,7 @@ public Expr EvaluateOpShapeExpr(Op op, IShapeEvaluateContext context)
DumpScope.Current.DumpIR(result, "EvaluateOpShapeExprInvalidResult");
}
- throw new InvalidOperationException();
+ throw new InvalidOperationException(s);
}
return result;
diff --git a/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs b/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs
index 61cb3e9438..c7cb702bca 100644
--- a/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs
+++ b/src/Nncase.Evaluator/ShapeExpr/BroadcastShape.cs
@@ -41,9 +41,8 @@ public Cost Visit(ICostEvaluateContext context, BroadcastShape target)
public Expr Visit(IShapeEvaluateContext context, BroadcastShape target)
{
var inShape = context.GetArgumentShape(target, BroadcastShape.Inputs);
- var len = ((IR.Tuple)inShape).Fields.ToArray().Aggregate((Expr)1, (i, call) => IR.F.Math.Max(i, call));
- var bn = IR.F.Tensors.Cast(len, DataTypes.Int32);
- return bn;
+ var len = ((IR.Tuple)inShape).Fields.ToArray().Aggregate((Expr)1L, (i, call) => IR.F.Math.Max(i, call));
+ return len;
}
public Metric Visit(IMetricEvaluateContext context, BroadcastShape target)
diff --git a/src/Nncase.Evaluator/ShapeExpr/GetPaddings.cs b/src/Nncase.Evaluator/ShapeExpr/GetPaddings.cs
new file mode 100644
index 0000000000..5d5d775ea8
--- /dev/null
+++ b/src/Nncase.Evaluator/ShapeExpr/GetPaddings.cs
@@ -0,0 +1,89 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Linq;
+using Nncase.CostModel;
+using Nncase.Diagnostics;
+using Nncase.IR;
+using Nncase.IR.ShapeExpr;
+using Nncase.IR.Tensors;
+using Nncase.Utilities;
+using static Nncase.IR.F.Tensors;
+
+namespace Nncase.Evaluator.ShapeExpr;
+
+public partial class GetPaddingsEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IShapeEvaluator, IMetricEvaluator
+{
+ public static Expr ConcatPadding(Expr[] padH, Expr[] padW)
+ {
+ // return [[padh_before, padh_after],
+ // [padw_before, padw_after]]
+ return Stack(
+ new IR.Tuple(
+ Stack(new IR.Tuple(padH), 0),
+ Stack(new IR.Tuple(padW), 0)),
+ 0);
+ }
+
+ public IValue Visit(IEvaluateContext context, GetPaddings target)
+ {
+ var inShape = context.GetArgumentValueAsArray(target, GetPaddings.InputShape);
+ var wShape = context.GetArgumentValueAsArray(target, GetPaddings.WeightsShape);
+ var strides = context.GetArgumentValueAsArray(target, GetPaddings.Strides);
+ var dilations = context.GetArgumentValueAsArray(target, GetPaddings.Dilations);
+ var same = context.GetArgumentValueAsScalar(target, GetPaddings.Same);
+ var lower = context.GetArgumentValueAsScalar(target, GetPaddings.Lower);
+ var padH = GetWindowedPadding(inShape[2], wShape[2], strides[0], dilations[0], same, lower);
+ var padW = GetWindowedPadding(inShape[3], wShape[3], strides[1], dilations[1], same, lower);
+ return ConcatPadding(padH, padW).Evaluate();
+ }
+
+ public IRType Visit(ITypeInferenceContext context, GetPaddings target)
+ {
+ return new TensorType(DataTypes.Int64, new[] { 2, 2 });
+ }
+
+ public Cost Visit(ICostEvaluateContext context, GetPaddings target)
+ {
+ return CostUtility.GetShapeExprCost();
+ }
+
+ public Expr Visit(IShapeEvaluateContext context, GetPaddings target)
+ {
+ return new[] { 2, 2 };
+ }
+
+ public Metric Visit(IMetricEvaluateContext context, GetPaddings target)
+ {
+ var returnType = context.GetReturnType();
+ return new()
+ {
+ [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType),
+ };
+ }
+
+ private static Expr[] GetWindowedPadding(Expr inputSize, Expr filter, Expr stride, Expr dilation, bool same, bool lower = false)
+ {
+ var i32InputSize = Cast(inputSize, DataTypes.Int32);
+ var i32Filter = Cast(filter, DataTypes.Int32);
+ var i32Stride = Cast(stride, DataTypes.Int32);
+ var i32Dilation = Cast(dilation, DataTypes.Int32);
+ var outputSize = IR.Util.GetWindowedOutputSize(i32InputSize, i32Filter, i32Stride, i32Dilation, same, false);
+ return GetWindowedPaddingValue(i32InputSize, outputSize, i32Filter, i32Stride, i32Dilation, lower);
+ }
+
+ private static Expr[] GetWindowedPaddingValue(Expr inputSize, Expr outputSize, Expr filter, Expr stride, Expr dilation, bool lower)
+ {
+ var effectiveFilterSize = ((filter - 1) * dilation) + 1;
+ var padding = IR.F.Math.Max(0, ((outputSize - 1) * stride) + effectiveFilterSize - inputSize);
+ var before = Cast(padding / 2, DataTypes.Int32);
+ var after = Cast(padding - (padding / 2), DataTypes.Int32);
+ if (lower)
+ {
+ return new[] { IR.F.Math.Max(before, after), IR.F.Math.Min(before, after) };
+ }
+
+ return new[] { before, after };
+ }
+}
diff --git a/src/Nncase.Evaluator/ShapeExpr/ReshapeShape.cs b/src/Nncase.Evaluator/ShapeExpr/ReshapeShape.cs
new file mode 100644
index 0000000000..c8b7d1a536
--- /dev/null
+++ b/src/Nncase.Evaluator/ShapeExpr/ReshapeShape.cs
@@ -0,0 +1,50 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Linq;
+using Nncase.CostModel;
+using Nncase.Diagnostics;
+using Nncase.IR;
+using Nncase.IR.ShapeExpr;
+using Nncase.IR.Tensors;
+using Nncase.Utilities;
+
+namespace Nncase.Evaluator.ShapeExpr;
+
+public partial class ReshapeShapeEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IShapeEvaluator, IMetricEvaluator
+{
+ public IValue Visit(IEvaluateContext context, ReshapeShape target)
+ {
+ var inShape = context.GetArgumentValueAsArray(target, ReshapeShape.InputShape);
+ var shape = context.GetArgumentValueAsTensor(target, ReshapeShape.Shape);
+ var t = IR.F.Tensors.Reshape(new Var(new TensorType(DataTypes.Float32, inShape)), shape);
+ return ShapeExprUtility.GetShapeValue(t);
+ }
+
+ public IRType Visit(ITypeInferenceContext context, ReshapeShape target)
+ {
+ var shape = context.CheckArgumentType(target, ReshapeShape.Shape);
+ return new TensorType(DataTypes.Int64, shape.Shape.ToValueArray());
+ }
+
+ public Cost Visit(ICostEvaluateContext context, ReshapeShape target)
+ {
+ return CostUtility.GetShapeExprCost();
+ }
+
+ public Expr Visit(IShapeEvaluateContext context, ReshapeShape target)
+ {
+ var shape = context.GetArgument(target, ReshapeShape.Shape);
+ return shape.CheckedShape.ToValueArray();
+ }
+
+ public Metric Visit(IMetricEvaluateContext context, ReshapeShape target)
+ {
+ var returnType = context.GetReturnType();
+ return new()
+ {
+ [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType),
+ };
+ }
+}
diff --git a/src/Nncase.Evaluator/ShapeExpr/ShapeExprModule.cs b/src/Nncase.Evaluator/ShapeExpr/ShapeExprModule.cs
index bbc3631372..e0a1ec3505 100644
--- a/src/Nncase.Evaluator/ShapeExpr/ShapeExprModule.cs
+++ b/src/Nncase.Evaluator/ShapeExpr/ShapeExprModule.cs
@@ -20,5 +20,10 @@ public void ConfigureServices(IRegistrator registrator)
registrator.RegisterManyInterface(reuse: Reuse.Singleton);
registrator.RegisterManyInterface(reuse: Reuse.Singleton);
registrator.RegisterManyInterface(reuse: Reuse.Singleton);
+ registrator.RegisterManyInterface(reuse: Reuse.Singleton);
+ registrator.RegisterManyInterface(reuse: Reuse.Singleton);
+ registrator.RegisterManyInterface(reuse: Reuse.Singleton);
+ registrator.RegisterManyInterface(reuse: Reuse.Singleton);
+ registrator.RegisterManyInterface(reuse: Reuse.Singleton);
}
}
diff --git a/src/Nncase.Evaluator/ShapeExpr/SqueezeShape.cs b/src/Nncase.Evaluator/ShapeExpr/SqueezeShape.cs
new file mode 100644
index 0000000000..df1c668bdf
--- /dev/null
+++ b/src/Nncase.Evaluator/ShapeExpr/SqueezeShape.cs
@@ -0,0 +1,57 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Linq;
+using Nncase.CostModel;
+using Nncase.Diagnostics;
+using Nncase.IR;
+using Nncase.IR.ShapeExpr;
+using Nncase.IR.Tensors;
+using Nncase.Utilities;
+
+namespace Nncase.Evaluator.ShapeExpr;
+
+public partial class SqueezeShapeEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IShapeEvaluator, IMetricEvaluator
+{
+ public IValue Visit(IEvaluateContext context, SqueezeShape target)
+ {
+ var inShape = context.GetArgumentValueAsArray(target, SqueezeShape.InputShape);
+ var dims = context.GetArgumentValueAsTensor(target, SqueezeShape.Dim);
+ var t = IR.F.Tensors.Squeeze(new Var(new TensorType(DataTypes.Float32, inShape)), dims);
+ return ShapeExprUtility.GetShapeValue(t);
+ }
+
+ public IRType Visit(ITypeInferenceContext context, SqueezeShape target)
+ {
+ var input = context.GetArgument(target, SqueezeShape.InputShape);
+ var dims = context.CheckArgumentType(target, SqueezeShape.Dim);
+ if (!input.CheckedShape.IsFixed)
+ {
+ return new TensorType(DataTypes.Int64, new[] { Dimension.Unknown });
+ }
+
+ return new TensorType(DataTypes.Int64, new[] { input.CheckedShape.Size - dims.Shape[0] });
+ }
+
+ public Cost Visit(ICostEvaluateContext context, SqueezeShape target)
+ {
+ return CostUtility.GetShapeExprCost();
+ }
+
+ public Expr Visit(IShapeEvaluateContext context, SqueezeShape target)
+ {
+ var input = context.GetArgument(target, SqueezeShape.InputShape);
+ var dims = context.GetArgument(target, SqueezeShape.Dim);
+ return new[] { input.CheckedShape.Size - dims.CheckedShape[0].FixedValue };
+ }
+
+ public Metric Visit(IMetricEvaluateContext context, SqueezeShape target)
+ {
+ var returnType = context.GetReturnType();
+ return new()
+ {
+ [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType),
+ };
+ }
+}
diff --git a/src/Nncase.Evaluator/ShapeExpr/TransposeShape.cs b/src/Nncase.Evaluator/ShapeExpr/TransposeShape.cs
new file mode 100644
index 0000000000..49a1046627
--- /dev/null
+++ b/src/Nncase.Evaluator/ShapeExpr/TransposeShape.cs
@@ -0,0 +1,51 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Linq;
+using NetFabric.Hyperlinq;
+using Nncase.CostModel;
+using Nncase.Diagnostics;
+using Nncase.IR;
+using Nncase.IR.ShapeExpr;
+using Nncase.IR.Tensors;
+using Nncase.Utilities;
+
+namespace Nncase.Evaluator.ShapeExpr;
+
+public partial class TransposeShapeEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IShapeEvaluator, IMetricEvaluator
+{
+ public IValue Visit(IEvaluateContext context, TransposeShape target)
+ {
+ var inShape = context.GetArgumentValueAsArray(target, TransposeShape.InputShape);
+ var perm = context.GetArgumentValueAsTensor(target, TransposeShape.Perm);
+ var t = IR.F.Tensors.Transpose(new Var(new TensorType(DataTypes.Float32, inShape)), perm);
+ return ShapeExprUtility.GetShapeValue(t);
+ }
+
+ public IRType Visit(ITypeInferenceContext context, TransposeShape target)
+ {
+ var tt = context.CheckArgumentType(target, TransposeShape.InputShape);
+ return new TensorType(DataTypes.Int64, new[] { tt.Shape[0] });
+ }
+
+ public Cost Visit(ICostEvaluateContext context, TransposeShape target)
+ {
+ return CostUtility.GetShapeExprCost();
+ }
+
+ public Expr Visit(IShapeEvaluateContext context, TransposeShape target)
+ {
+ var input = context.GetArgument(target, TransposeShape.Perm);
+ return input.CheckedShape[0].FixedValue;
+ }
+
+ public Metric Visit(IMetricEvaluateContext context, TransposeShape target)
+ {
+ var returnType = context.GetReturnType();
+ return new()
+ {
+ [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType),
+ };
+ }
+}
diff --git a/src/Nncase.Evaluator/ShapeExpr/UnsqueezeShape.cs b/src/Nncase.Evaluator/ShapeExpr/UnsqueezeShape.cs
new file mode 100644
index 0000000000..c195b6c2dc
--- /dev/null
+++ b/src/Nncase.Evaluator/ShapeExpr/UnsqueezeShape.cs
@@ -0,0 +1,57 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Linq;
+using Nncase.CostModel;
+using Nncase.Diagnostics;
+using Nncase.IR;
+using Nncase.IR.ShapeExpr;
+using Nncase.IR.Tensors;
+using Nncase.Utilities;
+
+namespace Nncase.Evaluator.ShapeExpr;
+
+public partial class UnsqueezeShapeEvaluator : IEvaluator, ITypeInferencer, ICostEvaluator, IShapeEvaluator, IMetricEvaluator
+{
+ public IValue Visit(IEvaluateContext context, UnsqueezeShape target)
+ {
+ var inShape = context.GetArgumentValueAsArray(target, UnsqueezeShape.InputShape);
+ var dims = context.GetArgumentValueAsTensor(target, UnsqueezeShape.Dim);
+ var t = IR.F.Tensors.Unsqueeze(new Var(new TensorType(DataTypes.Float32, inShape)), dims);
+ return ShapeExprUtility.GetShapeValue(t);
+ }
+
+ public IRType Visit(ITypeInferenceContext context, UnsqueezeShape target)
+ {
+ var input = context.CheckArgumentType(target, UnsqueezeShape.InputShape);
+ var dims = context.CheckArgumentType(target, UnsqueezeShape.Dim);
+ if (!input.Shape.IsFixed)
+ {
+ return new TensorType(DataTypes.Int64, new[] { Dimension.Unknown });
+ }
+
+ return new TensorType(DataTypes.Int64, new[] { input.Shape.Size + dims.Shape[0] });
+ }
+
+ public Cost Visit(ICostEvaluateContext context, UnsqueezeShape target)
+ {
+ return CostUtility.GetShapeExprCost();
+ }
+
+ public Expr Visit(IShapeEvaluateContext context, UnsqueezeShape target)
+ {
+ var input = context.GetArgument(target, UnsqueezeShape.InputShape);
+ var dims = context.GetArgument(target, UnsqueezeShape.Dim);
+ return IR.F.Tensors.Stack(new IR.Tuple(new[] { IR.F.Tensors.ShapeOf(input)[0] + (long)dims.CheckedShape[0].FixedValue }), 0);
+ }
+
+ public Metric Visit(IMetricEvaluateContext context, UnsqueezeShape target)
+ {
+ var returnType = context.GetReturnType();
+ return new()
+ {
+ [MetricFactorNames.OffChipMemoryTraffic] = CostUtility.GetMemoryAccess(returnType),
+ };
+ }
+}
diff --git a/src/Nncase.Evaluator/Tensors/Concat.cs b/src/Nncase.Evaluator/Tensors/Concat.cs
index 994e76de22..d1098ccf7b 100644
--- a/src/Nncase.Evaluator/Tensors/Concat.cs
+++ b/src/Nncase.Evaluator/Tensors/Concat.cs
@@ -54,8 +54,8 @@ public Expr Visit(IShapeEvaluateContext context, Concat target)
var inShape = context.GetArgumentShape(target, Concat.Input);
var axis = context.GetArgument(target, Concat.Axis);
var axisV = ShapeExprUtility.Positive(axis, inShape[0]);
- var inShapes = ((IR.Tuple)inShape).Fields.ToArray().Select(x => Cast(x, DataTypes.Int32)).ToArray();
- var dim = inShapes.ToArray().Aggregate((Expr)0, (sum, shape) => sum + shape[axisV]);
+ var inShapes = ((IR.Tuple)inShape).Fields.ToArray().Select(x => Cast(x, DataTypes.Int64)).ToArray();
+ var dim = inShapes.ToArray().Aggregate((Expr)0L, (sum, shape) => sum + shape[axisV]);
var outShape = ShapeExprUtility.Replace(inShapes[0], axisV, dim);
return outShape;
}
diff --git a/src/Nncase.Evaluator/Tensors/Range.cs b/src/Nncase.Evaluator/Tensors/Range.cs
index 5f185fe350..dac1afb1d1 100644
--- a/src/Nncase.Evaluator/Tensors/Range.cs
+++ b/src/Nncase.Evaluator/Tensors/Range.cs
@@ -95,6 +95,6 @@ public Expr Visit(IShapeEvaluateContext context, Range target)
var begin = context.GetArgument(target, Range.Begin);
var end = context.GetArgument(target, Range.End);
var step = context.GetArgument(target, Range.Step);
- return ShapeExprUtility.StackOne((end - begin) / step);
+ return IR.F.Tensors.Cast(StackOne((end - begin) / step), DataTypes.Int64);
}
}
diff --git a/src/Nncase.Evaluator/Tensors/Reshape.cs b/src/Nncase.Evaluator/Tensors/Reshape.cs
index 615728f7c7..38c4d150ee 100644
--- a/src/Nncase.Evaluator/Tensors/Reshape.cs
+++ b/src/Nncase.Evaluator/Tensors/Reshape.cs
@@ -53,32 +53,9 @@ Cost ICostEvaluator.Visit(ICostEvaluateContext context, Reshape target)
public Expr Visit(IShapeEvaluateContext context, Reshape target)
{
+ var inShape = context.GetArgumentShape(target, Reshape.Input);
var shape = context.GetArgument(target, Reshape.Shape);
- var inputShape = Cast(context.GetArgumentShape(target, Reshape.Input), DataTypes.Int32);
- if (shape is TensorConst shapeConst)
- {
- var shapeArray = shapeConst.Value.ToArray();
- var negIndex = shapeArray.IndexOf(-1);
- if (negIndex < 0)
- {
- return shapeArray;
- }
-
- var dim = Prod(inputShape) / System.Math.Abs(shapeArray.Aggregate((s, x) => x * s));
- var rhs = shapeArray.Select((_, i) => i == negIndex ? dim + 1 : (Expr)0).ToArray();
- var newShape = Stack(new IR.Tuple(rhs), 0);
-
- // dim = Product(inShape) / Produce(Reshape.Shape)
- // e.g. [1, 3, -1, 24] + [dim + 1, 0] = [1, 3, dim, 24]
- return newShape + shapeArray;
- }
-
- shape = Cast(shape, DataTypes.Int32);
- var iSize = Prod(inputShape);
- var sSize = Prod(shape);
- var negDimInfactValue = iSize / Abs(sSize);
- var index = IndexOf(shape, -1);
- return new If(sSize < 0, ShapeExprUtility.Replace(shape, index, negDimInfactValue), shape);
+ return IR.F.ShapeExpr.ReshapeShape(inShape, shape);
}
public Metric Visit(IMetricEvaluateContext context, Reshape target)
@@ -93,12 +70,10 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i
return input;
}
- if (context.GetArgument(target, Reshape.Shape) is TensorConst shapeConst &&
- input.Shape.IsFixed)
+ if (context.GetArgument(target, Reshape.Shape) is TensorConst shapeConst)
{
var shapeValue = shapeConst.Value.ToArray();
var negCount = shapeValue.Count(IsMinus1);
- var inputSize = input.Shape.Prod().FixedValue;
var shapeSize = shapeValue.Aggregate(1, (x, y) => x * y);
if (negCount > 1)
{
@@ -106,27 +81,39 @@ private IRType Visit(ITypeInferenceContext context, Reshape target, TensorType i
$"Reshape at most one dimension of the new shape can be -1," +
$" shape:{shapeValue}");
}
- else if (negCount < 1)
+
+ if (input.Shape.IsFixed)
{
- if (inputSize != shapeSize)
+ var inputSize = input.Shape.Prod().FixedValue;
+ if (negCount < 1)
{
- return new InvalidType("Reshape input shape size and param shape size must be same," +
- $" shape:{shapeValue.ToArray().Aggregate(string.Empty, (s, i) => s + i + " ")}, input shape${string.Join(",", input.Shape)}");
- }
+ if (inputSize != shapeSize)
+ {
+ return new InvalidType("Reshape input shape size and param shape size must be same," +
+ $" shape:{shapeValue.ToArray().Aggregate(string.Empty, (s, i) => s + i + " ")}, input shape${string.Join(",", input.Shape)}");
+ }
- return input with { Shape = new Shape(shapeValue) };
+ return input with { Shape = new Shape(shapeValue) };
+ }
+ else
+ {
+ shapeSize = -shapeSize;
+ var negIndex = shapeValue.Select((dim, index) => (dim, index)).First(x => IsMinus1(x.dim)).index;
+ if (inputSize % shapeSize != 0)
+ {
+ return new InvalidType("Reshape input size must be divisible by shapeSize when has -1");
+ }
+
+ shapeValue[negIndex] = inputSize / shapeSize;
+ return input with { Shape = new Shape(shapeValue) };
+ }
}
else
{
- shapeSize = -shapeSize;
- var negIndex = shapeValue.Select((dim, index) => (dim, index)).First(x => IsMinus1(x.dim)).index;
- if (inputSize % shapeSize != 0)
+ return input with
{
- return new InvalidType("Reshape input size must be divisible by shapeSize when has -1");
- }
-
- shapeValue[negIndex] = inputSize / shapeSize;
- return input with { Shape = new Shape(shapeValue) };
+ Shape = new Shape(shapeValue.Select(x => x == -1 ? Dimension.Unknown : x).ToArray()),
+ };
}
}
diff --git a/src/Nncase.Evaluator/Tensors/Slice.cs b/src/Nncase.Evaluator/Tensors/Slice.cs
index 219894f0da..eada657d01 100644
--- a/src/Nncase.Evaluator/Tensors/Slice.cs
+++ b/src/Nncase.Evaluator/Tensors/Slice.cs
@@ -104,7 +104,7 @@ public Expr Visit(IShapeEvaluateContext context, Slice target)
return Ceil(Abs(end - begin) / Abs(stride));
}).ToArray();
- return Stack(new IR.Tuple(outDims), 0);
+ return Cast(Stack(new IR.Tuple(outDims), 0), DataTypes.Int64);
}
/// Axis.
diff --git a/src/Nncase.Evaluator/Tensors/Squeeze.cs b/src/Nncase.Evaluator/Tensors/Squeeze.cs
index 05517c5b4d..19573e7208 100644
--- a/src/Nncase.Evaluator/Tensors/Squeeze.cs
+++ b/src/Nncase.Evaluator/Tensors/Squeeze.cs
@@ -38,28 +38,9 @@ public Cost Visit(ICostEvaluateContext context, Squeeze target)
public Expr Visit(IShapeEvaluateContext context, Squeeze target)
{
- var inShape = context.GetArgumentShape(target, Squeeze.Input);
- var input = context.GetArgument(target, Squeeze.Input);
+ var input = context.GetArgumentShape(target, Squeeze.Input);
var dims = context.GetArgument(target, Squeeze.Dim);
- if (dims is TensorConst dimConst)
- {
- var rank = input.CheckedShape.Count;
- var dimValue = dimConst.Value.ToArray().Select(x => x < 0 ? x + rank : x).ToArray();
- var outDims = Enumerable.Range(0, rank).Where(i => !dimValue.Contains(i)).Select(i => inShape[i]).ToArray();
- if (outDims.Length == 0)
- {
- return 1;
- }
-
- if (outDims.Length == input.CheckedShape.Rank)
- {
- throw new InvalidOperationException("Bad Squeeze Shape Expr");
- }
-
- return IR.F.Tensors.Stack(new IR.Tuple(outDims), 0);
- }
-
- throw new NotImplementedException();
+ return IR.F.ShapeExpr.SqueezeShape(input, dims);
}
public Metric Visit(IMetricEvaluateContext context, Squeeze target)
diff --git a/src/Nncase.Evaluator/Tensors/Stack.cs b/src/Nncase.Evaluator/Tensors/Stack.cs
index 00cb5c4472..4d07bda39e 100644
--- a/src/Nncase.Evaluator/Tensors/Stack.cs
+++ b/src/Nncase.Evaluator/Tensors/Stack.cs
@@ -58,10 +58,10 @@ public Cost Visit(ICostEvaluateContext context, Stack target)
public Expr Visit(IShapeEvaluateContext context, Stack target)
{
var inShape = context.GetArgumentShape(target, Stack.Inputs);
- Expr one = new[] { 1 };
+ Expr one = new[] { 1L };
if (inShape[0].CheckedShape.IsScalar)
{
- one = 1;
+ one = 1L;
}
return IR.F.Tensors.Concat(new IR.Tuple(inShape[0], one), 0);
diff --git a/src/Nncase.Evaluator/Tensors/Tile.cs b/src/Nncase.Evaluator/Tensors/Tile.cs
index 122f122d5a..19bebffdfa 100644
--- a/src/Nncase.Evaluator/Tensors/Tile.cs
+++ b/src/Nncase.Evaluator/Tensors/Tile.cs
@@ -46,7 +46,7 @@ public Expr Visit(IShapeEvaluateContext context, Tile target)
{
var inShape = context.GetArgumentShape(target, Tile.Input);
var repeats = context.GetArgument(target, Tile.Repeats);
- return inShape * IR.F.Tensors.Cast(repeats, DataTypes.Int32);
+ return inShape * IR.F.Tensors.Cast(repeats, DataTypes.Int64);
}
public Metric Visit(IMetricEvaluateContext context, Tile target)
diff --git a/src/Nncase.Evaluator/Tensors/Transpose.cs b/src/Nncase.Evaluator/Tensors/Transpose.cs
index e4bd2ee45f..77e74d07f1 100644
--- a/src/Nncase.Evaluator/Tensors/Transpose.cs
+++ b/src/Nncase.Evaluator/Tensors/Transpose.cs
@@ -92,17 +92,9 @@ public Metric Visit(IMetricEvaluateContext context, Transpose target)
public Expr Visit(IShapeEvaluateContext context, Transpose target)
{
- var perm = context.GetArgument(target, Transpose.Perm);
- var rank = context.GetArgument(target, Transpose.Input).CheckedShape.Rank;
- var permValue = IR.F.Tensors.Cast(perm, DataTypes.Int32);
var inShape = context.GetArgumentShape(target, Transpose.Input);
- var outShape = Enumerable.Range(0, rank).Select(i => inShape[i]).ToArray();
- for (int i = 0; i < rank; i++)
- {
- outShape[i] = inShape[permValue[i]];
- }
-
- return IR.F.Tensors.Stack(new IR.Tuple(outShape), 0);
+ var perm = context.GetArgument(target, Transpose.Perm);
+ return IR.F.ShapeExpr.TransposeShape(inShape, perm);
}
private IRType Visit(ITypeInferenceContext context, Transpose target, TensorType input)
diff --git a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs
index b8c937c44a..bd86940fee 100644
--- a/src/Nncase.Evaluator/Tensors/UnSqueeze.cs
+++ b/src/Nncase.Evaluator/Tensors/UnSqueeze.cs
@@ -43,29 +43,9 @@ public Cost Visit(ICostEvaluateContext context, Unsqueeze target)
public Expr Visit(IShapeEvaluateContext context, Unsqueeze target)
{
+ var input = context.GetArgumentShape(target, Unsqueeze.Input);
var dims = context.GetArgument(target, Unsqueeze.Dim);
- if (dims is TensorConst dimsConst)
- {
- var dimsValue = dimsConst.Value.ToArray();
- var outShape = context.GetArgumentShape(target, Unsqueeze.Input);
-
- foreach (var dimVal in dimsValue)
- {
- if (dimVal >= 0)
- {
- outShape = ShapeExprUtility.Insert(outShape, dimVal, 1);
- }
- else
- {
- var index = IR.F.Math.Max(IR.F.Tensors.Cast(IR.F.Tensors.ShapeOf(outShape)[0], DataTypes.Int32) + dimVal + 1, 0);
- outShape = ShapeExprUtility.Insert(outShape, index, 1);
- }
- }
-
- return outShape;
- }
-
- throw new NotImplementedException();
+ return IR.F.ShapeExpr.UnsqueezeShape(input, dims);
}
public Metric Visit(IMetricEvaluateContext context, Unsqueeze target) => Metric.Zero;
diff --git a/src/Nncase.Importer/TFLite/Conv2D.cs b/src/Nncase.Importer/TFLite/Conv2D.cs
index 0fc41dfb7f..260ee76d0e 100644
--- a/src/Nncase.Importer/TFLite/Conv2D.cs
+++ b/src/Nncase.Importer/TFLite/Conv2D.cs
@@ -29,19 +29,14 @@ private Expr VisitConv2D(in tflite.Operator op)
weights = F.Tensors.NHWCToNCHW(weights);
var bias = GetInputExprs(op, 2);
var options = op.BuiltinOptionsAsConv2DOptions();
- var (inH, inW) = Util.GetHW(input);
- var (fH, fW) = Util.GetHW(weights);
var strideH = options.StrideH;
var strideW = options.StrideW;
var dilationH = options.DilationHFactor;
var dilationW = options.DilationWFactor;
- var padH = Util.GetWindowedPadding(inH, fH, strideH, dilationH, options.Padding == tflite.Padding.SAME);
- var padW = Util.GetWindowedPadding(inW, fW, strideW, dilationW, options.Padding == tflite.Padding.SAME);
var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 });
var dilation = Tensor.From(new[] { dilationH, dilationW }, new[] { 2 });
- var padding = Util.ConcatPadding(padH, padW);
+ var padding = Util.GetPaddings(input, weights, stride, dilation, options.Padding == tflite.Padding.SAME, false);
var clamp = ToFloatClampRange(options.FusedActivationFunction);
-
var inputQuantParams = GetInputQuantParams(op, 0);
var weightsQuantParams = GetInputQuantParams(op, 1);
var biasQuantParams = GetInputQuantParams(op, 2);
@@ -123,19 +118,14 @@ private Expr VisitDepthwiseConv2D(in tflite.Operator op)
var bias = GetInputExprs(op, 2);
input = F.Tensors.NHWCToNCHW(input);
weights = F.Tensors.Transpose(weights, new[] { 3, 0, 1, 2 });
- _ = GetTensorShape(GetInputTensor(op, 1));
var options = op.BuiltinOptionsAsDepthwiseConv2DOptions();
- var (inH, inW) = Util.GetHW(input);
- var (fH, fW) = Util.GetHW(weights);
var strideH = options.StrideH;
var strideW = options.StrideW;
var dilationH = options.DilationHFactor;
var dilationW = options.DilationWFactor;
- var padH = Util.GetWindowedPadding(inH, fH, strideH, dilationH, options.Padding == tflite.Padding.SAME);
- var padW = Util.GetWindowedPadding(inW, fW, strideW, dilationW, options.Padding == tflite.Padding.SAME);
var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 });
var dilation = Tensor.From(new[] { dilationH, dilationW }, new[] { 2 });
- var padding = Util.ConcatPadding(padH, padW);
+ var padding = Util.GetPaddings(input, weights, stride, dilation, options.Padding == tflite.Padding.SAME, false);
var depthMul = options.DepthMultiplier;
if (depthMul != 1)
{
diff --git a/src/Nncase.Importer/TFLite/Conv2DTranspose.cs b/src/Nncase.Importer/TFLite/Conv2DTranspose.cs
index d5687e7cdb..ec8e64f9af 100644
--- a/src/Nncase.Importer/TFLite/Conv2DTranspose.cs
+++ b/src/Nncase.Importer/TFLite/Conv2DTranspose.cs
@@ -32,17 +32,15 @@ private Expr VisitConv2DTranspose(in tflite.Operator op)
}
var options = op.BuiltinOptionsAsTransposeConvOptions();
- var (_, _) = Util.GetHW(input);
- var (fH, fW) = Util.GetHW(weights, true);
var strideH = options.StrideH;
var strideW = options.StrideW;
var dilationH = 1;
var dilationW = 1;
- var padH = Util.GetWindowedPadding(newOutShape[2], fH, strideH, dilationH, options.Padding == tflite.Padding.SAME);
- var padW = Util.GetWindowedPadding(newOutShape[3], fW, strideW, dilationW, options.Padding == tflite.Padding.SAME);
var stride = Tensor.From(new[] { strideH, strideW }, new[] { 2 });
var dilation = Tensor.From(new[] { dilationH, dilationW }, new[] { 2 });
- var padding = Util.ConcatPadding(padH, padW);
+ var oldWShape = F.Tensors.ShapeOf(weights);
+ var wShape = F.Tensors.Stack(new IR.Tuple(oldWShape[0], oldWShape[3], oldWShape[1], oldWShape[2]), 0);
+ var padding = F.ShapeExpr.GetPaddings(F.Tensors.Stack(new IR.Tuple(newOutShape), 0), wShape, stride, dilation, options.Padding == tflite.Padding.SAME, false);
var clamp = ValueRange.Full;
return F.Tensors.NCHWToNHWC(F.Math.Clamp(
diff --git a/src/Nncase.Importer/Util.cs b/src/Nncase.Importer/Util.cs
index a3ca52d74e..762d503235 100644
--- a/src/Nncase.Importer/Util.cs
+++ b/src/Nncase.Importer/Util.cs
@@ -90,13 +90,9 @@ public static Expr[] GetWindowedPadding(Expr inputSize, Expr filter, Expr stride
}
// lower used for onnx when auto_pad attr is SAME_LOWER
- public static Expr GetPaddings(Expr input, Expr weights, long[] stride, long[] dilation, bool same, bool lower = false)
+ public static Expr GetPaddings(Expr input, Expr weights, Expr stride, Expr dilation, bool same, bool lower = false)
{
- var (inH, inW) = GetHW(input);
- var (fH, fW) = GetHW(weights);
- var padH = GetWindowedPadding(inH, fH, (int)stride[0], (int)dilation[0], same, lower);
- var padW = GetWindowedPadding(inW, fW, (int)stride[1], (int)dilation[1], same, lower);
- return ConcatPadding(padH, padW);
+ return IR.F.ShapeExpr.GetPaddings(ShapeOf(input), ShapeOf(weights), stride, dilation, same, lower);
}
public static Expr ComputeSplit(Expr input, long outputSize, long axis)
diff --git a/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs b/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs
index a5324fd11a..2d12883101 100644
--- a/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs
+++ b/src/Nncase.Passes/Rules/Neutral/FoldReshape.cs
@@ -78,6 +78,11 @@ public sealed partial class ReshapeToTranspose : IRewriteRule
private Expr? GetReplace(Expr input, Call call)
{
+ if (input.CheckedShape.Rank <= 1)
+ {
+ return null;
+ }
+
var newShape = call.CheckedShape.ToValueArray();
var inShape = input.CheckedShape.ToValueArray();
var sigNewShape = newShape.Where(x => x != 1).ToArray();
diff --git a/src/Nncase.Passes/Rules/Neutral/FoldSqueeze.cs b/src/Nncase.Passes/Rules/Neutral/FoldSqueeze.cs
new file mode 100644
index 0000000000..87437d0bdc
--- /dev/null
+++ b/src/Nncase.Passes/Rules/Neutral/FoldSqueeze.cs
@@ -0,0 +1,62 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Collections.Immutable;
+using System.Linq;
+using Nncase.IR;
+using Nncase.IR.Math;
+using Nncase.IR.Tensors;
+using Nncase.Passes;
+using Nncase.PatternMatch;
+using static Nncase.IR.F.Math;
+using static Nncase.IR.F.Tensors;
+using static Nncase.IR.TypePatternUtility;
+using static Nncase.PatternMatch.F.Math;
+using static Nncase.PatternMatch.F.Tensors;
+using static Nncase.PatternMatch.Utility;
+
+namespace Nncase.Passes.Rules.Neutral;
+
+[RuleGenerator]
+public partial class FoldUnsqueezeSqueeze : RewriteRule
+{
+ ///
+ public override Pattern Pattern => IsUnsqueeze(
+ "unsqu",
+ "output",
+ IsSqueeze(IsWildcard("input"), IsTensorConst("sqAxes")),
+ IsTensorConst("unsqAxes"));
+
+ private Expr? GetReplace(Call output, Expr input)
+ {
+ if (output.CheckedShape.SequenceEqual(input.CheckedShape))
+ {
+ return input;
+ }
+
+ return null;
+ }
+}
+
+[RuleGenerator]
+public partial class FoldSqueezeUnsqueeze : RewriteRule
+{
+ ///
+ public override Pattern Pattern => IsSqueeze(
+ "sqOp",
+ "output",
+ IsUnsqueeze(IsWildcard("input"), IsTensorConst("unsqAxes")),
+ IsTensorConst("sqAxes"));
+
+ private Expr? GetReplace(Call output, Expr input)
+ {
+ if (output.CheckedShape.SequenceEqual(input.CheckedShape))
+ {
+ return input;
+ }
+
+ return null;
+ }
+}
diff --git a/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs b/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs
index c3611f11b7..d5e9d6256f 100644
--- a/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs
+++ b/src/Nncase.Passes/Rules/Neutral/FoldTranspose.cs
@@ -101,6 +101,11 @@ public sealed partial class TransposeToReshape : IRewriteRule
private Expr? GetReplace(Expr input, Expr tp, Tensor perm, RunPassContext context)
{
+ if (input.CheckedShape.Rank <= 1)
+ {
+ return null;
+ }
+
// If all significant dims remains ascending order, it can be converted to a reshape.
var inShape = input.CheckedShape;
var sigAxes = new HashSet();
diff --git a/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs b/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs
new file mode 100644
index 0000000000..4361b2db15
--- /dev/null
+++ b/src/Nncase.Passes/Rules/Neutral/SplitSpaceToBatch.cs
@@ -0,0 +1,180 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using Nncase.Diagnostics;
+using Nncase.IR;
+using Nncase.IR.Tensors;
+using Nncase.Passes.Rules.ShapeExpr;
+using Nncase.PatternMatch;
+using Nncase.Utilities;
+using OrtKISharp;
+using static Nncase.IR.F.Math;
+using static Nncase.IR.F.NN;
+using static Nncase.IR.F.Tensors;
+using static Nncase.IR.TypePatternUtility;
+using static Nncase.PatternMatch.F.Math;
+using static Nncase.PatternMatch.F.NN;
+using static Nncase.PatternMatch.F.Tensors;
+using static Nncase.PatternMatch.Utility;
+
+namespace Nncase.Passes.Rules.Neutral;
+
+[RuleGenerator]
+public partial class SplitSpaceToBatch : RewriteRule
+{
+ ///
+ public override Pattern Pattern { get; } = IsSpaceToBatch(
+ IsWildcard("input") with { TypePattern = HasRank() },
+ IsWildcard("blockShape") with { TypePattern = HasFixedShape() },
+ IsWildcard("paddings"));
+
+ public Expr? GetReplace(Expr input, Expr blockShape, Expr paddings)
+ {
+ var spatialSize = blockShape.CheckedShape.Size;
+ var remainShapeSize = input.CheckedShape.Rank - spatialSize - 1;
+ var newPaddings = Enumerable.Repeat((Expr)0, (1 + spatialSize + remainShapeSize) * 2).ToArray();
+ for (int i = 0; i < spatialSize; i++)
+ {
+ newPaddings[1 + i] = paddings[i, 0];
+ newPaddings[1 + (newPaddings.Length / 2) + i] = paddings[i, 1];
+ }
+
+ var tmpPaddings = Stack(new IR.Tuple(newPaddings), 0);
+ var newPaddingsTensor = Transpose(Reshape(tmpPaddings, new long[] { 2, 1 + spatialSize + remainShapeSize }), new long[] { 1, 0 });
+ var p = Pad(input, newPaddingsTensor, PadMode.Constant, 0f);
+
+ var padShape = Cast(ShapeOf(p), DataTypes.Int32);
+ var batchShape1 = StackScalar(padShape[0]);
+ var spatialShape1 = RangeExec(
+ spatialSize,
+ i => Stack(new IR.Tuple(padShape[i + 1] / blockShape[i], blockShape[i]), 0))
+ .Aggregate((x, y) => Concat(new IR.Tuple(x, y), 0));
+ var remainShape1 = Stack(new IR.Tuple(RangeExec(remainShapeSize, i => padShape[1 + spatialSize + i])), 0);
+ var reshappedShape1 = Concat(
+ new IR.Tuple(
+ batchShape1,
+ spatialShape1,
+ remainShape1),
+ 0);
+
+ var perm = RangeExec(spatialSize, i => (i * 2) + 2)
+ .Concat(new[] { 0 })
+ .Concat(RangeExec(spatialSize, i => (i * 2) + 1))
+ .Concat(RangeExec(remainShapeSize, i => i + ((int)spatialSize * 2) + 1))
+ .Select(x => (long)x)
+ .ToArray();
+
+ var reshappedShape2 = Concat(
+ input: new IR.Tuple(
+ StackScalar(padShape[0] * Prod(blockShape)),
+ Stack(new IR.Tuple(RangeExec(spatialSize, i => padShape[i + 1] / blockShape[i])), 0),
+ Stack(new IR.Tuple(RangeExec(remainShapeSize, i => padShape[1 + spatialSize + i])), 0)),
+ 0);
+
+ var reshape1 = Reshape(p, reshappedShape1);
+ var rt = Transpose(reshape1, perm);
+ var reshape2 = Reshape(rt, reshappedShape2);
+ return reshape2;
+ }
+
+ private T[] RangeExec(long end, Func f)
+ {
+ return EndRange(0, (int)end).Select(f).ToArray();
+ }
+
+ private IEnumerable EndRange(int begin, int end)
+ {
+ return Enumerable.Range(begin, end - begin);
+ }
+}
+
+[RuleGenerator]
+public partial class SplitBatchToSpace : RewriteRule
+{
+ ///
+ public override Pattern Pattern { get; } = IsBatchToSpace(
+ IsWildcard("input") with { TypePattern = HasRank() },
+ IsWildcard("blockShape") with { TypePattern = HasFixedShape() },
+ IsWildcard("crop"));
+
+ public Expr? GetReplace(Expr input, Expr blockShape, Expr crop)
+ {
+ // to nhwc
+ var input0 = NCHWToNHWC(input);
+ var blockLen = blockShape.CheckedShape.Size;
+ var xLen = input0.CheckedShape.Rank;
+ var xShape = Cast(ShapeOf(input0), DataTypes.Int32);
+ var spatial = ShapeExprUtility.Slice(xShape, 1, blockLen + 1);
+ var depth = ShapeExprUtility.Slice(xShape, blockLen + 1, xLen);
+ var targetSpatial = spatial * blockShape;
+
+ var ccat1 = Concat(new IR.Tuple(spatial, blockShape), 0);
+ var re1 = Reshape(ccat1, new[] { ccat1.CheckedShape[0].FixedValue / blockLen, blockLen });
+ var interLeave = Reshape(Transpose(re1, new long[] { 1, 0 }), new long[] { -1 });
+ var shape1 = Concat(new IR.Tuple(new int[] { -1 }, interLeave, depth), 0);
+
+ var g1 = BoostRange(2, (2 * blockLen) + 1, 2);
+ var g2 = BoostRange(1, (2 * blockLen) + 1, 2);
+ var g3 = BoostRange(0, xLen + blockLen).ToArray()[1 + (2 * blockLen)];
+ var indices = g1.Append(0).Concat(g2).Append(g3);
+
+ var perm = GetPerm(xLen, blockLen);
+
+ var newShape = indices.Select(i => shape1[i]).ToArray();
+ var x2 = Reshape(input0, Stack(new IR.Tuple(newShape), 0));
+ var tr2 = Transpose(x2, perm);
+ var shape2 = Concat(new IR.Tuple(new[] { -1 }, targetSpatial, depth), 0);
+ var x3 = Reshape(tr2, shape2);
+
+ var cropTransposed = Transpose(crop, new long[] { 1, 0 });
+ var cropArray = Reshape(cropTransposed, new long[] { -1 });
+ var w = cropTransposed.CheckedShape[1].FixedValue;
+ var cropStart = ShapeExprUtility.Slice(cropArray, 0, w);
+ var cropEnd = ShapeExprUtility.Slice(cropArray, w, w + w);
+ var endRange = targetSpatial - cropEnd;
+ var axesConst = BoostRange(1, blockLen + 1).ToArray();
+ var strideConst = Enumerable.Repeat(1, axesConst.Length).ToArray();
+ var result = Slice(x3, cropStart, endRange, axesConst, strideConst);
+
+ // to nchw
+ var transposeResult = NHWCToNCHW(result);
+ return transposeResult;
+ }
+
+ private static IEnumerable BoostRange(int start, int end, int step = 1)
+ {
+ int x = start;
+ do
+ {
+ yield return x;
+ x += step;
+ if ((step < 0 && x <= end) || (step > 0 && end <= x))
+ {
+ break;
+ }
+ }
+ while (true);
+ }
+
+ private long[] GetPerm(int xLen, int blockLen)
+ {
+ var perm = Enumerable.Range(0, xLen + blockLen).ToArray();
+ perm[0] = blockLen;
+ perm[1] = blockLen + 1;
+ perm[2] = 0;
+ foreach (var i in BoostRange(3, (blockLen * 2) + 1))
+ {
+ perm[i] = perm[i - 2] + 1;
+ }
+
+ return perm.Select(x => (long)x).ToArray();
+ }
+
+ private T[] ZipExec(T[] a, T[] b, Func f)
+ {
+ return a.Zip(b).Select(x => f(x.First, x.Second)).ToArray();
+ }
+}
diff --git a/src/Nncase.Passes/Rules/ShapeBucket/FoldBucketReshape.cs b/src/Nncase.Passes/Rules/ShapeBucket/FoldBucketReshape.cs
deleted file mode 100644
index 2c78c5ca35..0000000000
--- a/src/Nncase.Passes/Rules/ShapeBucket/FoldBucketReshape.cs
+++ /dev/null
@@ -1,44 +0,0 @@
-// Copyright (c) Canaan Inc. All rights reserved.
-// Licensed under the Apache license. See LICENSE file in the project root for full license information.
-
-using System.Linq;
-using Nncase.IR;
-using Nncase.IR.Tensors;
-using Nncase.PatternMatch;
-using Nncase.Utilities;
-using static Nncase.IR.TypePatternUtility;
-using static Nncase.PatternMatch.F.Tensors;
-using static Nncase.PatternMatch.Utility;
-
-namespace Nncase.Passes.Rules.ShapeBucket;
-
-[RuleGenerator]
-public sealed partial class FoldBucketPadReshape : RewriteRule
-{
- // Reshape(Gather(Shape, 0, 0), new[] { 0 }) -> GetItem(Shape, 0)
- public override Pattern Pattern => IsReshape(
- IsBucketPad(null, "bucketPad", IsWildcard(), IsTensorConst()),
- IsTensorConst("newShape"));
-
- private Expr? GetReplace(Call bucketPad, Expr newShape)
- {
- return ReplaceUtility.ReplaceCallParams(bucketPad, (BucketPad.Shape.Index, newShape));
- }
-}
-
-// todo: squeeze
-[RuleGenerator]
-public sealed partial class FoldBucketPadUnsqueeze : RewriteRule
-{
- // Reshape(Gather(Shape, 0, 0), new[] { 0 }) -> GetItem(Shape, 0)
- public override Pattern Pattern => IsUnsqueeze(
- null,
- "unsqueeze",
- IsBucketPad(null, "bucketPad", IsWildcard(), IsTensorConst()),
- IsTensorConst());
-
- private Expr? GetReplace(Call bucketPad, Call unsqueeze)
- {
- return ReplaceUtility.ReplaceCallParams(bucketPad, (BucketPad.Shape.Index, unsqueeze.CheckedShape.ToValueArray()));
- }
-}
diff --git a/src/Nncase.Passes/Rules/ShapeBucket/FoldNopTuple.cs b/src/Nncase.Passes/Rules/ShapeBucket/FoldNopTuple.cs
new file mode 100644
index 0000000000..edbfd66074
--- /dev/null
+++ b/src/Nncase.Passes/Rules/ShapeBucket/FoldNopTuple.cs
@@ -0,0 +1,64 @@
+// Copyright (c) Canaan Inc. All rights reserved.
+// Licensed under the Apache license. See LICENSE file in the project root for full license information.
+
+using System.Linq;
+using System.Reactive;
+using System.Threading.Tasks;
+using Nncase.Diagnostics;
+using Nncase.IR;
+using Nncase.IR.Tensors;
+using Nncase.PatternMatch;
+using Nncase.Utilities;
+using static Nncase.PatternMatch.Utility;
+using static Nncase.Utilities.ReplaceUtility;
+
+namespace Nncase.Passes.Rules.ShapeBucket;
+
+public class FoldNopTuple : FunctionPass
+{
+ protected override Task RunCoreAsync(BaseFunction input, RunPassContext context)
+ {
+ int i = 0;
+ while (true)
+ {
+ var preHash = input.GetHashCode();
+ DumpScope.Current.DumpIR(input, $"{i}_before");
+
+ new FoldNopTupleVisitior().Visit(input);
+ DumpScope.Current.DumpIR(input, $"{i++}_after_convert");
+ var afterHash = input.GetHashCode();
+ if (preHash == afterHash)
+ {
+ return Task.FromResult(input);
+ }
+ }
+ }
+
+ internal class FoldNopTupleVisitior : ExprVisitor
+ {
+ private bool _changed;
+
+ public FoldNopTupleVisitior()
+ : base(true)
+ {
+ }
+
+ protected override Expr DefaultVisitLeaf(Expr expr) => expr;
+
+ protected override Expr VisitLeafTuple(Tuple expr)
+ {
+ if (!_changed && expr.Users.All(user => user is Call { Target: GetItem }))
+ {
+ foreach (var user in expr.Users)
+ {
+ var index = ((TensorConst)((Call)user).Arguments[GetItem.Index.Index]).Value.ToScalar();
+ ReplaceUtility.ReplaceAllUsesWith(user, expr.Fields[index]);
+ }
+
+ _changed = true;
+ }
+
+ return expr;
+ }
+ }
+}
diff --git a/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs b/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs
index f22e2d4c70..be7dad641b 100644
--- a/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs
+++ b/src/Nncase.Passes/Rules/ShapeBucket/MergeBucketFusion.cs
@@ -25,16 +25,38 @@ namespace Nncase.Passes.Rules.ShapeBucket;
public class MergeBucketFusionPass : FunctionPass
{
+ private readonly bool _greedy;
+
+ public MergeBucketFusionPass(bool greedy)
+ {
+ _greedy = greedy;
+ }
+
protected override async Task RunCoreAsync(BaseFunction input, RunPassContext context)
{
+ // bool greedy and dynamic
var main = (Function)input;
+ int i = 0;
while (true)
{
var preHash = main.GetHashCode();
- CompilerServices.Rewrite(main, new IRewriteRule[] { new MultiUserCallToFusion(), new MergeTupleFusion() }, new());
- await new MergeSeqBucketFusion().RunAsync(main, context);
- IRHelpers.DCE(main);
- await new MergeMultiUsersFusion().RunAsync(main, context);
+ if (_greedy)
+ {
+ CompilerServices.Rewrite(main, new IRewriteRule[] { new MultiUserCallToFusion(false, _greedy), new MergeTupleFusion() }, new());
+ await new MergeSeqBucketFusion().RunAsync(main, context);
+ IRHelpers.DCE(main);
+ await new MergeMultiUsersFusion().RunAsync(main, context);
+ DumpIR(main, $"{i}_before", "FoldNopTuple");
+ await new FoldNopTuple().RunAsync(main, context);
+ }
+ else
+ {
+ await new MergeSeqBucketFusion().RunAsync(main, context);
+ IRHelpers.DCE(main);
+ }
+
+ CheckRepeat(main);
+ CheckErrorVar(main, main.Parameters.ToArray());
var postHash = main.GetHashCode();
if (preHash == postHash)
{
@@ -42,6 +64,7 @@ protected override async Task RunCoreAsync(BaseFunction input, Run
}
}
+ DumpIR(main, "MergeBucketFusionEnd");
return main;
}
}
@@ -176,10 +199,6 @@ public class MergeMultiUsersFusion : FunctionPass
public static bool DetectedRing(Call outerCall, Expr[] users)
{
- // var users = outerCall.Users.ToArray();
- // todo: fix this,TestComplexExpr
- // var userArgs = users.SelectMany(user => ((Call)user).Arguments.ToArray()).Except(users).ToArray();
- // 用这个不过,但是好像会引起其他问题??
var userArgs = users.SelectMany(user => ((Call)user).Arguments.ToArray()).ToArray();
foreach (var arg in userArgs)
{
@@ -231,23 +250,15 @@ private static (Expr? NewCall, UserInfo[] AllUsers) MergeMultiUserFusion(Call ou
// todo: not support
if (users.Any(user => user is Tuple))
{
- // Console.WriteLine("HasTuple");
return notSupport;
}
var userInfos = CollectUsers(outerCall, users);
- // todo: support only one user, because merge fusion rule is not enough
- // maybe a error
- // if (userInfos.Length < 2)
- // {
- // return null;
- // }
-
// has invalid
if (userInfos.Length != users.Distinct().ToArray().Length)
{
- Console.WriteLine("not all fusion call and getItemMode");
+ // Console.WriteLine("not all fusion call and getItemMode");
return notSupport;
}
@@ -601,8 +612,7 @@ protected override Expr VisitLeafCall(Call expr)
}
// Console.WriteLine($"Match {fusion.Name} counter:{Counter}");
- DumpIR(Root, "OriginRoot", RelPath);
-
+ // DumpIR(Root, "OriginRoot", RelPath);
var (newCall, users) = MergeMultiUserFusion(outerCall, fusion);
if (newCall != null)
{
diff --git a/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs b/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs
index 5118cb44e9..bac3b6f135 100644
--- a/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs
+++ b/src/Nncase.Passes/Rules/ShapeBucket/MergeCallToFusion.cs
@@ -7,6 +7,7 @@
using System.Xml;
using NetFabric.Hyperlinq;
using Nncase.IR;
+using Nncase.IR.Math;
using Nncase.IR.Tensors;
using Nncase.PatternMatch;
using static Nncase.Passes.Rules.ShapeBucket.ShapeBucketHelper;
@@ -35,11 +36,6 @@ public static bool AllConst(Call originCall)
return false;
}
-
- public bool ValidTarget(Expr target)
- {
- return CallValidator.ValidTarget(target);
- }
}
[RuleGenerator]
@@ -106,6 +102,17 @@ public partial class MergePrevMarkerToFusion : MergeFusionBase
[RuleGenerator]
public partial class MergeNextCallToFusion : MergeFusionBase
{
+ private readonly bool _greedy = true;
+
+ public MergeNextCallToFusion(bool greedy = true)
+ {
+ _greedy = greedy;
+ }
+
+ public MergeNextCallToFusion()
+ {
+ }
+
public Pattern FusionCall => IsCall(
"fusionOuterCall",
IsFusion(
@@ -127,13 +134,13 @@ public partial class MergeNextCallToFusion : MergeFusionBase
// nextCall(marker(fusion(x))) -> fusion(nextCall(marker(x)))
public Expr? GetReplace(Call nextCall, Expr maybeFusionCallMarker, Expr target, Call fusionOuterCall, BucketFusion fusion)
{
- var singleVar = CompileSession.CompileOptions.ShapeBucketOptions.VarMap.Values.SelectMany(x => x).OfType().ToHashSet().Count <= 1;
+ var singleVar = SingleDimVar(CompileSession.CompileOptions.ShapeBucketOptions);
if (!singleVar && nextCall.Arguments.ToArray().OfType().Count() > 1)
{
return null;
}
- if (!ValidTarget(target))
+ if (!CallValidator.ValidTarget(nextCall, _greedy))
{
return null;
}
@@ -237,8 +244,25 @@ private bool SameEffectVar(Call originCall, Fusion fusion)
[RuleGenerator]
public partial class MergePrevCallToFusion : MergeFusionBase
{
+ // 输入必须匹配marker,因为即便合并marker也是要在外面保留一份副本
+ // fusion(marker(prevCall()) { var } -> fusion(var) { marker(prevCall()) }
+ // fusion((prevCall()) { var } -> fusion(var) { prevCall() }
+ private readonly bool _greedy = true;
+
+ private readonly bool _mergeFusion;
+
private string _prevCallStr = string.Empty;
+ public MergePrevCallToFusion()
+ {
+ }
+
+ public MergePrevCallToFusion(bool greedy = true, bool mergeFusion = false)
+ {
+ _greedy = greedy;
+ _mergeFusion = mergeFusion;
+ }
+
public override Pattern Pattern => IsCall(
"fusionOuterCall",
IsFusion(
@@ -257,10 +281,6 @@ public Pattern MaybeMarker(string exprName, Pattern exprPatten) => IsAlt(
IsRangeOfMarker(exprPatten, IsWildcard()),
exprPatten);
- // 输入必须匹配marker,因为即便合并marker也是要在外面保留一份副本
- // fusion(marker(prevCall()) { var } -> fusion(var) { marker(prevCall()) }
- // fusion((prevCall()) { var } -> fusion(var) { prevCall() }
-
// dfs
// xx | marker(xx)不行, 会先匹配到xx
// xx(marker) | xx 可以
@@ -600,7 +620,7 @@ private bool IsInvalid(Call lhsPrevCall, Expr lhsTarget)
return true;
}
- if (!ValidTarget(lhsTarget))
+ if (!CallValidator.ValidTarget(lhsPrevCall, _greedy))
{
return true;
}
diff --git a/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs b/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs
index 110e37026a..94f7977f2f 100644
--- a/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs
+++ b/src/Nncase.Passes/Rules/ShapeBucket/RecordFusionShape.cs
@@ -12,10 +12,22 @@
using Nncase.Evaluator;
using Nncase.IR;
using static Nncase.IR.F.Tensors;
+using static Nncase.Passes.Rules.ShapeBucket.ShapeBucketHelper;
namespace Nncase.Passes.Rules.ShapeBucket;
-public record FusionShapeData(IValue Outshape, IValue[] InputShapes);
+public class FusionShapeData
+{
+ public FusionShapeData(IValue outshape, IValue[] inputShapes)
+ {
+ Outshape = outshape;
+ InputShapes = inputShapes;
+ }
+
+ public IValue Outshape { get; }
+
+ public IValue[] InputShapes { get; }
+}
public class FusionShapeUpdater : ExprVisitor
{
@@ -26,7 +38,7 @@ public FusionShapeUpdater(Dictionary memo)
_memo = memo;
}
- public Dictionary FusionShape { get; set; } = new();
+ public Dictionary FusionShape { get; } = new();
protected override Expr DefaultVisitLeaf(Expr expr) => expr;
@@ -34,7 +46,11 @@ protected override Expr VisitLeafCall(Call expr)
{
if (expr.Target is BucketFusion f)
{
- var argShape = expr.Arguments.ToArray().Select(arg => GetShape(_memo[arg])).ToArray();
+ var argShape = expr.Arguments.ToArray().Select(arg =>
+ {
+ var exp = arg is Marker m ? m.Target : arg;
+ return GetShape(_memo[exp]);
+ }).ToArray();
var shape = GetShape(_memo[expr]);
FusionShape[f] = new FusionShapeData(shape, argShape);
}
@@ -54,25 +70,6 @@ private IValue GetShape(IValue value)
}
}
-public class SimpleTimer : IDisposable
-{
- private readonly DateTime _startTime;
- private readonly string _name;
-
- public SimpleTimer(string name)
- {
- _startTime = System.DateTime.Now;
- _name = name;
- }
-
- public void Dispose()
- {
- var endTime = System.DateTime.Now;
- var time = endTime - _startTime;
- Console.WriteLine($"{_name} tooks {time.Seconds}");
- }
-}
-
public class RecordFusionShape : FunctionPass
{
private Dictionary