Skip to content

Commit

Permalink
Add compare ops for both tflite and onnx. (#571)
Browse files Browse the repository at this point in the history
* Add compare ops for both tflite and onnx.

* Disable multiprocess test for importer.

* Remove --dist=load.

* Fix pytest with multiprocess fail issue.
  • Loading branch information
zhangyang2057 authored Apr 22, 2022
1 parent f2c0a1e commit 885ac18
Show file tree
Hide file tree
Showing 38 changed files with 613 additions and 234 deletions.
4 changes: 4 additions & 0 deletions docs/onnx_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,17 @@
| Gemm ||
| GlobalAveragePool ||
| GlobalMaxPool ||
| Greater ||
| GreaterOrEqual ||
| Hardmax ||
| HardSigmoid ||
| HardSwish ||
| Identity ||
| InstanceNormalization ||
| LpNormalization ||
| LeakyRelu ||
| Less ||
| LessOrEqual ||
| Log ||
| LogSoftmax ||
| LRN ||
Expand Down
8 changes: 8 additions & 0 deletions docs/tflite_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@
| CONCATENATION ||
| CONV_2D ||
| COS ||
| CUSTOM ||
| DEPTHWISE_CONV_2D ||
| DIV ||
| EQUAL ||
| EXP ||
| EXPAND_DIMS ||
| FLOOR ||
| FLOOR_DIV ||
| FLOOR_MOD ||
| FULLY_CONNECTED ||
| GREATER ||
| GREATER_EQUAL ||
| L2_NORMALIZATION ||
| LEAKY_RELU ||
| LESS ||
| LESS_EQUAL ||
| LOG ||
| LOGISTIC ||
| MAX_POOL_2D ||
Expand All @@ -32,6 +38,7 @@
| MINIMUM ||
| MUL ||
| NEG ||
| NOT_EQUAL ||
| PAD ||
| PADV2 ||
| MIRROR_PAD ||
Expand Down Expand Up @@ -71,3 +78,4 @@
| SQUARED_DIFFERENCE ||
| LOG_SOFTMAX ||
| SPLIT ||
| HARD_SWISH ||
38 changes: 20 additions & 18 deletions include/nncase/codegen/stackvm/op_writer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2/16/2022 4:18:59 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 4/20/2022 2:35:53 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -952,6 +952,24 @@ struct op_writer<nncase::runtime::stackvm::tensor_call_op_t>
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_compare_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_compare_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src1);
writer.write(op.rstride_src1);
writer.write(op.rshape_src2);
writer.write(op.rstride_src2);
writer.write(op.rshape_dest);
writer.write(op.rstride_dest);
writer.write(static_cast<uint8_t>(op.compare_op));
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_conv2d_op_t>
{
Expand Down Expand Up @@ -1035,22 +1053,6 @@ struct op_writer<nncase::runtime::stackvm::tensor_dequantize_op_t>
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_equal_op_t>
{
void operator()(const nncase::runtime::stackvm::tensor_equal_op_t &op, binary_writer &writer) const
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src1);
writer.write(op.rstride_src1);
writer.write(op.rshape_src2);
writer.write(op.rstride_src2);
writer.write(op.rstride_dest);
}
};

template <>
struct op_writer<nncase::runtime::stackvm::tensor_gather_op_t>
{
Expand Down Expand Up @@ -1535,12 +1537,12 @@ class NNCASE_API op_builder
void tensor_broadcast_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_dest, uint8_t rstride_dest);
void tensor_binary_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rshape_dest, uint8_t rstride_dest, binary_op_t binary_op, float fused_clamp_low, float fused_clamp_high);
void tensor_call_(uint32_t function_id, uint16_t module_id, uint8_t num_src, uint8_t num_dst);
void tensor_compare_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rshape_dest, uint8_t rstride_dest, compare_op_t compare_op);
void tensor_conv2d_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rshape_kernel, uint8_t rstride_kernel, uint8_t rstride_bias, uint8_t rstride_dest, uint16_t groups, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high);
void tensor_copy_(datatype_t datatype, uint8_t rshape, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_convert_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_cumsum_(datatype_t datatype, uint8_t rshape_src, int32_t axis, bool exclusive, bool reverse);
void tensor_dequantize_(datatype_t in_datatype, datatype_t dst_datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest);
void tensor_equal_(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rstride_dest);
void tensor_gather_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t axis);
void tensor_gather_nd_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_indices, uint8_t batch_dims);
void tensor_hardmax_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, int32_t axis);
Expand Down
2 changes: 1 addition & 1 deletion include/nncase/ir/opcode.def
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ DEFINE_NEUTRAL_OPCODE(topk, TopK, 0x123)
DEFINE_NEUTRAL_OPCODE(trilu, Trilu, 0x124)
DEFINE_NEUTRAL_OPCODE(sigmoid, Sigmoid, 0x125)
DEFINE_NEUTRAL_OPCODE(roi_align, RoiAlign, 0x126)
DEFINE_NEUTRAL_OPCODE(equal, Equal, 0x127)
DEFINE_NEUTRAL_OPCODE(compare, Compare, 0x127)
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,22 @@

namespace nncase::ir
{
class NNCASE_API equal : public node
class NNCASE_API compare : public node
{
public:
DEFINE_NODE_OPCODE(op_equal);
DEFINE_NODE_OPCODE(op_compare);

input_connector &input_a() { return input_at(0); }
input_connector &input_b() { return input_at(1); }
output_connector &output() { return output_at(0); }

equal(datatype_t input_type, shape_t input_a_shape, shape_t input_b_shape);
compare_op_t compare_op() const noexcept { return compare_op_; }
compare(compare_op_t compare_op, datatype_t input_type, shape_t input_a_shape, shape_t input_b_shape);

protected:
bool properties_equal(node &other) const override;

private:
compare_op_t compare_op_;
};
}
7 changes: 4 additions & 3 deletions include/nncase/kernels/cpu/reference/tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ NNCASE_API result<void> dequantize(datatype_t in_type, datatype_t out_type, cons
kernel_context &context) noexcept;

template <typename T>
NNCASE_API result<void> equal(const T *input_a, const T *input_b, bool *output,
const runtime_shape_t &in_a_shape, const runtime_shape_t &in_a_strides, const runtime_shape_t &in_b_shape,
const runtime_shape_t &in_b_strides, const runtime_shape_t &out_strides) noexcept;
NNCASE_API result<void> compare(compare_op_t op, const T *input_a, const T *input_b, bool *output,
const runtime_shape_t &in_a_shape, const runtime_shape_t &in_a_strides,
const runtime_shape_t &in_b_shape, const runtime_shape_t &in_b_strides,
const runtime_shape_t &out_shape, const runtime_shape_t &out_strides) noexcept;

NNCASE_API result<void> lut1d(datatype_t type, const gsl::byte *input, const gsl::byte *table, gsl::byte *output, const runtime_shape_t &shape,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, const scalar &min, const scalar &max) noexcept;
Expand Down
7 changes: 4 additions & 3 deletions include/nncase/kernels/tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,10 @@ NNCASE_API result<void> dequantize(datatype_t in_type, datatype_t out_type, cons
kernel_context &context = default_kernel_context()) noexcept;

template <typename T>
NNCASE_API result<void> equal(const T *input_a, const T *input_b, bool *output,
const runtime_shape_t &in_a_shape, const runtime_shape_t &in_a_strides, const runtime_shape_t &in_b_shape,
const runtime_shape_t &in_b_strides, const runtime_shape_t &out_strides) noexcept;
NNCASE_API result<void> compare(compare_op_t op, const T *input_a, const T *input_b, bool *output,
const runtime_shape_t &in_a_shape, const runtime_shape_t &in_a_strides,
const runtime_shape_t &in_b_shape, const runtime_shape_t &in_b_strides,
const runtime_shape_t &out_shape, const runtime_shape_t &out_strides) noexcept;

NNCASE_API result<void> lut1d(datatype_t type, const gsl::byte *input, const gsl::byte *table, gsl::byte *output, const runtime_shape_t &shape,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, const scalar &min, const scalar &max) noexcept;
Expand Down
30 changes: 30 additions & 0 deletions include/nncase/runtime/datatypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,36 @@ inline std::string unary_op_to_string(unary_op_t op)
return "unknown";
}

typedef enum _compare_op
{
compare_equal,
compare_not_equal,
compare_greater,
compare_greater_equal,
compare_less,
compare_less_equal
} compare_op_t;

inline std::string compare_op_to_string(compare_op_t op)
{
switch (op)
{
case compare_equal:
return "compare_equal";
case compare_not_equal:
return "compare_not_equal";
case compare_greater:
return "compare_greater";
case compare_greater_equal:
return "compare_greater_equal";
case compare_less:
return "compare_less";
case compare_less_equal:
return "compare_less_equal";
}
return "unknown";
}

typedef enum _image_resize_mode
{
image_resize_bilinear,
Expand Down
42 changes: 22 additions & 20 deletions include/nncase/runtime/stackvm/op_reader.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2/16/2022 4:18:59 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 4/20/2022 2:35:53 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -1150,6 +1150,26 @@ struct op_reader<tensor_call_op_t>
}
};

template <>
struct op_reader<tensor_compare_op_t>
{
tensor_compare_op_t operator()(span_reader &reader) const
{
tensor_compare_op_t op(default_init);
op.opcode = static_cast<opcode_t>(reader.read_unaligned<uint8_t>());
op.funct = static_cast<tensor_function_t>(reader.read_unaligned<uint16_t>());
op.datatype = static_cast<datatype_t>(reader.read_unaligned<uint8_t>());
op.rshape_src1 = reader.read_unaligned<uint8_t>();
op.rstride_src1 = reader.read_unaligned<uint8_t>();
op.rshape_src2 = reader.read_unaligned<uint8_t>();
op.rstride_src2 = reader.read_unaligned<uint8_t>();
op.rshape_dest = reader.read_unaligned<uint8_t>();
op.rstride_dest = reader.read_unaligned<uint8_t>();
op.compare_op = static_cast<compare_op_t>(reader.read_unaligned<uint8_t>());
return op;
}
};

template <>
struct op_reader<tensor_conv2d_op_t>
{
Expand Down Expand Up @@ -1243,24 +1263,6 @@ struct op_reader<tensor_dequantize_op_t>
}
};

template <>
struct op_reader<tensor_equal_op_t>
{
tensor_equal_op_t operator()(span_reader &reader) const
{
tensor_equal_op_t op(default_init);
op.opcode = static_cast<opcode_t>(reader.read_unaligned<uint8_t>());
op.funct = static_cast<tensor_function_t>(reader.read_unaligned<uint16_t>());
op.datatype = static_cast<datatype_t>(reader.read_unaligned<uint8_t>());
op.rshape_src1 = reader.read_unaligned<uint8_t>();
op.rstride_src1 = reader.read_unaligned<uint8_t>();
op.rshape_src2 = reader.read_unaligned<uint8_t>();
op.rstride_src2 = reader.read_unaligned<uint8_t>();
op.rstride_dest = reader.read_unaligned<uint8_t>();
return op;
}
};

template <>
struct op_reader<tensor_gather_op_t>
{
Expand Down Expand Up @@ -1796,12 +1798,12 @@ class NNCASE_API op_visitor
virtual result<void> visit(NNCASE_UNUSED const tensor_broadcast_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_binary_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_call_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_compare_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_conv2d_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_copy_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_convert_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_cumsum_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_dequantize_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_equal_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_gather_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_gather_nd_op_t &op) noexcept { return ok(); }
virtual result<void> visit(NNCASE_UNUSED const tensor_hardmax_op_t &op) noexcept { return ok(); }
Expand Down
56 changes: 29 additions & 27 deletions include/nncase/runtime/stackvm/opcode.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 2/16/2022 4:18:58 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 4/20/2022 2:35:53 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -126,14 +126,14 @@ enum class tensor_function_t
BINARY = 0x0001,
BROADCAST = 0x0002,
CALL = 0x0003,
CLAMP = 0x0004,
CONV2D = 0x0005,
CONV2D_TRANSPOSE = 0x0006,
CONVERT = 0x0007,
COPY = 0x0008,
CUMSUM = 0x0009,
DEQUANTIZE = 0x000A,
EQUAL = 0x000B,
COMPARE = 0x0004,
CLAMP = 0x0005,
CONV2D = 0x0006,
CONV2D_TRANSPOSE = 0x0007,
CONVERT = 0x0008,
COPY = 0x0009,
CUMSUM = 0x000A,
DEQUANTIZE = 0x000B,
GATHER = 0x000C,
GATHER_ND = 0x000D,
HARDMAX = 0x000E,
Expand Down Expand Up @@ -1290,6 +1290,26 @@ struct tensor_call_op_t
}
};

struct tensor_compare_op_t
{
opcode_t opcode;
tensor_function_t funct;
datatype_t datatype;
uint8_t rshape_src1;
uint8_t rstride_src1;
uint8_t rshape_src2;
uint8_t rstride_src2;
uint8_t rshape_dest;
uint8_t rstride_dest;
compare_op_t compare_op;

tensor_compare_op_t(default_init_t) noexcept { }
explicit tensor_compare_op_t(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rshape_dest, uint8_t rstride_dest, compare_op_t compare_op) noexcept
: opcode(opcode_t::TENSOR), funct(tensor_function_t::COMPARE), datatype(datatype), rshape_src1(rshape_src1), rstride_src1(rstride_src1), rshape_src2(rshape_src2), rstride_src2(rstride_src2), rshape_dest(rshape_dest), rstride_dest(rstride_dest), compare_op(compare_op)
{
}
};

struct tensor_conv2d_op_t
{
opcode_t opcode;
Expand Down Expand Up @@ -1383,24 +1403,6 @@ struct tensor_dequantize_op_t
}
};

struct tensor_equal_op_t
{
opcode_t opcode;
tensor_function_t funct;
datatype_t datatype;
uint8_t rshape_src1;
uint8_t rstride_src1;
uint8_t rshape_src2;
uint8_t rstride_src2;
uint8_t rstride_dest;

tensor_equal_op_t(default_init_t) noexcept { }
explicit tensor_equal_op_t(datatype_t datatype, uint8_t rshape_src1, uint8_t rstride_src1, uint8_t rshape_src2, uint8_t rstride_src2, uint8_t rstride_dest) noexcept
: opcode(opcode_t::TENSOR), funct(tensor_function_t::EQUAL), datatype(datatype), rshape_src1(rshape_src1), rstride_src1(rstride_src1), rshape_src2(rshape_src2), rstride_src2(rstride_src2), rstride_dest(rstride_dest)
{
}
};

struct tensor_gather_op_t
{
opcode_t opcode;
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/stackvm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ set(SRCS module_builder.cpp
ops/binary.cpp
ops/broadcast.cpp
ops/call.cpp
ops/compare.cpp
ops/conv2d.cpp
ops/convert.cpp
ops/copy.cpp
ops/cumsum.cpp
ops/dequantize.cpp
ops/equal.cpp
ops/gather.cpp
ops/gather_nd.cpp
ops/hardmax.cpp
Expand Down
Loading

0 comments on commit 885ac18

Please sign in to comment.