Skip to content

Commit

Permalink
Refactor matmul ir and add BATCH_MATMUL op. (#577)
Browse files Browse the repository at this point in the history
* Refactor matmul ir and add BATCH_MATMUL op.

* Fix binary/SQUEEZE ops and add TILE op.

* Fix build error.

* Updat test_tile.py

* Add default bias for FULLY_CONNECTED.
  • Loading branch information
zhangyang2057 authored Apr 27, 2022
1 parent 6acbe83 commit 13f52d9
Show file tree
Hide file tree
Showing 23 changed files with 701 additions and 197 deletions.
2 changes: 2 additions & 0 deletions docs/tflite_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
| ARG_MAX ||
| ARG_MIN ||
| AVERAGE_POOL_2D ||
| BATCH_MATMUL ||
| CAST ||
| CEIL ||
| CONCATENATION ||
Expand Down Expand Up @@ -68,6 +69,7 @@
| SUB ||
| SUM ||
| TANH ||
| TILE ||
| TRANSPOSE ||
| TRANSPOSE_CONV ||
| QUANTIZE ||
Expand Down
34 changes: 34 additions & 0 deletions include/nncase/ir/op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,40 @@ inline shape_t get_strided_slice_output_shape(const axis_t &begin, const axis_t
return new_shape.size() ? new_shape : shape_t { 1 };
}

inline shape_t get_matmul_output_shape(const shape_t &input_a_shape, const shape_t &input_b_shape)
{
shape_t b_shape = input_b_shape;
b_shape[b_shape.size() - 2] = input_a_shape[input_a_shape.size() - 2];
b_shape[b_shape.size() - 1] = input_a_shape[input_a_shape.size() - 1];
shape_t out_shape;

const auto dest_dims = (int32_t)std::max(input_a_shape.size(), b_shape.size());
const auto in_a_ext = dest_dims - (int32_t)input_a_shape.size();
const auto in_b_ext = dest_dims - (int32_t)b_shape.size();

for (int32_t i = 0; i < dest_dims; i++)
{
const auto in_a_dim = i - (int32_t)in_a_ext;
const auto in_b_dim = i - (int32_t)in_b_ext;

const auto in_a = in_a_dim < 0 ? 1 : input_a_shape[in_a_dim];
const auto in_b = in_b_dim < 0 ? 1 : b_shape[in_b_dim];
if (in_a == in_b)
out_shape.push_back(in_a);
else if (in_a == 1)
out_shape.push_back(in_b);
else if (in_b == 1)
out_shape.push_back(in_a);
else
throw std::invalid_argument("inputs are not compatible to broadcast");
}

out_shape[out_shape.size() - 2] = input_a_shape[input_a_shape.size() - 2];
out_shape[out_shape.size() - 1] = input_b_shape.back();

return out_shape;
}

inline bool is_copy_slice(const axis_t &strides)
{
return std::all_of(strides.begin(), strides.end(), [](int32_t stride) { return stride == 1; });
Expand Down
29 changes: 29 additions & 0 deletions include/nncase/transforms/neutral/fold_matmul_add.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/* Copyright 2019-2021 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "../transform.h"

namespace nncase::ir::transforms
{
class NNCASE_API fold_matmul_add_transform : public transform
{
public:
void process(transform_context &context) override;

protected:
bool skip_self_contained_check() const noexcept override { return true; }
bool on_try_match(ir::node &node, transform_context &context) override;
};
}
19 changes: 11 additions & 8 deletions src/evaluator/ops/neutral/neutral_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,15 +293,18 @@ void register_neutral_evaluators()

assert(rnode.input_a().type() == dt_float32);
assert(rnode.input_b().type() == dt_float32);
auto input_a = context.memory_at(rnode.input_a()).buffer().as_span<float>();
auto input_b = context.memory_at(rnode.input_b()).buffer().as_span<float>();
auto bias = context.memory_at(rnode.bias()).buffer().as_span<float>();
auto output = context.memory_at(rnode.output()).buffer().as_span<float>();

auto &a_shape = rnode.input_a().shape();
auto &b_shape = rnode.input_b().shape();
auto input_a = context.memory_at(rnode.input_a());
auto input_b = context.memory_at(rnode.input_b());
auto bias = context.memory_at(rnode.bias());
auto output = context.memory_at(rnode.output());
auto input_a_mem = input_a.buffer().as_span<float>();
auto input_b_mem = input_b.buffer().as_span<float>();
auto bias_mem = bias.buffer().as_span<float>();
auto output_mem = output.buffer().as_span<float>();

neutral::matmul(input_a.data(), input_b.data(), output.data(), bias.data(), (int32_t)a_shape[0], (int32_t)a_shape[1], (int32_t)b_shape[1], rnode.fused_activation());
kernels::matmul(input_a_mem.data(), input_b_mem.data(), bias_mem.data(), output_mem.data(), input_a.shape(), input_a.strides(),
input_b.shape(), input_b.strides(), output.shape(), output.strides(), rnode.fused_activation())
.unwrap_or_throw();
});

register_evaluator(op_pad, [](ir::node &node, function_evaluate_context &context) {
Expand Down
94 changes: 19 additions & 75 deletions src/importer/onnx/ops/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ void onnx_importer::convert_op_MatMul(const NodeProto &node)
const auto &output = node.output()[0];
const auto &output_shape = get_shape(output);

// reshape A to [batch, n, k]
// reshape A to [batch, m, k]
shape_t new_a_shape { 1, 1, 1 };
auto input_a_shape_size = input_a_shape.size();
if (input_a_shape_size == 1)
Expand Down Expand Up @@ -75,7 +75,7 @@ void onnx_importer::convert_op_MatMul(const NodeProto &node)
auto bc_a_3d = graph_.emplace<bitcast>(input_type, input_a_shape, new_a_shape);
bc_a_3d->name(op_name + ".bitcast_A_3d(MatMul)");

// reshape B to [batch, k, m]
// reshape B to [batch, k, n]
shape_t new_b_shape = { 1, 1, 1 };
auto input_b_shape_size = input_b_shape.size();
if (input_b_shape_size == 1)
Expand Down Expand Up @@ -115,7 +115,7 @@ void onnx_importer::convert_op_MatMul(const NodeProto &node)
shape_t new_output_shape { 1, new_a_shape[1], new_b_shape[2] };
new_output_shape[0] = new_a_shape[0] == 1 ? new_b_shape[0] : new_a_shape[0];

bitcast *bc_a = nullptr;
bitcast *bc_a = bc_a_3d;
if (new_a_shape[0] == 1)
{
// reshape to 2D
Expand All @@ -125,7 +125,7 @@ void onnx_importer::convert_op_MatMul(const NodeProto &node)
bc_a->input().connect(bc_a_3d->output());
}

bitcast *bc_b = nullptr;
bitcast *bc_b = bc_b_3d;
if (new_b_shape[0] == 1)
{
// reshape to 2D
Expand All @@ -135,80 +135,24 @@ void onnx_importer::convert_op_MatMul(const NodeProto &node)
bc_b->input().connect(bc_b_3d->output());
}

// concat
std::vector<ir::shape_t> concat_shape(new_output_shape[0], ir::shape_t { 1, new_output_shape[1], new_output_shape[2] });
auto con = graph_.emplace<concat>(input_type, concat_shape, 0);
con->name(op_name + ".concat(MatMul)");
// bias
auto b = bc_b->output().shape().back();
std::vector<float> bias_value(b, 0.f);
shape_t bias_shape = { b };
auto bias = graph_.emplace<constant>(dt_float32, bias_shape, bias_value);
bias->name(op_name + ".bias(MatMul)");

for (auto i = 0; i < new_output_shape[0]; i++)
{
// A
bitcast *bc_a_2d = bc_a;
if (new_a_shape[0] != 1)
{
// slice batch
auto sl = graph_.emplace<slice>(input_type, bc_a_3d->output().shape(),
axis_t { i, 0, 0 },
axis_t { static_cast<int32_t>(i + 1), static_cast<int32_t>(new_a_shape[1]), static_cast<int32_t>(new_a_shape[2]) },
axis_t { 1, 1, 1 }, 0, 0, 0, 0, 0);
sl->name(op_name + ".slice_A_" + std::to_string(i) + "(MatMul)");
sl->input().connect(bc_a_3d->output());

// reshape to 2D
shape_t new_shape { new_a_shape[1], new_a_shape[2] };
bc_a_2d = graph_.emplace<bitcast>(input_type, sl->output().shape(), new_shape);
bc_a_2d->name(op_name + ".bitcast_A_2d(MatMul)");
bc_a_2d->input().connect(sl->output());
}

// B
bitcast *bc_b_2d = bc_b;
if (new_b_shape[0] != 1)
{
// slice batch
auto sl = graph_.emplace<slice>(input_type, bc_b_3d->output().shape(),
axis_t { i, 0, 0 },
axis_t { static_cast<int32_t>(i + 1), static_cast<int32_t>(new_b_shape[1]), static_cast<int32_t>(new_b_shape[2]) },
axis_t { 1, 1, 1 }, 0, 0, 0, 0, 0);
sl->name(op_name + ".slice_B_" + std::to_string(i) + "(MatMul)");
sl->input().connect(bc_b_3d->output());

// reshape to 2D
shape_t new_shape { new_b_shape[1], new_b_shape[2] };
bc_b_2d = graph_.emplace<bitcast>(input_type, sl->output().shape(), new_shape);
bc_b_2d->name(op_name + ".bitcast_B_slice(MatMul)");
bc_b_2d->input().connect(sl->output());
}

// bias
auto b = bc_b_2d->output().shape().back();
std::vector<float> bias_value(b, 0.f);
shape_t bias_shape = { b };
auto bias = graph_.emplace<constant>(dt_float32, bias_shape, bias_value);
bias->name(op_name + ".bias(MatMul)");

// matmul
auto mm = graph_.emplace<matmul>(bc_a_2d->output().shape(), bc_b_2d->output().shape(), value_range<float>::full());
mm->name(op_name + ".matmul(MatMul)");
mm->input_a().connect(bc_a_2d->output());
mm->input_b().connect(bc_b_2d->output());
mm->bias().connect(bias->output());

// reshape to 3D
auto mm_shape = mm->output().shape();
shape_t bc_mm_shape { 1, mm_shape[0], mm_shape[1] };
auto bc_mm_3d = graph_.emplace<bitcast>(input_type, mm_shape, bc_mm_shape);
bc_mm_3d->name(op_name + ".bitcast_mm_3d(MatMul)");
bc_mm_3d->input().connect(mm->output());

// concat at axis 0
con->input_at(i).connect(bc_mm_3d->output());
}
// matmul
auto mm = graph_.emplace<matmul>(bc_a->output().shape(), bc_b->output().shape(), value_range<float>::full());
mm->name(op_name + ".matmul(MatMul)");
mm->input_a().connect(bc_a->output());
mm->input_b().connect(bc_b->output());
mm->bias().connect(bias->output());

// reshape to output
auto bc_output = graph_.emplace<bitcast>(input_type, con->output().shape(), output_shape);
bc_output->name(op_name + ".bitcast_concat(MatMul)");
bc_output->input().connect(con->output());
auto bc_output = graph_.emplace<bitcast>(input_type, mm->output().shape(), output_shape);
bc_output->name(op_name + ".bitcast(MatMul)");
bc_output->input().connect(mm->output());

input_tensors_.emplace(&bc_a_3d->input(), input_a);
input_tensors_.emplace(&bc_b_3d->input(), input_b);
Expand Down
1 change: 1 addition & 0 deletions src/importer/tflite/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ set(SRCS tflite_importer.cpp
ops/softmax.cpp
ops/slice.cpp
ops/resize_image.cpp
ops/tile.cpp
ops/transpose.cpp
ops/space_to_batch.cpp
ops/l2_normalization.cpp
Expand Down
2 changes: 2 additions & 0 deletions src/importer/tflite/opcode.def
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ DEFINE_OPCODE(ADD)
DEFINE_OPCODE(ARG_MAX)
DEFINE_OPCODE(ARG_MIN)
DEFINE_OPCODE(AVERAGE_POOL_2D)
DEFINE_OPCODE(BATCH_MATMUL)
DEFINE_OPCODE(CAST)
DEFINE_OPCODE(CEIL)
DEFINE_OPCODE(CONCATENATION)
Expand Down Expand Up @@ -62,6 +63,7 @@ DEFINE_OPCODE(SQUARE)
DEFINE_OPCODE(SUB)
DEFINE_OPCODE(SUM)
DEFINE_OPCODE(TANH)
DEFINE_OPCODE(TILE)
DEFINE_OPCODE(TRANSPOSE)
DEFINE_OPCODE(TRANSPOSE_CONV)
DEFINE_OPCODE(QUANTIZE)
Expand Down
4 changes: 2 additions & 2 deletions src/importer/tflite/ops/binary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ void tflite_importer::convert_binary(const tflite::Operator &op, binary_op_t bin
dequantize *input_a_dequant, *input_b_dequant;
quantize *output_quant;
// input_a dequantize
if (input_type != dt_float32)
if (input_type == dt_uint8 || input_type == dt_int8)
{
quant_param_t input_a_paras = to_quant_param(input_a.quantization());
input_a_dequant = graph_.emplace<dequantize>(to_data_type(input_a.type()), get_shape(input_a.shape()), dt_float32, input_a_paras);
Expand All @@ -143,7 +143,7 @@ void tflite_importer::convert_binary(const tflite::Operator &op, binary_op_t bin
}

//input_b dequantize
if (input_type != dt_float32)
if (input_type == dt_uint8 || input_type == dt_int8)
{
quant_param_t input_b_paras = to_quant_param(input_b.quantization());
input_b_dequant = graph_.emplace<dequantize>(to_data_type(input_b.type()), get_shape(input_b.shape()), dt_float32, input_b_paras);
Expand Down
Loading

0 comments on commit 13f52d9

Please sign in to comment.