diff --git a/include/nncase/transforms/neutral/optimize_allocation.h b/include/nncase/transforms/neutral/optimize_allocation.h index 43d064e4c6..c01b4a1f8b 100644 --- a/include/nncase/transforms/neutral/optimize_allocation.h +++ b/include/nncase/transforms/neutral/optimize_allocation.h @@ -71,6 +71,15 @@ class NNCASE_API add_copy_to_output_pass : public graph_pass void run_core(graph &graph, nncase::target &target, const run_pass_options &options) override; }; +class NNCASE_API add_copy_to_bitcast_pass : public graph_pass +{ +public: + using graph_pass::graph_pass; + +protected: + void run_core(graph &graph, nncase::target &target, const run_pass_options &options) override; +}; + class NNCASE_API remove_exclusive_copy_to_output_transform : public transform { public: @@ -89,6 +98,15 @@ class NNCASE_API remove_exclusive_copy_to_concat_transform : public transform bool on_try_match(ir::node &node, transform_context &context) override; }; +class NNCASE_API remove_exclusive_copy_to_bitcast_transform : public transform +{ +public: + void process(transform_context &context) override; + +protected: + bool on_try_match(ir::node &node, transform_context &context) override; +}; + class NNCASE_API remove_simple_copy_from_slice_transform : public transform { public: diff --git a/src/nncase/compiler.cpp b/src/nncase/compiler.cpp index 5f0d36da5a..77ebc2beb1 100644 --- a/src/nncase/compiler.cpp +++ b/src/nncase/compiler.cpp @@ -485,12 +485,14 @@ class compiler_impl : public compiler pmgr.add_pass(); pmgr.add_pass(); pmgr.add_pass(); + pmgr.add_pass(); transform_pass pass("optimize_copy"); pass.emplace(); pass.emplace(); pass.emplace(); pass.emplace(); + pass.emplace(); pmgr.add_pass(std::move(pass)); }); } diff --git a/src/transforms/neutral/optimize_allocation.cpp b/src/transforms/neutral/optimize_allocation.cpp index cd25b67c41..211a917bfb 100644 --- a/src/transforms/neutral/optimize_allocation.cpp +++ b/src/transforms/neutral/optimize_allocation.cpp @@ -140,6 +140,25 @@ void add_copy_to_output_pass::run_core(graph &graph, [[maybe_unused]] nncase::ta alias_visitor.visit(graph); } +void add_copy_to_bitcast_pass::run_core(graph &graph, [[maybe_unused]] nncase::target &target, [[maybe_unused]] const run_pass_options &options) +{ + auto alias_visitor = make_relay_ir_visitor([&](node &node) { + if (auto b = node_cast(node)) + { + auto &out = *b->input().connection(); + if (out.owner().runtime_opcode() != op_copy) + { + auto cp = graph.emplace(out.type(), out.shape()); + cp->module_type(graph.module_type()); + cp->name(out.owner().name() + "/copy"); + cp->input().connect(out); + b->input().connect(cp->output()); + } + } + }); + alias_visitor.visit(graph); +} + // x@data x@output // | | // copy | @@ -222,6 +241,40 @@ void remove_exclusive_copy_to_concat_transform::process(transform_context &conte in->connect(output); } +bool remove_exclusive_copy_to_bitcast_transform::on_try_match(node &node, transform_context &context) +{ + copy *cp; + bitcast *b; + + if ((cp = node_cast(node)) + && (b = try_get_direct_child(*cp))) + { + auto input = cp->input().connection(); + + if (input->memory_location() == mem_data + && ((input->attributes() & cnctr_attr_no_buffer_fusion) == 0)) + { + context.inputs.emplace_back(&cp->input()); + context.outputs.emplace_back(&cp->output()); + + context.matched_nodes.emplace_back(cp); + return true; + } + } + + return false; +} + +void remove_exclusive_copy_to_bitcast_transform::process(transform_context &context) +{ + auto &output = *context.inputs[0]->connection(); + auto inputs = context.outputs[0]->connections(); + + output.attributes(output.attributes() | cnctr_attr_no_buffer_fusion); + for (auto &in : dup(inputs)) + in->connect(output); +} + // x x // | | // slice | diff --git a/tests/schedule/buffer_fusion/test_bitcast.py b/tests/schedule/buffer_fusion/test_bitcast.py new file mode 100644 index 0000000000..5797e0c29e --- /dev/null +++ b/tests/schedule/buffer_fusion/test_bitcast.py @@ -0,0 +1,42 @@ +# Copyright 2019-2021 Canaan Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# pylint: disable=invalid-name, unused-argument, import-outside-toplevel + +import pytest +import tensorflow as tf +import numpy as np +from tflite_test_runner import TfliteTestRunner + + +def _make_module(): + class Module(tf.Module): + def __init__(self): + super(Module).__init__() + + @tf.function(input_signature=[tf.TensorSpec([1, 4, 4, 3], tf.float32)]) + def __call__(self, x): + return tf.reshape(x, [1, -1, 3]) + return Module() + + +def test_bitcast(request): + module = _make_module() + + runner = TfliteTestRunner(request.node.name) + model_file = runner.from_tensorflow(module) + runner.run(model_file) + + +if __name__ == "__main__": + pytest.main(['-vv', 'test_bitcast.py'])