Skip to content

Commit

Permalink
Fix FULLY_CONNECTED and REDUCE ops. (#574)
Browse files Browse the repository at this point in the history
* Fix FULLY_CONNECTED and REDUCE ops.

* Add reduce_prod for tflite, support both float and int32_t.
  • Loading branch information
zhangyang2057 authored Apr 24, 2022
1 parent 885ac18 commit 6acbe83
Show file tree
Hide file tree
Showing 42 changed files with 469 additions and 107 deletions.
1 change: 1 addition & 0 deletions docs/tflite_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
| POW ||
| REDUCE_MAX ||
| REDUCE_MIN ||
| REDUCE_PROD ||
| RELU ||
| PRELU ||
| RELU6 ||
Expand Down
5 changes: 3 additions & 2 deletions include/nncase/codegen/stackvm/op_writer.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 4/20/2022 2:35:53 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 4/22/2022 10:57:45 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -1254,6 +1254,7 @@ struct op_writer<nncase::runtime::stackvm::tensor_reduce_prod_op_t>
{
writer.write(static_cast<uint8_t>(op.opcode));
writer.write(static_cast<uint16_t>(op.funct));
writer.write(static_cast<uint8_t>(op.datatype));
writer.write(op.rshape_src);
writer.write(op.rstride_src);
writer.write(op.rstride_dest);
Expand Down Expand Up @@ -1555,7 +1556,7 @@ class NNCASE_API op_builder
void tensor_random_uniform_(datatype_t datatype_dest, uint8_t rshape_dest, float low, float high, float seed);
void tensor_reduce_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, reduce_op_t reduce_op, uint8_t rshape_axis, bool keep_dims);
void tensor_reduce_arg_(datatype_t datatype_src, uint8_t rshape_src, uint8_t rstride_src, datatype_t datatype_dest, uint8_t rstride_dest, reduce_arg_op_t reduce_arg_op, uint8_t rshape_axis, bool keep_dims, bool select_last_idx);
void tensor_reduce_prod_(uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_axes, bool keep_dims);
void tensor_reduce_prod_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_axes, bool keep_dims);
void tensor_reduce_window2d_(datatype_t datatype, reduce_op_t reduce_op, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint16_t filter_h, uint16_t filter_w, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high);
void tensor_resize_image_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, bool align_corners, bool half_pixel_centers, image_resize_mode_t image_resize_mode);
void tensor_roi_align_(datatype_t datatype, uint8_t rshape_src, uint8_t rshape_dest, roi_align_mode_t mode, float spatial_scale, int64_t sampling_ratio);
Expand Down
2 changes: 1 addition & 1 deletion include/nncase/ir/ops/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class NNCASE_API reduce : public node
float init_value() const noexcept { return init_value_; }
bool keep_dims() const noexcept { return keep_dims_; }

reduce(reduce_op_t reduce_op, shape_t input_shape, axis_t axis, float init_value, bool keep_dims);
reduce(reduce_op_t reduce_op, datatype_t input_type, shape_t input_shape, axis_t axis, float init_value, bool keep_dims);

protected:
bool properties_equal(node &other) const override;
Expand Down
2 changes: 1 addition & 1 deletion include/nncase/ir/ops/reduce_prod.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class NNCASE_API reduce_prod : public node
const axis_t &axis() const noexcept { return axis_; }
bool keep_dims() const noexcept { return keep_dims_; }

reduce_prod(shape_t input_shape, axis_t axis, bool keep_dims);
reduce_prod(datatype_t input_type, shape_t input_shape, axis_t axis, bool keep_dims);

protected:
bool properties_equal(node &other) const override;
Expand Down
3 changes: 2 additions & 1 deletion include/nncase/kernels/cpu/reference/tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ NNCASE_API result<void> quantize(datatype_t in_type, datatype_t out_type, const
NNCASE_API result<void> unary(unary_op_t op, const float *input, float *output, const runtime_shape_t &shape,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, kernel_context &context) noexcept;

NNCASE_API result<void> reduce(reduce_op_t op, float init_value, const float *input, float *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
template <typename T>
NNCASE_API result<void> reduce(reduce_op_t op, T init_value, const T *input, T *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context) noexcept;

template <typename T>
Expand Down
3 changes: 2 additions & 1 deletion include/nncase/kernels/tensor_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ NNCASE_API result<void> quantize(datatype_t in_type, datatype_t out_type, const
NNCASE_API result<void> unary(unary_op_t op, const float *input, float *output, const runtime_shape_t &shape,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, kernel_context &context = default_kernel_context()) noexcept;

NNCASE_API result<void> reduce(reduce_op_t op, float init_value, const float *input, float *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
template <typename T>
NNCASE_API result<void> reduce(reduce_op_t op, T init_value, const T *input, T *output, const runtime_shape_t &in_shape, const runtime_shape_t &axis,
const runtime_shape_t &in_strides, const runtime_shape_t &out_strides, bool keep_dims, kernel_context &context = default_kernel_context()) noexcept;

template <typename T>
Expand Down
3 changes: 2 additions & 1 deletion include/nncase/runtime/stackvm/op_reader.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 4/20/2022 2:35:53 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 4/22/2022 10:57:45 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -1489,6 +1489,7 @@ struct op_reader<tensor_reduce_prod_op_t>
tensor_reduce_prod_op_t op(default_init);
op.opcode = static_cast<opcode_t>(reader.read_unaligned<uint8_t>());
op.funct = static_cast<tensor_function_t>(reader.read_unaligned<uint16_t>());
op.datatype = static_cast<datatype_t>(reader.read_unaligned<uint8_t>());
op.rshape_src = reader.read_unaligned<uint8_t>();
op.rstride_src = reader.read_unaligned<uint8_t>();
op.rstride_dest = reader.read_unaligned<uint8_t>();
Expand Down
7 changes: 4 additions & 3 deletions include/nncase/runtime/stackvm/opcode.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 4/20/2022 2:35:53 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 4/22/2022 10:57:45 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -1625,15 +1625,16 @@ struct tensor_reduce_prod_op_t
{
opcode_t opcode;
tensor_function_t funct;
datatype_t datatype;
uint8_t rshape_src;
uint8_t rstride_src;
uint8_t rstride_dest;
uint8_t rshape_axes;
bool keep_dims;

tensor_reduce_prod_op_t(default_init_t) noexcept { }
explicit tensor_reduce_prod_op_t(uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_axes, bool keep_dims) noexcept
: opcode(opcode_t::TENSOR), funct(tensor_function_t::REDUCE_PROD), rshape_src(rshape_src), rstride_src(rstride_src), rstride_dest(rstride_dest), rshape_axes(rshape_axes), keep_dims(keep_dims)
explicit tensor_reduce_prod_op_t(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_axes, bool keep_dims) noexcept
: opcode(opcode_t::TENSOR), funct(tensor_function_t::REDUCE_PROD), datatype(datatype), rshape_src(rshape_src), rstride_src(rstride_src), rstride_dest(rstride_dest), rshape_axes(rshape_axes), keep_dims(keep_dims)
{
}
};
Expand Down
6 changes: 3 additions & 3 deletions src/codegen/stackvm/op_writer.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/* This file is generated by tools/stackvm_gen/IsaGen at 4/20/2022 2:35:53 PM +08:00.
/* This file is generated by tools/stackvm_gen/IsaGen at 4/22/2022 10:57:45 PM +08:00.
*
* Copyright 2019-2021 Canaan Inc.
*
Expand Down Expand Up @@ -608,9 +608,9 @@ void op_builder::tensor_reduce_arg_(datatype_t datatype_src, uint8_t rshape_src,
op_writer<tensor_reduce_arg_op_t>()(tensor_reduce_arg_op_t(datatype_src, rshape_src, rstride_src, datatype_dest, rstride_dest, reduce_arg_op, rshape_axis, keep_dims, select_last_idx), writer_);
}

void op_builder::tensor_reduce_prod_(uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_axes, bool keep_dims)
void op_builder::tensor_reduce_prod_(datatype_t datatype, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint8_t rshape_axes, bool keep_dims)
{
op_writer<tensor_reduce_prod_op_t>()(tensor_reduce_prod_op_t(rshape_src, rstride_src, rstride_dest, rshape_axes, keep_dims), writer_);
op_writer<tensor_reduce_prod_op_t>()(tensor_reduce_prod_op_t(datatype, rshape_src, rstride_src, rstride_dest, rshape_axes, keep_dims), writer_);
}

void op_builder::tensor_reduce_window2d_(datatype_t datatype, reduce_op_t reduce_op, uint8_t rshape_src, uint8_t rstride_src, uint8_t rstride_dest, uint16_t filter_h, uint16_t filter_w, uint16_t stride_h, uint16_t stride_w, uint16_t dilation_h, uint16_t dilation_w, float fused_clamp_low, float fused_clamp_high)
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/stackvm/ops/reduce_prod.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,5 +30,5 @@ void stackvm_module_builder::emit(reduce_prod &node, stackvm_op_builder &builder
builder.stshape(1, input.strides);
builder.stshape(2, output.strides);
builder.staxis(3, node.axis());
builder.tensor_reduce_prod_(0, 1, 2, 3, node.keep_dims());
builder.tensor_reduce_prod_(node.input().type(), 0, 1, 2, 3, node.keep_dims());
}
46 changes: 32 additions & 14 deletions src/evaluator/ops/neutral/neutral_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,16 +340,25 @@ void register_neutral_evaluators()

register_evaluator(op_reduce, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<reduce &>(node);

assert(rnode.input().type() == dt_float32);
auto input = context.memory_at(rnode.input());
auto output = context.memory_at(rnode.output());
auto input_mem = input.buffer().as_span<float>();
auto output_mem = output.buffer().as_span<float>();

kernels::reduce(rnode.reduce_op(), rnode.init_value(), input_mem.data(), output_mem.data(), input.shape(),
to(rnode.axis()), input.strides(), output.strides(), rnode.keep_dims())
.unwrap_or_throw();
auto input_type = rnode.input().type();
switch (input_type)
{
case dt_float32:
kernels::reduce(rnode.reduce_op(), static_cast<float>(rnode.init_value()), input.buffer().as_span<float>().data(),
output.buffer().as_span<float>().data(), input.shape(), to(rnode.axis()), input.strides(), output.strides(), rnode.keep_dims())
.unwrap_or_throw();
break;
case dt_int32:
kernels::reduce(rnode.reduce_op(), static_cast<int32_t>(rnode.init_value()), input.buffer().as_span<int32_t>().data(),
output.buffer().as_span<int32_t>().data(), input.shape(), to(rnode.axis()), input.strides(), output.strides(), rnode.keep_dims())
.unwrap_or_throw();
break;
default:
std::cerr << "unsupported dtype for reduce: " + std::string(datatype_names(input_type));
}
});

register_evaluator(op_reduce_arg, [](ir::node &node, function_evaluate_context &context) {
Expand Down Expand Up @@ -381,16 +390,25 @@ void register_neutral_evaluators()

register_evaluator(op_reduce_prod, [](ir::node &node, function_evaluate_context &context) {
auto &rnode = static_cast<reduce_prod &>(node);

assert(rnode.input().type() == dt_float32);
auto input = context.memory_at(rnode.input());
auto output = context.memory_at(rnode.output());
auto input_mem = input.buffer().as_span<float>();
auto output_mem = output.buffer().as_span<float>();

kernels::reduce_prod(input_mem.data(), output_mem.data(), input.shape(),
input.strides(), output.strides(), to(rnode.axis()), rnode.keep_dims())
.unwrap_or_throw();
auto input_type = rnode.input().type();
switch (input_type)
{
case dt_float32:
kernels::reduce_prod(input.buffer().as_span<float>().data(), output.buffer().as_span<float>().data(), input.shape(),
input.strides(), output.strides(), to(rnode.axis()), rnode.keep_dims())
.unwrap_or_throw();
break;
case dt_int32:
kernels::reduce_prod(input.buffer().as_span<int32_t>().data(), output.buffer().as_span<int32_t>().data(), input.shape(),
input.strides(), output.strides(), to(rnode.axis()), rnode.keep_dims())
.unwrap_or_throw();
break;
default:
std::cerr << "unsupported dtype for reduce_prod: " + std::string(datatype_names(input_type));
}
});

register_evaluator(op_reduce_window2d, [](ir::node &node, function_evaluate_context &context) {
Expand Down
4 changes: 2 additions & 2 deletions src/importer/caffe/ops/softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ DEFINE_CAFFE_LOWER(Softmax)
axis_t reduce_axis;
reduce_axis.push_back(param.axis());

auto max = graph_.emplace<reduce>(reduce_max, input.shape(), reduce_axis, std::numeric_limits<float>::lowest(), true);
auto max = graph_.emplace<reduce>(reduce_max, dt_float32, input.shape(), reduce_axis, std::numeric_limits<float>::lowest(), true);
max->name(op.name() + "/max");
auto sub = graph_.emplace<binary>(binary_sub, dt_float32, input.shape(), max->output().shape(), value_range<float>::full());
sub->name(op.name() + "/sub");
auto exp = graph_.emplace<unary>(unary_exp, sub->output().shape());
exp->name(op.name() + "/exp");
auto sum = graph_.emplace<reduce>(reduce_sum, exp->output().shape(), reduce_axis, 0.f, true);
auto sum = graph_.emplace<reduce>(reduce_sum, dt_float32, exp->output().shape(), reduce_axis, 0.f, true);
sum->name(op.name() + "/sum");
auto div = graph_.emplace<binary>(binary_div, dt_float32, exp->output().shape(), sum->output().shape(), value_range<float>::full());
div->name(op.name() + "/div");
Expand Down
4 changes: 2 additions & 2 deletions src/importer/onnx/ops/instancenorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void onnx_importer::convert_op_InstanceNormalization(const NodeProto &node)
}
float init_value = 0.f;
bool keepdims = true;
auto mean = graph_.emplace<reduce>(reduce_mean, input_shape, axes, init_value, keepdims);
auto mean = graph_.emplace<reduce>(reduce_mean, input_type, input_shape, axes, init_value, keepdims);
mean->name(op_name + ".reduce_mean(InstanceNormalization)");

// x - mean
Expand All @@ -80,7 +80,7 @@ void onnx_importer::convert_op_InstanceNormalization(const NodeProto &node)
// variance
auto square = graph_.emplace<unary>(unary_square, sub->output().shape());
square->name(op_name + ".square(InstanceNormalization)");
auto variance = graph_.emplace<reduce>(reduce_mean, square->output().shape(), axes, init_value, keepdims);
auto variance = graph_.emplace<reduce>(reduce_mean, input_type, square->output().shape(), axes, init_value, keepdims);
variance->name(op_name + ".reduce(InstanceNormalization)");

// sqrt(variance + epsilon)
Expand Down
4 changes: 2 additions & 2 deletions src/importer/onnx/ops/lpnorm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ void onnx_importer::convert_op_LpNormalization(const NodeProto &node)
{
auto abs = graph_.emplace<unary>(unary_abs, input_shape);
abs->name(op_name + ".abs(L1Normalization)");
auto sum = graph_.emplace<reduce>(reduce_sum, abs->output().shape(), reduce_axis, 0.f, true);
auto sum = graph_.emplace<reduce>(reduce_sum, input_type, abs->output().shape(), reduce_axis, 0.f, true);
sum->name(op_name + ".reduce_sum(L1Normalization)");
auto div = graph_.emplace<binary>(binary_div, input_type, input_shape, sum->output().shape(), value_range<float>::full());
div->name(op_name + ".div(L1Normalization)");
Expand All @@ -64,7 +64,7 @@ void onnx_importer::convert_op_LpNormalization(const NodeProto &node)
{
auto square = graph_.emplace<unary>(unary_square, input_shape);
square->name(op_name + ".square(L2Normalization)");
auto sum = graph_.emplace<reduce>(reduce_sum, square->output().shape(), reduce_axis, 0.f, true);
auto sum = graph_.emplace<reduce>(reduce_sum, input_type, square->output().shape(), reduce_axis, 0.f, true);
sum->name(op_name + ".reduce_sum(L2Normalization)");
auto epsilon = graph_.emplace<constant>(1e-10f);
epsilon->name(op_name + ".eps(L2Normalization)");
Expand Down
2 changes: 1 addition & 1 deletion src/importer/onnx/ops/lrn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void onnx_importer::convert_op_LRN(const NodeProto &node)
sl->name(op_name + ".slice_" + std::to_string(i) + "(LRN)");
sl->input().connect(square->output());

auto r_sum = graph_.emplace<reduce>(reduce_sum, sl->output().shape(), axis_t { 1 }, 0.f, true);
auto r_sum = graph_.emplace<reduce>(reduce_sum, input_type, sl->output().shape(), axis_t { 1 }, 0.f, true);
r_sum->name(op_name + ".reduce_sum_" + std::to_string(i) + "(LRN)");
r_sum->input().connect(sl->output());
con->input_at(i).connect(r_sum->output());
Expand Down
Loading

0 comments on commit 6acbe83

Please sign in to comment.