Skip to content

Commit

Permalink
Fix/fold matmul add (#584)
Browse files Browse the repository at this point in the history
* support dump quant error for layers

* fix bug

* remove useless code

* format code

* small change

* fix ofstream

* refine weights quantilization with minimize mse

* Update config.yml

* Update config.yml

* Update test_runner.py

* Update config.yml

* regist evaluators for quant related IRs and support dump quant error for k210

* remove useless files

* remove useless file

* remove useless file

* specify quant mode

* fix range

* Update convolution.cpp

* Update quantizer.cpp

* Update quantizer.cpp

* Update quantizer.cpp

* snake style

* Update quantizer.cpp

* remove assert

* fix data type

* Update quantizer.cpp

* dump op range for import graph

* add count_include_pad in tflite pool importer

* revert

* dump output range in order

* support dump range for noptq

* fix test_runner

* fix bug

* format issue

* format issue

* add k230 target in config.yml

* add bitcast clamp motion pass

* apply code-format changes

* add do_letterbox flag

* apply code-format changes

* revert bitcast motion, do it in another branch

* specify do_letterbox flag for each preprocess test cases

* fix config

* flag for ncc

* use input_shpae to judge whether do letterbox or not

* fix typo

* fix letterbox bug

* judge input shape according to both input layout and network framework type

* apply code-format changes

* fix shape

* fix dump quant error

* dump data for each layer before and after quant

* apply code-format changes

* fix data_dir

* fix bias round issue

* apply code-format changes

* formatted

* data type

* support any bits for quant

* support int16 quant

* do not modify src_bin now

* int16 for deq

* support multiple input quant

* fix set_input_tensor

* exclude wrong model

* fix no inputs condition

* Update test_runner.py

* fold matmul-bitcast-add pattern

Co-authored-by: zhangjizhao <zhangjizhao@canaan-creative.com>
Co-authored-by: aaltonenzhang <aaltonenzhang@users.noreply.github.com>
  • Loading branch information
3 people authored May 11, 2022
1 parent 8951bab commit b33ad94
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions src/transforms/neutral/fold_matmul_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
* limitations under the License.
*/
#include <nncase/ir/ops/binary.h>
#include <nncase/ir/ops/bitcast.h>
#include <nncase/ir/ops/constant.h>
#include <nncase/ir/ops/matmul.h>
#include <nncase/ir/visitor.h>
Expand All @@ -26,11 +27,12 @@ bool fold_matmul_add_transform::on_try_match(node &node, transform_context &cont
{
matmul *mm = nullptr;
binary *add = nullptr;
bitcast *bc = nullptr;
constant *bias_constant = nullptr;
constant *add_constant = nullptr;
if ((add = node_cast<binary>(node))
&& (add->binary_op() == binary_add)
&& (((mm = try_get_direct_parent<matmul>(*add, 0)) && (add_constant = try_get_direct_parent<constant>(*add, 1))) || ((mm = try_get_direct_parent<matmul>(*add, 1)) && (add_constant = try_get_direct_parent<constant>(*add, 0))))
&& ((((mm = try_get_direct_parent<matmul>(*add, 0)) && (add_constant = try_get_direct_parent<constant>(*add, 1))) || ((mm = try_get_direct_parent<matmul>(*add, 1)) && (add_constant = try_get_direct_parent<constant>(*add, 0)))) || ((bc = try_get_direct_parent<bitcast>(*add, 0)) && (mm = try_get_direct_parent<matmul>(*bc)) && (add_constant = try_get_direct_parent<constant>(*add, 1))) || ((bc = try_get_direct_parent<bitcast>(*add, 1)) && (mm = try_get_direct_parent<matmul>(*bc)) && (add_constant = try_get_direct_parent<constant>(*add, 0))))
&& (mm->fused_activation() == value_range<float>::full())
&& (bias_constant = node_cast<constant>(mm->bias().connection()->owner()))
&& (bias_constant->data().size() == add_constant->data().size()))
Expand Down Expand Up @@ -79,12 +81,16 @@ void fold_matmul_add_transform::process(transform_context &context)

// create new matmul
auto new_mm = context.graph.emplace<matmul>(old_mm->input_a().shape(), old_mm->input_b().shape(), mm_act);
auto new_bc = context.graph.emplace<bitcast>(new_mm->output().type(), new_mm->output().shape(), add->output().shape());

new_bc->name(old_mm->name() + "/bitcast");
new_mm->name(old_mm->name());
new_mm->input_a().connect(*context.inputs[0]->connection());
new_mm->input_b().connect(*context.inputs[1]->connection());
new_mm->bias().connect(new_bias->output());
new_bc->input().connect(new_mm->output());

auto inputs = context.outputs[0]->connections();
for (auto &in : dup(inputs))
in->connect(new_mm->output());
in->connect(new_bc->output());
}

0 comments on commit b33ad94

Please sign in to comment.