Skip to content

Commit

Permalink
【Hackathon No.7】为 Paddle 新增 apply API -part (PaddlePaddle#59374)
Browse files Browse the repository at this point in the history
* add tensor apply

* fix

* fix 2023-11-27

* fix

* fix V2

* add apply in Variable

* add apply in newir

* add test

* fix

* fix2

* fix example code

* change shape

* fix docs

* fix docs
  • Loading branch information
yangguohao authored and Wanglongzhi2001 committed Jan 7, 2024
1 parent 0f9f113 commit 64d1ca8
Show file tree
Hide file tree
Showing 6 changed files with 312 additions and 1 deletion.
32 changes: 32 additions & 0 deletions paddle/fluid/pybind/eager_method.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1754,6 +1754,30 @@ static PyObject* tensor__setitem_dygraph(TensorObject* self,
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_apply(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
PyObject* apply_func = PyTuple_GET_ITEM(args, 0);
PyTensorHook func = PyTensorHook(apply_func);
paddle::Tensor out = func(self->tensor);
return ToPyObject(out);
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_apply_(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
EAGER_TRY
PyObject* apply_func = PyTuple_GET_ITEM(args, 0);
PyTensorHook func = PyTensorHook(apply_func);
paddle::Tensor out = func(self->tensor);
self->tensor.set_impl(out.impl());
Py_INCREF(self);
return reinterpret_cast<PyObject*>(self);
EAGER_CATCH_AND_THROW_RETURN_NULL
}

static PyObject* tensor_register_grad_hook(TensorObject* self,
PyObject* args,
PyObject* kwargs) {
Expand Down Expand Up @@ -3167,6 +3191,14 @@ PyMethodDef variable_methods[] = { // NOLINT
(PyCFunction)(void (*)())tensor__setitem_dygraph,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_apply",
(PyCFunction)(void (*)())tensor_apply,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_apply_",
(PyCFunction)(void (*)())tensor_apply_,
METH_VARARGS | METH_KEYWORDS,
nullptr},
{"_register_grad_hook",
(PyCFunction)(void (*)())tensor_register_grad_hook,
METH_VARARGS | METH_KEYWORDS,
Expand Down
35 changes: 34 additions & 1 deletion paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
#include "paddle/fluid/pir/transforms/inplace_pass.h"
#include "paddle/fluid/pir/transforms/replace_fetch_with_shadow_output_pass.h"
#include "paddle/fluid/pir/transforms/shape_optimization_pass.h"

#include "paddle/fluid/pybind/eager_utils.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/pir/core/attribute.h"
#include "paddle/pir/core/block.h"
Expand Down Expand Up @@ -581,6 +581,38 @@ const phi::DDim &GetValueDims(Value value) {
}
}

pir::OpResult apply(Value self, py::object func) {
py::gil_scoped_acquire gil;
auto stop_gradient = self.attribute<BoolAttribute>(kAttrStopGradients);
if (stop_gradient && !stop_gradient.data()) {
PADDLE_THROW(phi::errors::Unavailable(
"Cannot apply function on a tensor that required gradient."));
}
PyObject *py_func = func.release().ptr();
Py_INCREF(py_func);
PyObject *res = nullptr;
try {
py::object obj = py::cast(self);
PyObject *tmp_self = obj.release().ptr();
Py_INCREF(tmp_self);
res = PyObject_CallFunctionObjArgs(py_func, tmp_self, nullptr);
Py_DECREF(tmp_self);
} catch (std::exception &e) {
PADDLE_THROW(phi::errors::Unavailable(
"Apply function of Tensor raises an exception: %s.", e.what()));
} catch (...) {
PADDLE_THROW(phi::errors::Fatal(
"Apply function of Tensor raises an unknown exception."));
}
if (res == Py_None) {
return self.dyn_cast<OpResult>();
}
auto out = CastPyArg2Value(res, "", 0);
Py_DECREF(py_func);
Py_DECREF(res);
return out.dyn_cast<OpResult>();
}

void BindValue(py::module *m) {
py::class_<Value> value(*m, "Value", R"DOC(
Value class represents the SSA value in the IR system. It is a directed edge
Expand Down Expand Up @@ -738,6 +770,7 @@ void BindValue(py::module *m) {
print_stream << ")";
return print_stream.str();
})
.def("apply", &apply)
.def("is_same", &Value::operator==)
.def("hash", [](Value self) { return std::hash<pir::Value>{}(self); })
.def("__repr__", &Value2String);
Expand Down
100 changes: 100 additions & 0 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,104 @@ def gradient(self):
return (np.array(self.grad), np.array(self.grad.rows()))
return np.array(self.grad)

@framework.dygraph_only
def apply_(self, func):
"""
Inplace apply the python function to the tensor.
Returns:
None
Examples:
.. code-block:: python
>>> import paddle
>>> x = paddle.to_tensor([[0.3, 0.5, 0.1],
>>> [0.9, 0.9, 0.7],
>>> [0.4, 0.8, 0.2]]).to("cpu", "float64")
>>> f = lambda x: 3*x+2
>>> x.apply_(f)
>>> print(x)
Tensor(shape=[3, 3], dtype=float64, place=Place(cpu), stop_gradient=True,
[[2.90000004, 3.50000000, 2.30000000],
[4.69999993, 4.69999993, 4.09999996],
[3.20000002, 4.40000004, 2.60000001]])
>>> x = paddle.to_tensor([[0.3, 0.5, 0.1],
>>> [0.9, 0.9, 0.7],
>>> [0.4, 0.8, 0.2]]).to("cpu", "float16")
>>> x.apply_(f)
>>> x = paddle.to_tensor([[0.3, 0.5, 0.1],
>>> [0.9, 0.9, 0.7],
>>> [0.4, 0.8, 0.2]]).to("cpu", "bfloat16")
>>> x.apply_(f)
>>> if paddle.is_compiled_with_cuda():
>>> x = paddle.to_tensor([[0.3, 0.5, 0.1],
>>> [0.9, 0.9, 0.7],
>>> [0.4, 0.8, 0.2]]).to("gpu", "float32")
>>> x.apply_(f)
"""
if not self.stop_gradient:
raise RuntimeError(
"Cannot apply function on a tensor that required gradient."
)
return self._apply_(func)

def apply(self, func):
"""
Apply the python function to the tensor.
Returns:
None
Examples:
.. code-block:: python
>>> import paddle
>>> x = paddle.to_tensor([[0.3, 0.5, 0.1],
>>> [0.9, 0.9, 0.7],
>>> [0.4, 0.8, 0.2]]).to("cpu", "float64")
>>> f = lambda x: 3*x+2
>>> y = x.apply(f)
>>> print(y)
Tensor(shape=[3, 3], dtype=float64, place=Place(cpu), stop_gradient=True,
[[2.90000004, 3.50000000, 2.30000000],
[4.69999993, 4.69999993, 4.09999996],
[3.20000002, 4.40000004, 2.60000001]])
>>> x = paddle.to_tensor([[0.3, 0.5, 0.1],
>>> [0.9, 0.9, 0.7],
>>> [0.4, 0.8, 0.2]]).to("cpu", "float16")
>>> y = x.apply(f)
>>> x = paddle.to_tensor([[0.3, 0.5, 0.1],
>>> [0.9, 0.9, 0.7],
>>> [0.4, 0.8, 0.2]]).to("cpu", "bfloat16")
>>> y = x.apply(f)
>>> if paddle.is_compiled_with_cuda():
>>> x = paddle.to_tensor([[0.3, 0.5, 0.1],
>>> [0.9, 0.9, 0.7],
>>> [0.4, 0.8, 0.2]]).to("gpu", "float32")
>>> y = x.apply(f)
"""
if not self.stop_gradient:
raise RuntimeError(
"Cannot apply function on a tensor that required gradient."
)
return self._apply(func)

@framework.dygraph_only
def register_hook(self, hook):
"""
Expand Down Expand Up @@ -1142,6 +1240,8 @@ def coalesce(self, name=None):
("clear_grad", clear_grad),
("inplace_version", inplace_version),
("gradient", gradient),
("apply_", apply_),
("apply", apply),
("register_hook", register_hook),
("__str__", __str__),
("__repr__", __str__),
Expand Down
10 changes: 10 additions & 0 deletions python/paddle/base/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1863,6 +1863,16 @@ def forward_hook_wrapper(x):
skip_vars_in_backward_input=[self],
)

def apply(self, func):
if not self.stop_gradient:
raise RuntimeError(
"Cannot apply function on a tensor that required gradient."
)
try:
return func(self)
except:
raise ValueError(f"The PyFunc {func.__name__} could not be applied")

def __str__(self):
return self._to_readable_code()

Expand Down
105 changes: 105 additions & 0 deletions test/legacy_test/test_apply.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

import paddle


class TestTensorApplyAPI(unittest.TestCase):
def setUp(self):
self.x = paddle.to_tensor([1, 2, 3, 4, 5], stop_gradient=True)
self.function = lambda x: 3 * x + 2

def test_dtype(self):
for dtype in ["float64", "float16", "bfloat16"]:
self.x.to(dtype)
self.test_dygraph()

@unittest.skipIf(
not paddle.is_compiled_with_cuda(),
"only support cuda",
)
def test_on_gpu(self):
self.x.to("gpu")
self.test_dygraph()

def test_dygraph(self):
y = self.x.apply(self.function)
np.testing.assert_allclose(
self.function(self.x).numpy(), y.numpy(), rtol=1e-05
)

def test_error(self):
self.x.stop_gradient = False

def fn_inplace(x):
x.apply_(self.function)

def fn_outplace(x, func):
x.apply(func)

def function(x, y, z):
return x + y + z

self.assertRaises(RuntimeError, fn_inplace, self.x)
self.assertRaises(RuntimeError, fn_outplace, self.x, self.function)
with paddle.jit.api.sot_mode_guard(False):
self.assertRaises(
RuntimeError,
paddle.jit.to_static(fn_outplace),
self.x,
self.function,
)
self.x.stop_gradient = True
self.assertRaises(
ValueError,
paddle.jit.to_static(fn_outplace),
self.x,
function,
)
self.x.stop_gradient = False
with paddle.pir_utils.IrGuard():
paddle.disable_static()
self.assertRaises(
RuntimeError,
paddle.jit.to_static(fn_outplace),
self.x,
self.function,
)

def test_to_static(self):
def fn(x, func):
y = x.apply(func)
return y

with paddle.jit.api.sot_mode_guard(False):
jit_g = paddle.jit.to_static(fn)
out_legacy_ir = jit_g(self.x, self.function)
with paddle.pir_utils.IrGuard():
paddle.disable_static()
jit_g = paddle.jit.to_static(fn)
out_pir = jit_g(self.x, self.function)
np.testing.assert_allclose(
self.function(self.x).numpy(), out_legacy_ir.numpy(), rtol=1e-05
)
np.testing.assert_allclose(
self.function(self.x).numpy(), out_pir.numpy(), rtol=1e-05
)


if __name__ == "__main__":
unittest.main()
31 changes: 31 additions & 0 deletions test/legacy_test/test_inplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1680,5 +1680,36 @@ def test_backward_error(self):
loss.backward()


class TestDygraphTensorApplyInplace(unittest.TestCase):
def setUp(self):
self.init_data()
self.set_np_compare_func()

def init_data(self):
self.input_var_numpy = np.random.uniform(-5, 5, [10, 20, 1])
self.dtype = "float32"

def set_np_compare_func(self):
self.np_compare = np.array_equal

def non_inplace_api_processing(self, var, f):
return var.apply(f)

def inplace_api_processing(self, var, f):
return var.apply_(f)

def test_inplace_api(self):
var = paddle.to_tensor(self.input_var_numpy, stop_gradient=True).astype(
self.dtype
)
f = lambda x: 3 * x + 2
non_inplace_var = self.non_inplace_api_processing(var, f)
inplace_var = self.inplace_api_processing(var, f)
self.assertTrue(id(var) == id(inplace_var))
np.testing.assert_array_equal(
non_inplace_var.numpy(), inplace_var.numpy()
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 64d1ca8

Please sign in to comment.