Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compare ops for both tflite and onnx. #571

Merged
merged 4 commits into from
Apr 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/onnx_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@
| Gemm | ✅ |
| GlobalAveragePool | ✅ |
| GlobalMaxPool | ✅ |
| Greater | ✅ |
| GreaterOrEqual | ✅ |
| Hardmax | ✅ |
| HardSigmoid | ✅ |
| HardSwish | ✅ |
| Identity | ✅ |
| InstanceNormalization | ✅ |
| LpNormalization | ✅ |
| LeakyRelu | ✅ |
| Less | ✅ |
| LessOrEqual | ✅ |
| Log | ✅ |
| LogSoftmax | ✅ |
| LRN | ✅ |
Expand Down
8 changes: 8 additions & 0 deletions docs/tflite_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@
| CONCATENATION | ✅ |
| CONV_2D | ✅ |
| COS | ✅ |
| CUSTOM | ✅ |
| DEPTHWISE_CONV_2D | ✅ |
| DIV | ✅ |
| EQUAL | ✅ |
| EXP | ✅ |
| EXPAND_DIMS | ✅ |
| FLOOR | ✅ |
| FLOOR_DIV | ✅ |
| FLOOR_MOD | ✅ |
| FULLY_CONNECTED | ✅ |
| GREATER | ✅ |
| GREATER_EQUAL | ✅ |
| L2_NORMALIZATION | ✅ |
| LEAKY_RELU | ✅ |
| LESS | ✅ |
| LESS_EQUAL | ✅ |
| LOG | ✅ |
| LOGISTIC | ✅ |
| MAX_POOL_2D | ✅ |
Expand All @@ -32,6 +38,7 @@
| MINIMUM | ✅ |
| MUL | ✅ |
| NEG | ✅ |
| NOT_EQUAL | ✅ |
| PAD | ✅ |
| PADV2 | ✅ |
| MIRROR_PAD | ✅ |
Expand Down Expand Up @@ -71,3 +78,4 @@
| SQUARED_DIFFERENCE | ✅ |
| LOG_SOFTMAX | ✅ |
| SPLIT | ✅ |
| HARD_SWISH | ✅ |
38 changes: 20 additions & 18 deletions include/nncase/codegen/stackvm/op_writer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2/16/2022 4:18:59 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 4/20/2022 2:35:53 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -952,6 +952,24 @@ struct op_writer<nncase::runtime::stackvm::tensor_call_op_t>
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_compare_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_compare_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src1);
writer.write(op.rstride_src1);
writer.write(op.rshape_src2);
writer.write(op.rstride_src2);
writer.write(op.rshape_dest);
writer.write(op.rstride_dest);
writer.write(static_cast<uint8_t>(op.compare_op));
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_conv2d_op_t>
{
Expand Down Expand Up @@ -1035,22 +1053,6 @@ struct op_writer<nncase::runtime::stackvm::tensor_dequantize_op_t>
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_equal_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_equal_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src1);
writer.write(op.rstride_src1);
writer.write(op.rshape_src2);
writer.write(op.rstride_src2);
writer.write(op.rstride_dest);
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_gather_op_t>
{
Expand Down Expand Up @@ -1535,12 +1537,12 @@ class NNCASE_API op_builder
void tensor_broadcast_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_dest, uint8_t rstride_dest);
void tensor_binary_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rshape_dest, uint8_t rstride_dest, binary_op_t binary_op, float fused_clamp_low, float fused_clamp_high);
void tensor_call_(uint32_t function_id, uint16_t module_id, uint8_t num_src, uint8_t num_dst);
void tensor_compare_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rshape_dest, uint8_t rstride_dest, compare_op_t compare_op);
void tensor_conv2d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_kernel, uint8_t rstride_kernel, uint8_t rstride_bias, uint8_t rstride_dest, uint16_t groups, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high);
void tensor_copy_(datatype_t datatype, uint8_t rshape, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_convert_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_cumsum_(datatype_t datatype, uint8_t rshape_src, int32_t axis, bool exclusive, bool reverse);
void tensor_dequantize_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_equal_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rstride_dest);
void tensor_gather_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t axis);
void tensor_gather_nd_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t batch_dims);
void tensor_hardmax_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, int32_t axis);
Expand Down
2 changes: 1 addition & 1 deletion include/nncase/ir/opcode.def
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ DEFINE_NEUTRAL_OPCODE(topk, TopK, 0x123)
DEFINE_NEUTRAL_OPCODE(trilu, Trilu, 0x124)
DEFINE_NEUTRAL_OPCODE(sigmoid, Sigmoid, 0x125)
DEFINE_NEUTRAL_OPCODE(roi_align, RoiAlign, 0x126)
DEFINE_NEUTRAL_OPCODE(equal, Equal, 0x127)
DEFINE_NEUTRAL_OPCODE(compare, Compare, 0x127)
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@

namespace nncase::ir
{
class NNCASE_API equal : public node
class NNCASE_API compare : public node
{
public:
DEFINE_NODE_OPCODE(op_equal);
DEFINE_NODE_OPCODE(op_compare);

input_connector &input_a() { return input_at(0); }
input_connector &input_b() { return input_at(1); }
output_connector &output() { return output_at(0); }

equal(datatype_t input_type, shape_t input_a_shape, shape_t input_b_shape);
compare_op_t compare_op() const noexcept { return compare_op_; }
compare(compare_op_t compare_op, datatype_t input_type, shape_t input_a_shape, shape_t input_b_shape);

protected:
bool properties_equal(node &other) const override;

private:
compare_op_t compare_op_;
};
}
7 changes: 4 additions & 3 deletions include/nncase/kernels/cpu/reference/tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ NNCASE_API result<void> dequantize(datatype_t in_type, datatype_t out_type, cons
kernel_context &context) noexcept;

template <typename T>
NNCASE_API result<void> equal(const T *input_a, const T *input_b, bool *output,
const runtime_shape_t &in_a_shape, const runtime_shape_t &in_a_strides, const runtime_shape_t &in_b_shape,
const runtime_shape_t &in_b_strides, const runtime_shape_t &out_strides) noexcept;
NNCASE_API result<void> compare(compare_op_t op, const T *input_a, const T *input_b, bool *output,
const runtime_shape_t &in_a_shape, const runtime_shape_t &in_a_strides,
const runtime_shape_t &in_b_shape, const runtime_shape_t &in_b_strides,
const runtime_shape_t &out_shape, const runtime_shape_t &out_strides) noexcept;

NNCASE_API result<void> lut1d(datatype_t type, const gsl::byte *input, const gsl::byte *table, gsl::byte *output, const runtime_shape_t &shape,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, const scalar &min, const scalar &max) noexcept;
Expand Down
7 changes: 4 additions & 3 deletions include/nncase/kernels/tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ NNCASE_API result<void> dequantize(datatype_t in_type, datatype_t out_type, cons
kernel_context &context = default_kernel_context()) noexcept;

template <typename T>
NNCASE_API result<void> equal(const T *input_a, const T *input_b, bool *output,
const runtime_shape_t &in_a_shape, const runtime_shape_t &in_a_strides, const runtime_shape_t &in_b_shape,
const runtime_shape_t &in_b_strides, const runtime_shape_t &out_strides) noexcept;
NNCASE_API result<void> compare(compare_op_t op, const T *input_a, const T *input_b, bool *output,
const runtime_shape_t &in_a_shape, const runtime_shape_t &in_a_strides,
const runtime_shape_t &in_b_shape, const runtime_shape_t &in_b_strides,
const runtime_shape_t &out_shape, const runtime_shape_t &out_strides) noexcept;

NNCASE_API result<void> lut1d(datatype_t type, const gsl::byte *input, const gsl::byte *table, gsl::byte *output, const runtime_shape_t &shape,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, const scalar &min, const scalar &max) noexcept;
Expand Down
30 changes: 30 additions & 0 deletions include/nncase/runtime/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,36 @@ inline std::string unary_op_to_string(unary_op_t op)
return "unknown";
}

typedef enum _compare_op
{
compare_equal,
compare_not_equal,
compare_greater,
compare_greater_equal,
compare_less,
compare_less_equal
} compare_op_t;

inline std::string compare_op_to_string(compare_op_t op)
{
switch (op)
{
case compare_equal:
return "compare_equal";
case compare_not_equal:
return "compare_not_equal";
case compare_greater:
return "compare_greater";
case compare_greater_equal:
return "compare_greater_equal";
case compare_less:
return "compare_less";
case compare_less_equal:
return "compare_less_equal";
}
return "unknown";
}

typedef enum _image_resize_mode
{
image_resize_bilinear,
Expand Down
42 changes: 22 additions & 20 deletions include/nncase/runtime/stackvm/op_reader.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2/16/2022 4:18:59 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 4/20/2022 2:35:53 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -1150,6 +1150,26 @@ struct op_reader<tensor_call_op_t>
}
};

template <>
struct op_reader<tensor_compare_op_t>
{
tensor_compare_op_t operator()(span_reader &reader) const
{
tensor_compare_op_t op(default_init);
op.opcode = static_cast<opcode_t>(reader.read_unaligned<uint8_t>());
op.funct = static_cast<tensor_function_t>(reader.read_unaligned<uint16_t>());
op.datatype = static_cast<datatype_t>(reader.read_unaligned<uint8_t>());
op.rshape_src1 = reader.read_unaligned<uint8_t>();
op.rstride_src1 = reader.read_unaligned<uint8_t>();
op.rshape_src2 = reader.read_unaligned<uint8_t>();
op.rstride_src2 = reader.read_unaligned<uint8_t>();
op.rshape_dest = reader.read_unaligned<uint8_t>();
op.rstride_dest = reader.read_unaligned<uint8_t>();
op.compare_op = static_cast<compare_op_t>(reader.read_unaligned<uint8_t>());
return op;
}
};

template <>
struct op_reader<tensor_conv2d_op_t>
{
Expand Down Expand Up @@ -1243,24 +1263,6 @@ struct op_reader<tensor_dequantize_op_t>
}
};

template <>
struct op_reader<tensor_equal_op_t>
{
tensor_equal_op_t operator()(span_reader &reader) const
{
tensor_equal_op_t op(default_init);
op.opcode = static_cast<opcode_t>(reader.read_unaligned<uint8_t>());
op.funct = static_cast<tensor_function_t>(reader.read_unaligned<uint16_t>());
op.datatype = static_cast<datatype_t>(reader.read_unaligned<uint8_t>());
op.rshape_src1 = reader.read_unaligned<uint8_t>();
op.rstride_src1 = reader.read_unaligned<uint8_t>();
op.rshape_src2 = reader.read_unaligned<uint8_t>();
op.rstride_src2 = reader.read_unaligned<uint8_t>();
op.rstride_dest = reader.read_unaligned<uint8_t>();
return op;
}
};

template <>
struct op_reader<tensor_gather_op_t>
{
Expand Down Expand Up @@ -1796,12 +1798,12 @@ class NNCASE_API op_visitor
virtual result<void> visit(NNCASE_UNUSED const tensor_broadcast_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_binary_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_call_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_compare_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_conv2d_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_copy_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_convert_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_cumsum_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_dequantize_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_equal_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_gather_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_gather_nd_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_hardmax_op_t &op) noexcept { return ok(); }
Expand Down
56 changes: 29 additions & 27 deletions include/nncase/runtime/stackvm/opcode.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2/16/2022 4:18:58 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 4/20/2022 2:35:53 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -126,14 +126,14 @@ enum class tensor_function_t
BINARY = 0x0001,
BROADCAST = 0x0002,
CALL = 0x0003,
CLAMP = 0x0004,
CONV2D = 0x0005,
CONV2D_TRANSPOSE = 0x0006,
CONVERT = 0x0007,
COPY = 0x0008,
CUMSUM = 0x0009,
DEQUANTIZE = 0x000A,
EQUAL = 0x000B,
COMPARE = 0x0004,
CLAMP = 0x0005,
CONV2D = 0x0006,
CONV2D_TRANSPOSE = 0x0007,
CONVERT = 0x0008,
COPY = 0x0009,
CUMSUM = 0x000A,
DEQUANTIZE = 0x000B,
GATHER = 0x000C,
GATHER_ND = 0x000D,
HARDMAX = 0x000E,
Expand Down Expand Up @@ -1290,6 +1290,26 @@ struct tensor_call_op_t
}
};

struct tensor_compare_op_t
{
opcode_t opcode;
tensor_function_t funct;
datatype_t datatype;
uint8_t rshape_src1;
uint8_t rstride_src1;
uint8_t rshape_src2;
uint8_t rstride_src2;
uint8_t rshape_dest;
uint8_t rstride_dest;
compare_op_t compare_op;

tensor_compare_op_t(default_init_t) noexcept { }
explicit tensor_compare_op_t(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rshape_dest, uint8_t rstride_dest, compare_op_t compare_op) noexcept
: opcode(opcode_t::TENSOR), funct(tensor_function_t::COMPARE), datatype(datatype), rshape_src1(rshape_src1), rstride_src1(rstride_src1), rshape_src2(rshape_src2), rstride_src2(rstride_src2), rshape_dest(rshape_dest), rstride_dest(rstride_dest), compare_op(compare_op)
{
}
};

struct tensor_conv2d_op_t
{
opcode_t opcode;
Expand Down Expand Up @@ -1383,24 +1403,6 @@ struct tensor_dequantize_op_t
}
};

struct tensor_equal_op_t
{
opcode_t opcode;
tensor_function_t funct;
datatype_t datatype;
uint8_t rshape_src1;
uint8_t rstride_src1;
uint8_t rshape_src2;
uint8_t rstride_src2;
uint8_t rstride_dest;

tensor_equal_op_t(default_init_t) noexcept { }
explicit tensor_equal_op_t(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rstride_dest) noexcept
: opcode(opcode_t::TENSOR), funct(tensor_function_t::EQUAL), datatype(datatype), rshape_src1(rshape_src1), rstride_src1(rstride_src1), rshape_src2(rshape_src2), rstride_src2(rstride_src2), rstride_dest(rstride_dest)
{
}
};

struct tensor_gather_op_t
{
opcode_t opcode;
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/stackvm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ set(SRCS module_builder.cpp
ops/binary.cpp
ops/broadcast.cpp
ops/call.cpp
ops/compare.cpp
ops/conv2d.cpp
ops/convert.cpp
ops/copy.cpp
ops/cumsum.cpp
ops/dequantize.cpp
ops/equal.cpp
ops/gather.cpp
ops/gather_nd.cpp
ops/hardmax.cpp
Expand Down
Loading