Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Prim]Support mean_grad decompose in vjp #64346

Merged
merged 3 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@
'instance_norm_grad',
'layer_norm_grad',
'leaky_relu_grad',
'mean_grad',
'minimum_grad',
'pow_grad',
'relu_grad',
Expand Down
47 changes: 47 additions & 0 deletions paddle/fluid/primitive/rule/vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,53 @@ void sum_grad(const Tensor& x,
set_output<T>(x_grad_tmp, x_grad);
}

template <typename T>
void mean_grad(const Tensor& x,
const Tensor& out_grad,
const IntArray& axis,
bool keepdim,
bool reduce_all,
Tensor* x_grad) {
if (!x_grad) {
return;
}
Tensor x_grad_tmp;
sum_grad<T>(x, out_grad, axis, keepdim, reduce_all, &x_grad_tmp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sum_grad 尚未适配动态shape,

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

后续需统一考虑支持


Tensor div_factor = [&] {
Tensor factor_tensor;
auto axis_data = axis.GetData();
const std::vector<int64_t> x_dim = x.shape();
if (axis.size() == 0) {
for (size_t i = 0; i < x_dim.size(); ++i) {
axis_data.push_back(i);
}
}
if (has_dynamic_shape(x_dim, axis_data)) {
auto x_shape = shape<T>(x);
factor_tensor =
slice<T>(x_shape, {0}, {axis_data[0]}, {axis_data[0] + 1}, {1}, {0});
for (size_t i = 1; i < axis_data.size(); ++i) {
factor_tensor =
factor_tensor *
slice<T>(
x_shape, {0}, {axis_data[i]}, {axis_data[i] + 1}, {1}, {0});
}
factor_tensor = cast<T>(factor_tensor, x.dtype());
} else {
int64_t factor = 1;
for (int64_t idx : axis_data) {
if (idx < 0) idx += x_dim.size();
factor *= x_dim[idx];
}
factor_tensor = full<T>(std::vector<int64_t>{}, factor, x.dtype());
}
return factor_tensor;
}();

set_output<T>(x_grad_tmp / div_factor, x_grad);
}

template <typename T>
void gelu_grad(const Tensor& x,
const Tensor& out_grad,
Expand Down
25 changes: 20 additions & 5 deletions paddle/fluid/primitive/utils/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,27 @@ static bool find_value(const std::vector<int64_t>& vec, int64_t value) {
}
}

static bool has_dynamic_shape(const std::vector<int64_t>& vec) {
if (std::find(vec.begin(), vec.end(), -1) != vec.end()) {
return true;
} else {
return false;
static bool has_dynamic_shape(const std::vector<int64_t>& shape) {
return std::find(shape.begin(), shape.end(), -1) != shape.end();
}

static bool has_dynamic_shape(const std::vector<int64_t>& shape,
const std::vector<int64_t>& axis) {
bool flag = false;
const int64_t rank = shape.size();
for (int64_t idx : axis) {
if (idx < 0) idx += rank;
PADDLE_ENFORCE_LT(
idx,
rank,
::common::errors::PreconditionNotMet(
"Required idx < shape.size(), but received %d.", idx));
if (shape[idx] == -1) {
flag = true;
break;
}
}
return flag;
}

} // namespace primitive
Expand Down
108 changes: 105 additions & 3 deletions test/legacy_test/test_mean_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,34 +44,54 @@ class TestMeanOp(OpTest):
def setUp(self):
self.op_type = "mean"
self.python_api = paddle.mean
self.public_python_api = paddle.mean
self.dtype = np.float64
self.init_dtype_type()
self.init_prim_type()
self.inputs = {'X': np.random.random((10, 10)).astype(self.dtype)}
self.outputs = {'Out': np.mean(self.inputs["X"])}

def init_prim_type(self):
self.prim_op_type = "comp"

def init_dtype_type(self):
pass

def test_check_output(self):
self.check_output(check_pir=True)

def test_checkout_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestMeanOpPrim(TestMeanOp):
def init_prim_type(self):
self.prim_op_type = "prim"


class TestMeanOp_ZeroDim(OpTest):
def setUp(self):
self.op_type = "mean"
self.python_api = paddle.mean
self.dtype = np.float64
self.public_python_api = paddle.mean
self.init_prim_type()
self.inputs = {'X': np.random.random([]).astype(self.dtype)}
self.outputs = {'Out': np.mean(self.inputs["X"])}

def init_prim_type(self):
self.prim_op_type = "comp"

def test_check_output(self):
self.check_output(check_pir=True)

def test_checkout_grad(self):
self.check_grad(['X'], 'Out', check_pir=True)
self.check_grad(['X'], 'Out', check_pir=True, check_prim_pir=True)


class TestMeanOp_ZeroDim_Prim(TestMeanOp_ZeroDim):
def init_prim_type(self):
self.prim_op_type = "prim"


class TestMeanOpError(unittest.TestCase):
Expand Down Expand Up @@ -161,7 +181,7 @@ def setUp(self):
self.op_type = 'reduce_mean'
self.python_api = reduce_mean_wrapper
self.public_python_api = reduce_mean_wrapper
self.prim_op_type = "comp"
self.init_prim_type()
self.dtype = 'float64'
self.init_shapes()
self.axis = [0]
Expand All @@ -186,6 +206,9 @@ def setUp(self):
'reduce_all': self.reduce_all,
}

def init_prim_type(self):
self.prim_op_type = "comp"

def init_shapes(self):
self.shape = [2, 3, 4, 5]

Expand Down Expand Up @@ -231,6 +254,43 @@ def test_check_grad(self):
)


class TestReduceMeanOpPrim(TestReduceMeanOp):
def init_prim_type(self):
self.prim_op_type = "prim"

@test_with_pir_api
def test_check_output(self):
if self.dtype != 'float16':
self.check_output(check_prim_pir=True, check_pir=True)
else:
place = paddle.CUDAPlace(0)
self.check_output_with_place(
place=place,
check_prim_pir=True,
check_pir=True,
)

@test_with_pir_api
def test_check_grad(self):
if self.dtype != 'float16':
self.check_grad(
['X'],
['Out'],
check_prim_pir=True,
check_pir=True,
)
else:
place = paddle.CUDAPlace(0)
self.check_grad_with_place(
place,
['X'],
['Out'],
numeric_grad_delta=0.5,
check_prim_pir=True,
check_pir=True,
)


class TestReduceMeanOp_ZeroDim(TestReduceMeanOp):
def init_shapes(self):
self.shape = []
Expand Down Expand Up @@ -306,16 +366,41 @@ def setUp(self):
self.outputs = {'Out': out_np}


class TestReduceMeanOpDefaultAttrsForPrim(TestReduceMeanOpPrim):
def setUp(self):
self.op_type = 'reduce_mean'
self.python_api = reduce_mean_wrapper
self.public_python_api = reduce_mean_wrapper
self.init_prim_type()
self.dtype = 'float64'
self.shape = [2, 3, 4, 5]

x_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out_np = np.mean(x_np, axis=0)
self.inputs = {'X': x_np}
self.outputs = {'Out': out_np}


class TestReduceMeanOpFloat32(TestReduceMeanOp):
def set_attrs(self):
self.dtype = 'float32'


class TestReduceMeanOpFloat32Prim(TestReduceMeanOpPrim):
def set_attrs(self):
self.dtype = 'float32'


class TestReduceMeanOpFloat16(TestReduceMeanOp):
def set_attrs(self):
self.dtype = 'float16'


class TestReduceMeanOpFloat16Prim(TestReduceMeanOpPrim):
def set_attrs(self):
self.dtype = 'float16'


class TestReduceMeanOpShape1D(TestReduceMeanOp):
def set_attrs(self):
self.shape = [100]
Expand Down Expand Up @@ -348,12 +433,23 @@ def set_attrs(self):
self.axis = [0, 1, 2, 3]


class TestReduceMeanOpAxisAllPrim(TestReduceMeanOpPrim):
def set_attrs(self):
self.axis = [0, 1, 2, 3]


class TestReduceMeanOpAxisAllFP16(TestReduceMeanOp):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
self.dtype = 'float16'


class TestReduceMeanOpAxisAllFP16Prim(TestReduceMeanOpPrim):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
self.dtype = 'float16'


class TestReduceMeanOpAxisAllBF16(TestReduceMeanBF16Op):
def set_attrs(self):
self.axis = [0, 1, 2, 3]
Expand Down Expand Up @@ -386,6 +482,12 @@ def set_attrs(self):
self.dtype = 'float16'


class TestReduceMeanOpAxisNegativeFP16Prim(TestReduceMeanOpPrim):
def set_attrs(self):
self.axis = [-2, -1]
self.dtype = 'float16'


class TestReduceMeanOpAxisNegativeBF16(TestReduceMeanBF16Op):
def set_attrs(self):
self.axis = [-2, -1]
Expand Down