Skip to content

Commit

Permalink
add squeeze, gather, cast Ops (PaddlePaddle#72)
Browse files Browse the repository at this point in the history
* add Op

* add gather, cast, squeeze

* pre-commit
  • Loading branch information
yaozhixin authored Aug 18, 2021
1 parent 1ae448d commit 44d9053
Show file tree
Hide file tree
Showing 7 changed files with 431 additions and 2 deletions.
55 changes: 55 additions & 0 deletions paddle/fluid/framework/ipu/ipu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,61 @@ void IpuBackend::LowerBody(const ir::Graph* graph) {
auto perm = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("axis"));
auto result = builder_->aiOnnxOpset11().transpose(inputs, perm);

tensors_.emplace(outputs[0], result);
} else if (op_type == "Gather") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
int64_t axis = 0;
auto result = builder_->aiOnnxOpset11().gather(inputs, axis);

tensors_.emplace(outputs[0], result);
} else if (op_type == "Squeeze") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
auto axes = BOOST_GET_CONST(std::vector<int64_t>, op->GetAttr("axes"));
auto result = builder_->aiOnnxOpset11().squeeze(inputs, axes);

tensors_.emplace(outputs[0], result);
} else if (op_type == "Cast") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
// TODO(yaozhxin): support np.dtype and core.VarDesc.VarType
auto to_ = BOOST_GET_CONST(int, op->GetAttr("to"));
std::string to = "";
switch (to_) {
case proto::VarType::UINT8:
to = "UINT8";
break;
case proto::VarType::INT8:
to = "INT8";
break;
case proto::VarType::INT16:
to = "INT16";
break;
case proto::VarType::INT32:
to = "INT32";
break;
case proto::VarType::INT64:
to = "INT64";
break;
case proto::VarType::BOOL:
to = "BOOL";
break;
case proto::VarType::FP64:
to = "DOUBLE";
break;
case proto::VarType::FP32:
to = "FLOAT";
break;
case proto::VarType::FP16:
to = "FLOAT16";
break;
default:
PADDLE_THROW(
paddle::platform::errors::Unavailable("Unsupported data type."));
}
auto result = builder_->aiOnnxOpset11().cast(inputs, to);

tensors_.emplace(outputs[0], result);
} else {
PADDLE_THROW(platform::errors::Unimplemented("Unimplemented op type %s.",
Expand Down
58 changes: 58 additions & 0 deletions paddle/fluid/framework/ipu/popart_canonicalization/tensor_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,69 @@ ir::Node *reshape_handler(ir::Graph *graph, ir::Node *node) {
return new_node_reshape;
}

ir::Node *gather_handler(ir::Graph *graph, ir::Node *node) {
auto new_node_gather = CreateBaseOp(
graph, "Gather", {GetInputNode("X", node), GetInputNode("Index", node)},
{GetOutputNode("Out", node)}, {});
ReplaceNodeOutputs(node, new_node_gather);
ReplaceNodeInputs(node, new_node_gather);
return new_node_gather;
}

ir::Node *squeeze_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto axes_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("axes"));
auto input_shape_ = op->Block()->FindVar(op->Input("X")[0])->GetShape();

std::vector<int64_t> axes{axes_.begin(), axes_.end()};
if (axes_.empty()) {
for (int i = 0; i < input_shape_.size(); i++) {
if (input_shape_[i] == 1) {
axes.push_back(i);
}
}
}
auto new_node_squeeze =
CreateBaseOp(graph, "Squeeze", {GetInputNode("X", node)},
{GetOutputNode("Out", node)}, {{"axes", axes}});
ReplaceNodeOutputs(node, new_node_squeeze);
ReplaceNodeInputs(node, new_node_squeeze);
return new_node_squeeze;
}

ir::Node *cast_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto to_ = BOOST_GET_CONST(int, op->GetAttr("out_dtype"));
auto new_node_cast =
CreateBaseOp(graph, "Cast", {GetInputNode("X", node)},
{GetOutputNode("Out", node)}, {{"to", to_}});
ReplaceNodeOutputs(node, new_node_cast);
ReplaceNodeInputs(node, new_node_cast);
return new_node_cast;
}

ir::Node *lookup_table_handler(ir::Graph *graph, ir::Node *node) {
auto new_node_squeeze =
CreateBaseOp(graph, "Squeeze", {GetInputNode("Ids", node)}, {},
{{"axes", std::vector<int64_t>{-1}}});
ReplaceNodeOutputs(node, new_node_squeeze);

auto new_node_gather = CreateBaseOp(
graph, "Gather", {GetInputNode("W", node), new_node_squeeze->outputs[0]},
{GetOutputNode("Out", node)}, {});
ReplaceNodeInputs(node, new_node_gather);
return new_node_gather;
}

REGISTER_HANDLER(fill_constant, fill_constant_handler);
REGISTER_HANDLER(gaussian_random, gaussian_random_handler);
REGISTER_HANDLER(uniform_random, uniform_random_handler);
REGISTER_HANDLER(transpose2, transpose_handler);
REGISTER_HANDLER(reshape2, reshape_handler);
REGISTER_HANDLER(gather, gather_handler);
REGISTER_HANDLER(squeeze2, squeeze_handler);
REGISTER_HANDLER(cast, cast_handler);
REGISTER_HANDLER(lookup_table, lookup_table_handler);

} // namespace
} // namespace ipu
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/operators/ipu_runtime_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,12 @@ REGISTER_OP_CPU_KERNEL(ipu_runtime, ops::IpuRuntimeKernel<float>,
ops::IpuRuntimeKernel<double>,
ops::IpuRuntimeKernel<int>,
ops::IpuRuntimeKernel<int64_t>,
ops::IpuRuntimeKernel<bool>);
ops::IpuRuntimeKernel<bool>,
ops::IpuRuntimeKernel<paddle::platform::float16>);

REGISTER_OP_IPU_KERNEL(ipu_runtime, ops::IpuRuntimeKernel<float>,
ops::IpuRuntimeKernel<double>,
ops::IpuRuntimeKernel<int>,
ops::IpuRuntimeKernel<int64_t>,
ops::IpuRuntimeKernel<bool>);
ops::IpuRuntimeKernel<bool>,
ops::IpuRuntimeKernel<paddle::platform::float16>);
78 changes: 78 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_cast_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import sys
import paddle
import paddle.fluid as fluid
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestCastNet(unittest.TestCase):
def _test(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

np_image = np.random.rand(1, 3, 1, 5).astype(np.float32)

with paddle.static.program_guard(main_prog, startup_prog):
image = paddle.static.data(
name='image', shape=[1, 3, 1, 5], dtype="float32")
cast = paddle.cast(image, "float16")

if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = [image.name]
fetch_list = [cast.name]
ipu_strategy = compiler.get_ipu_strategy()
ipu_strategy.is_training = False
program = compiler.IpuCompiler(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
#print(program._to_readable_code())
else:
program = main_prog
#print(program._to_readable_code())
result = exe.run(program, feed={"image": np_image}, fetch_list=[cast])
return result[0]

def test_cast(self):
cpu = self._test(False)
print(cpu.shape)
print(cpu)
ipu = self._test(True)
print(ipu.shape)
print(ipu)
self.assertTrue(np.allclose(ipu, cpu, atol=1e-3))


if __name__ == "__main__":
unittest.main()
80 changes: 80 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_gather_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import sys
import paddle
import paddle.fluid as fluid
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestGatherNet(unittest.TestCase):
def _test(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

np_image = np.random.rand(3, 2).astype(np.float32)
np_index = np.array([1, 2]).astype(np.int32)

with paddle.static.program_guard(main_prog, startup_prog):
image = paddle.static.data(
name='image', shape=[3, 2], dtype='float32')
index = paddle.static.data(name='index', shape=[2], dtype='int32')
gather = paddle.fluid.layers.gather(image, index)

if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = [image.name, index.name]
fetch_list = [gather.name]
ipu_strategy = compiler.get_ipu_strategy()
ipu_strategy.is_training = False
program = compiler.IpuCompiler(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
else:
program = main_prog

result = exe.run(program,
feed={"image": np_image,
"index": np_index},
fetch_list=[gather])
return result[0]

def test_gather(self):
cpu = self._test(False)
print(cpu.shape)
ipu = self._test(True)
print(ipu.shape)
self.assertTrue(np.allclose(ipu, cpu, atol=1e-4))


if __name__ == "__main__":
unittest.main()
79 changes: 79 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_lookuptable_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import sys
import paddle
import paddle.fluid as fluid
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestLookupTableNet(unittest.TestCase):
def _test(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

np_image = np.array(
[[[1], [3]], [[2], [4]], [[4], [127]]]).astype(np.int64)

with paddle.static.program_guard(main_prog, startup_prog):
image = paddle.static.data(
name='image', shape=[3, 2, 1], dtype='int32')
lookup = paddle.fluid.layers.embedding(
input=image, size=[128, 16], padding_idx=-1)

if run_ipu:
place = paddle.IPUPlace()
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = [image.name]
fetch_list = [lookup.name]
ipu_strategy = compiler.get_ipu_strategy()
ipu_strategy.is_training = False
program = compiler.IpuCompiler(
main_prog, ipu_strategy=ipu_strategy).compile(feed_list,
fetch_list)
else:
program = main_prog
print(program._to_readable_code())

result = exe.run(program, feed={"image": np_image}, fetch_list=[lookup])
return result[0]

def test_gather(self):
#cpu = self._test(False)
#print(cpu.shape)
#print(cpu)
ipu = self._test(True)
#print(ipu.shape)
#self.assertTrue(np.allclose(ipu, cpu, atol=1e-4))


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit 44d9053

Please sign in to comment.