From 5188ef51df908eed979e7eaf2c1a22a8f2901196 Mon Sep 17 00:00:00 2001 From: AyaseNana <49900969+NKNaN@users.noreply.github.com> Date: Sun, 7 Apr 2024 10:56:19 +0800 Subject: [PATCH] API improvement for paddle.linalg.svd_lowrank (#62876) * add svd lowrank api * add test * fix param M * fix test timeout * update docs --- python/paddle/linalg.py | 2 + python/paddle/tensor/__init__.py | 2 + python/paddle/tensor/linalg.py | 231 +++++++++++++++++--------- test/legacy_test/test_svd_lowrank.py | 239 +++++++++++++++++++++++++++ 4 files changed, 398 insertions(+), 76 deletions(-) create mode 100644 test/legacy_test/test_svd_lowrank.py diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 4ba51b20ba5dc..e83aedb740907 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -40,6 +40,7 @@ slogdet, solve, svd, + svd_lowrank, triangular_solve, vector_norm, ) @@ -61,6 +62,7 @@ 'qr', 'householder_product', 'pca_lowrank', + 'svd_lowrank', 'lu', 'lu_unpack', 'matrix_exp', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 4513bcbdba8f8..3afdca0fb21ce 100644 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -93,6 +93,7 @@ qr, solve, svd, + svd_lowrank, t, t_, transpose, @@ -467,6 +468,7 @@ 'qr', 'householder_product', 'pca_lowrank', + 'svd_lowrank', 'eigvals', 'eigvalsh', 'abs', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index 09030f9608f88..e88825c54390b 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -2531,6 +2531,161 @@ def svd(x, full_matrices=False, name=None): return u, s, vh +def _conjugate(x): + if x.is_complex(): + return x.conj() + return x + + +def _transpose(x): + shape = x.shape + perm = list(range(0, len(shape))) + perm = perm[:-2] + [perm[-1]] + [perm[-2]] + return paddle.transpose(x, perm) + + +def _transjugate(x): + return _conjugate(_transpose(x)) + + +def _get_approximate_basis(x, q, niter=2, M=None): + niter = 2 if niter is None else niter + m, n = x.shape[-2:] + qr = paddle.linalg.qr + + R = paddle.randn((n, q), dtype=x.dtype) + + A_t = _transpose(x) + A_H = _conjugate(A_t) + if M is None: + Q = qr(paddle.matmul(x, R))[0] + for i in range(niter): + Q = qr(paddle.matmul(A_H, Q))[0] + Q = qr(paddle.matmul(x, Q))[0] + else: + M_H = _transjugate(M) + Q = qr(paddle.matmul(x, R) - paddle.matmul(M, R))[0] + for i in range(niter): + Q = qr(paddle.matmul(A_H, Q) - paddle.matmul(M_H, Q))[0] + Q = qr(paddle.matmul(x, Q) - paddle.matmul(M, Q))[0] + + return Q + + +def svd_lowrank(x, q=None, niter=2, M=None, name=None): + r""" + Return the singular value decomposition (SVD) on a low-rank matrix or batches of such matrices. + + If :math:`X` is the input matrix or a batch of input matrices, the output should satisfies: + + .. math:: + X \approx U * diag(S) * V^{T} + + When :math:`M` is given, the output should satisfies: + + .. math:: + X - M \approx U * diag(S) * V^{T} + + Args: + x (Tensor): The input tensor. Its shape should be `[..., N, M]`, where `...` is + zero or more batch dimensions. N and M can be arbitrary positive number. + The data type of ``x`` should be float32 or float64. + q (int, optional): A slightly overestimated rank of :math:`X`. + Default value is None, which means the overestimated rank is 6. + niter (int, optional): The number of iterations to perform. Default: 2. + M (Tensor, optional): The input tensor's mean. Its shape should be `[..., 1, M]`. + Default value is None. + name (str, optional): Name for the operation. For more information, please + refer to :ref:`api_guide_Name`. Default: None. + + Returns: + - Tensor U, is N x q matrix. + - Tensor S, is a vector with length q. + - Tensor V, is M x q matrix. + + tuple (U, S, V): which is the nearly optimal approximation of a singular value decomposition of the matrix :math:`X` or :math:`X - M`. + + Examples: + .. code-block:: python + + >>> import paddle + >>> paddle.seed(2024) + + >>> x = paddle.randn((5, 5), dtype='float64') + >>> U, S, V = paddle.linalg.svd_lowrank(x) + >>> print(U) + Tensor(shape=[5, 5], dtype=float64, place=Place(cpu), stop_gradient=True, + [[-0.03586982, -0.17211503, 0.31536566, -0.38225676, -0.85059629], + [-0.38386839, 0.67754925, 0.23222694, 0.51777188, -0.26749766], + [-0.85977150, -0.28442378, -0.41412094, -0.08955629, -0.01948348], + [ 0.18611503, 0.56047358, -0.67717019, -0.39286761, -0.19577062], + [ 0.27841082, -0.34099254, -0.46535957, 0.65071250, -0.40770727]]) + + >>> print(S) + Tensor(shape=[5], dtype=float64, place=Place(cpu), stop_gradient=True, + [4.11253399, 3.03227120, 2.45499752, 1.25602436, 0.45825337]) + + >>> print(V) + Tensor(shape=[5, 5], dtype=float64, place=Place(cpu), stop_gradient=True, + [[ 0.46401347, 0.50977695, -0.08742316, -0.11140428, -0.71046833], + [-0.48927226, -0.35047624, 0.07918771, 0.45431083, -0.65200463], + [-0.20494730, 0.67097011, -0.05427719, 0.66510472, 0.24997083], + [-0.69645001, 0.40237917, 0.09360970, -0.58032322, -0.08666357], + [ 0.13512270, 0.07199989, 0.98710572, 0.04529277, 0.01134594]]) + """ + if not paddle.is_tensor(x): + raise ValueError(f'Input must be tensor, but got {type(x)}') + + m, n = x.shape[-2:] + if q is None: + q = min(6, m, n) + elif not (q >= 0 and q <= min(m, n)): + raise ValueError( + f'q(={q}) must be non-negative integer' + f' and not greater than min(m, n)={min(m, n)}' + ) + + if not (niter >= 0): + raise ValueError(f'niter(={niter}) must be non-negative integer') + + if M is None: + M_t = None + else: + M = M.broadcast_to(x.shape) + M_t = _transpose(M) + A_t = _transpose(x) + + if m < n or n > q: + Q = _get_approximate_basis(A_t, q, niter=niter, M=M_t) + Q_c = _conjugate(Q) + if M is None: + B_t = paddle.matmul(x, Q_c) + else: + B_t = paddle.matmul(x, Q_c) - paddle.matmul(M, Q_c) + assert B_t.shape[-2] == m, (B_t.shape, m) + assert B_t.shape[-1] == q, (B_t.shape, q) + assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape + U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False) + V = _transjugate(Vh) + V = Q.matmul(V) + else: + Q = _get_approximate_basis(x, q, niter=niter, M=M) + Q_c = _conjugate(Q) + if M is None: + B = paddle.matmul(A_t, Q_c) + else: + B = paddle.matmul(A_t, Q_c) - paddle.matmul(M_t, Q_c) + B_t = _transpose(B) + assert B_t.shape[-2] == q, (B_t.shape, q) + assert B_t.shape[-1] == n, (B_t.shape, n) + assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape + U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False) + V = _transjugate(Vh) + U = Q.matmul(U) + + return U, S, V + + def pca_lowrank(x, q=None, center=True, niter=2, name=None): r""" Performs linear Principal Component Analysis (PCA) on a low-rank matrix or batches of such matrices. @@ -2588,82 +2743,6 @@ def pca_lowrank(x, q=None, center=True, niter=2, name=None): [-0.67131070, -0.19071018, 0.07795789, -0.04615811, 0.71046714]]) """ - def conjugate(x): - if x.is_complex(): - return x.conj() - return x - - def transpose(x): - shape = x.shape - perm = list(range(0, len(shape))) - perm = perm[:-2] + [perm[-1]] + [perm[-2]] - return paddle.transpose(x, perm) - - def transjugate(x): - return conjugate(transpose(x)) - - def get_approximate_basis(x, q, niter=2, M=None): - niter = 2 if niter is None else niter - m, n = x.shape[-2:] - qr = paddle.linalg.qr - - R = paddle.randn((n, q), dtype=x.dtype) - - A_t = transpose(x) - A_H = conjugate(A_t) - if M is None: - Q = qr(paddle.matmul(x, R))[0] - for i in range(niter): - Q = qr(paddle.matmul(A_H, Q))[0] - Q = qr(paddle.matmul(x, Q))[0] - else: - M_H = transjugate(M) - Q = qr(paddle.matmul(x, R) - paddle.matmul(M, R))[0] - for i in range(niter): - Q = qr(paddle.matmul(A_H, Q) - paddle.matmul(M_H, Q))[0] - Q = qr(paddle.matmul(x, Q) - paddle.matmul(M, Q))[0] - - return Q - - def svd_lowrank(x, q=6, niter=2, M=None): - q = 6 if q is None else q - m, n = x.shape[-2:] - if M is None: - M_t = None - else: - M_t = transpose(M) - A_t = transpose(x) - - if m < n or n > q: - Q = get_approximate_basis(A_t, q, niter=niter, M=M_t) - Q_c = conjugate(Q) - if M is None: - B_t = paddle.matmul(x, Q_c) - else: - B_t = paddle.matmul(x, Q_c) - paddle.matmul(M, Q_c) - assert B_t.shape[-2] == m, (B_t.shape, m) - assert B_t.shape[-1] == q, (B_t.shape, q) - assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape - U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False) - V = transjugate(Vh) - V = Q.matmul(V) - else: - Q = get_approximate_basis(x, q, niter=niter, M=M) - Q_c = conjugate(Q) - if M is None: - B = paddle.matmul(A_t, Q_c) - else: - B = paddle.matmul(A_t, Q_c) - paddle.matmul(M_t, Q_c) - B_t = transpose(B) - assert B_t.shape[-2] == q, (B_t.shape, q) - assert B_t.shape[-1] == n, (B_t.shape, n) - assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape - U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False) - V = transjugate(Vh) - U = Q.matmul(U) - - return U, S, V - if not paddle.is_tensor(x): raise ValueError(f'Input must be tensor, but got {type(x)}') diff --git a/test/legacy_test/test_svd_lowrank.py b/test/legacy_test/test_svd_lowrank.py new file mode 100644 index 0000000000000..acdcb81b50b54 --- /dev/null +++ b/test/legacy_test/test_svd_lowrank.py @@ -0,0 +1,239 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle + + +class TestSvdLowrankAPI(unittest.TestCase): + def transpose(self, x): + shape = x.shape + perm = list(range(0, len(shape))) + perm = perm[:-2] + [perm[-1]] + [perm[-2]] + return paddle.transpose(x, perm) + + def random_matrix(self, rows, columns, *batch_dims, **kwargs): + dtype = kwargs.get('dtype', paddle.float64) + + x = paddle.randn(batch_dims + (rows, columns), dtype=dtype) + if x.numel() == 0: + return x + u, _, vh = paddle.linalg.svd(x, full_matrices=False) + k = min(rows, columns) + s = paddle.linspace(1 / (k + 1), 1, k, dtype=dtype) + return (u * s.unsqueeze(-2)) @ vh + + def random_lowrank_matrix(self, rank, rows, columns, *batch_dims, **kwargs): + B = self.random_matrix(rows, rank, *batch_dims, **kwargs) + C = self.random_matrix(rank, columns, *batch_dims, **kwargs) + return B.matmul(C) + + def run_subtest( + self, guess_rank, actual_rank, matrix_size, batches, svd, **options + ): + if isinstance(matrix_size, int): + rows = columns = matrix_size + else: + rows, columns = matrix_size + a_input = self.random_lowrank_matrix( + actual_rank, rows, columns, *batches + ) + a = a_input + m = a_input.mean(axis=-2, keepdim=True) + + u, s, v = svd(a_input - m, q=guess_rank, **options) + + self.assertEqual(s.shape[-1], guess_rank) + self.assertEqual(u.shape[-2], rows) + self.assertEqual(u.shape[-1], guess_rank) + self.assertEqual(v.shape[-1], guess_rank) + self.assertEqual(v.shape[-2], columns) + + A1 = u.matmul(paddle.diag_embed(s)).matmul(self.transpose(v)) + ones_m1 = paddle.ones(batches + (rows, 1), dtype=a.dtype) + c = a.sum(axis=-2) / rows + c = c.reshape(batches + (1, columns)) + A2 = a - ones_m1.matmul(c) + np.testing.assert_allclose(A1.numpy(), A2.numpy(), atol=1e-5) + + detect_rank = (s.abs() > 1e-5).sum(axis=-1) + left = actual_rank * paddle.ones(batches, dtype=paddle.int64) + if not left.shape: + np.testing.assert_allclose(int(left), int(detect_rank)) + else: + np.testing.assert_allclose(left.numpy(), detect_rank.numpy()) + S = paddle.linalg.svd(A2, full_matrices=False)[1] + left = s[..., :actual_rank] + right = S[..., :actual_rank] + np.testing.assert_allclose(left.numpy(), right.numpy()) + + def test_forward(self): + svd_lowrank = paddle.linalg.svd_lowrank + all_batches = [(), (1,), (3,), (2, 3)] + for actual_rank, size in [ + (2, (17, 4)), + (6, (100, 40)), + ]: + for batches in all_batches: + for guess_rank in [ + actual_rank, + actual_rank + 2, + actual_rank + 6, + ]: + if guess_rank <= min(*size): + self.run_subtest( + guess_rank, actual_rank, size, batches, svd_lowrank + ) + self.run_subtest( + guess_rank, + actual_rank, + size[::-1], + batches, + svd_lowrank, + ) + x = np.random.randn(5, 5).astype('float64') + x = paddle.to_tensor(x) + q = None + U, S, V = svd_lowrank(x, q) + + def test_errors(self): + svd_lowrank = paddle.linalg.svd_lowrank + x = np.random.randn(5, 5).astype('float64') + x = paddle.to_tensor(x) + + def test_x_not_tensor(): + U, S, V = svd_lowrank(x.numpy()) + + self.assertRaises(ValueError, test_x_not_tensor) + + def test_q_range(): + q = -1 + U, S, V = svd_lowrank(x, q) + + self.assertRaises(ValueError, test_q_range) + + def test_niter_range(): + n = -1 + U, S, V = svd_lowrank(x, niter=n) + + self.assertRaises(ValueError, test_niter_range) + + +class TestStaticSvdLowrankAPI(unittest.TestCase): + def transpose(self, x): + shape = x.shape + perm = list(range(0, len(shape))) + perm = perm[:-2] + [perm[-1]] + [perm[-2]] + return paddle.transpose(x, perm) + + def random_matrix(self, rows, columns, *batch_dims, **kwargs): + dtype = kwargs.get('dtype', 'float64') + + x = paddle.randn(batch_dims + (rows, columns), dtype=dtype) + u, _, vh = paddle.linalg.svd(x, full_matrices=False) + k = min(rows, columns) + s = paddle.linspace(1 / (k + 1), 1, k, dtype=dtype) + return (u * s.unsqueeze(-2)) @ vh + + def random_lowrank_matrix(self, rank, rows, columns, *batch_dims, **kwargs): + B = self.random_matrix(rows, rank, *batch_dims, **kwargs) + C = self.random_matrix(rank, columns, *batch_dims, **kwargs) + return B.matmul(C) + + def run_subtest( + self, guess_rank, actual_rank, matrix_size, batches, svd, **options + ): + main = paddle.static.Program() + startup = paddle.static.Program() + with paddle.static.program_guard(main, startup): + if isinstance(matrix_size, int): + rows = columns = matrix_size + else: + rows, columns = matrix_size + a_input = self.random_lowrank_matrix( + actual_rank, rows, columns, *batches + ) + a = a_input + m = a_input.mean(axis=-2, keepdim=True) + + u, s, v = svd(a_input, q=guess_rank, M=m, **options) + + self.assertEqual(s.shape[-1], guess_rank) + self.assertEqual(u.shape[-2], rows) + self.assertEqual(u.shape[-1], guess_rank) + self.assertEqual(v.shape[-1], guess_rank) + self.assertEqual(v.shape[-2], columns) + + A1 = u.matmul(paddle.diag_embed(s)).matmul(self.transpose(v)) + ones_m1 = paddle.ones(batches + (rows, 1), dtype=a.dtype) + c = a.sum(axis=-2) / rows + c = c.reshape(batches + (1, columns)) + A2 = a - ones_m1.matmul(c) + detect_rank = (s.abs() > 1e-5).sum(axis=-1) + left1 = actual_rank * paddle.ones(batches, dtype=paddle.int64) + S = paddle.linalg.svd(A2, full_matrices=False)[1] + left2 = s[..., :actual_rank] + right = S[..., :actual_rank] + + exe = paddle.static.Executor() + exe.run(startup) + A1, A2, left1, detect_rank, left2, right = exe.run( + main, + feed={}, + fetch_list=[A1, A2, left1, detect_rank, left2, right], + ) + + np.testing.assert_allclose(A1, A2, atol=1e-5) + if not left1.shape: + np.testing.assert_allclose(int(left1), int(detect_rank)) + else: + np.testing.assert_allclose(left1, detect_rank) + np.testing.assert_allclose(left2, right) + + def test_forward(self): + with paddle.pir_utils.IrGuard(): + svd_lowrank = paddle.linalg.svd_lowrank + all_batches = [(), (1,), (3,), (2, 3)] + for actual_rank, size in [ + (2, (17, 4)), + (6, (100, 40)), + ]: + for batches in all_batches: + for guess_rank in [ + actual_rank, + actual_rank + 2, + actual_rank + 6, + ]: + if guess_rank <= min(*size): + self.run_subtest( + guess_rank, + actual_rank, + size, + batches, + svd_lowrank, + ) + self.run_subtest( + guess_rank, + actual_rank, + size[::-1], + batches, + svd_lowrank, + ) + + +if __name__ == "__main__": + unittest.main()