Skip to content

Commit

Permalink
fix bitcast buffer allocation (#632)
Browse files Browse the repository at this point in the history
  • Loading branch information
xhuohai authored Aug 25, 2022
1 parent 8baf43a commit bb95a5b
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 0 deletions.
18 changes: 18 additions & 0 deletions include/nncase/transforms/neutral/optimize_allocation.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,15 @@ class NNCASE_API add_copy_to_output_pass : public graph_pass
void run_core(graph &graph, nncase::target &target, const run_pass_options &options) override;
};

class NNCASE_API add_copy_to_bitcast_pass : public graph_pass
{
public:
using graph_pass::graph_pass;

protected:
void run_core(graph &graph, nncase::target &target, const run_pass_options &options) override;
};

class NNCASE_API remove_exclusive_copy_to_output_transform : public transform
{
public:
Expand All @@ -89,6 +98,15 @@ class NNCASE_API remove_exclusive_copy_to_concat_transform : public transform
bool on_try_match(ir::node &node, transform_context &context) override;
};

class NNCASE_API remove_exclusive_copy_to_bitcast_transform : public transform
{
public:
void process(transform_context &context) override;

protected:
bool on_try_match(ir::node &node, transform_context &context) override;
};

class NNCASE_API remove_simple_copy_from_slice_transform : public transform
{
public:
Expand Down
2 changes: 2 additions & 0 deletions src/nncase/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,12 +485,14 @@ class compiler_impl : public compiler
pmgr.add_pass<add_copy_to_concat_pass>();
pmgr.add_pass<add_copy_to_slice_pass>();
pmgr.add_pass<add_copy_to_output_pass>();
pmgr.add_pass<add_copy_to_bitcast_pass>();

transform_pass pass("optimize_copy");
pass.emplace<remove_exclusive_copy_to_output_transform>();
pass.emplace<remove_simple_copy_from_slice_transform>();
pass.emplace<remove_non_simple_copy_from_slice_transform>();
pass.emplace<remove_exclusive_copy_to_concat_transform>();
pass.emplace<remove_exclusive_copy_to_bitcast_transform>();
pmgr.add_pass(std::move(pass));
});
}
Expand Down
53 changes: 53 additions & 0 deletions src/transforms/neutral/optimize_allocation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,25 @@ void add_copy_to_output_pass::run_core(graph &graph, [[maybe_unused]] nncase::ta
alias_visitor.visit(graph);
}

void add_copy_to_bitcast_pass::run_core(graph &graph, [[maybe_unused]] nncase::target &target, [[maybe_unused]] const run_pass_options &options)
{
auto alias_visitor = make_relay_ir_visitor([&](node &node) {
if (auto b = node_cast<bitcast>(node))
{
auto &out = *b->input().connection();
if (out.owner().runtime_opcode() != op_copy)
{
auto cp = graph.emplace<copy>(out.type(), out.shape());
cp->module_type(graph.module_type());
cp->name(out.owner().name() + "/copy");
cp->input().connect(out);
b->input().connect(cp->output());
}
}
});
alias_visitor.visit(graph);
}

// x@data x@output
// | |
// copy |
Expand Down Expand Up @@ -222,6 +241,40 @@ void remove_exclusive_copy_to_concat_transform::process(transform_context &conte
in->connect(output);
}

bool remove_exclusive_copy_to_bitcast_transform::on_try_match(node &node, transform_context &context)
{
copy *cp;
bitcast *b;

if ((cp = node_cast<copy>(node))
&& (b = try_get_direct_child<bitcast>(*cp)))
{
auto input = cp->input().connection();

if (input->memory_location() == mem_data
&& ((input->attributes() & cnctr_attr_no_buffer_fusion) == 0))
{
context.inputs.emplace_back(&cp->input());
context.outputs.emplace_back(&cp->output());

context.matched_nodes.emplace_back(cp);
return true;
}
}

return false;
}

void remove_exclusive_copy_to_bitcast_transform::process(transform_context &context)
{
auto &output = *context.inputs[0]->connection();
auto inputs = context.outputs[0]->connections();

output.attributes(output.attributes() | cnctr_attr_no_buffer_fusion);
for (auto &in : dup(inputs))
in->connect(output);
}

// x x
// | |
// slice |
Expand Down
42 changes: 42 additions & 0 deletions tests/schedule/buffer_fusion/test_bitcast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright 2019-2021 Canaan Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=invalid-name, unused-argument, import-outside-toplevel

import pytest
import tensorflow as tf
import numpy as np
from tflite_test_runner import TfliteTestRunner


def _make_module():
class Module(tf.Module):
def __init__(self):
super(Module).__init__()

@tf.function(input_signature=[tf.TensorSpec([1, 4, 4, 3], tf.float32)])
def __call__(self, x):
return tf.reshape(x, [1, -1, 3])
return Module()


def test_bitcast(request):
module = _make_module()

runner = TfliteTestRunner(request.node.name)
model_file = runner.from_tensorflow(module)
runner.run(model_file)


if __name__ == "__main__":
pytest.main(['-vv', 'test_bitcast.py'])

0 comments on commit bb95a5b

Please sign in to comment.