Skip to content

Commit

Permalink
API improvement for paddle.linalg.svd_lowrank (#62876)
Browse files Browse the repository at this point in the history
* add svd lowrank api

* add test

* fix param M

* fix test timeout

* update docs
  • Loading branch information
NKNaN authored Apr 7, 2024
1 parent 91e4455 commit 5188ef5
Show file tree
Hide file tree
Showing 4 changed files with 398 additions and 76 deletions.
2 changes: 2 additions & 0 deletions python/paddle/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
slogdet,
solve,
svd,
svd_lowrank,
triangular_solve,
vector_norm,
)
Expand All @@ -61,6 +62,7 @@
'qr',
'householder_product',
'pca_lowrank',
'svd_lowrank',
'lu',
'lu_unpack',
'matrix_exp',
Expand Down
2 changes: 2 additions & 0 deletions python/paddle/tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@
qr,
solve,
svd,
svd_lowrank,
t,
t_,
transpose,
Expand Down Expand Up @@ -467,6 +468,7 @@
'qr',
'householder_product',
'pca_lowrank',
'svd_lowrank',
'eigvals',
'eigvalsh',
'abs',
Expand Down
231 changes: 155 additions & 76 deletions python/paddle/tensor/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2531,6 +2531,161 @@ def svd(x, full_matrices=False, name=None):
return u, s, vh


def _conjugate(x):
if x.is_complex():
return x.conj()
return x


def _transpose(x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
return paddle.transpose(x, perm)


def _transjugate(x):
return _conjugate(_transpose(x))


def _get_approximate_basis(x, q, niter=2, M=None):
niter = 2 if niter is None else niter
m, n = x.shape[-2:]
qr = paddle.linalg.qr

R = paddle.randn((n, q), dtype=x.dtype)

A_t = _transpose(x)
A_H = _conjugate(A_t)
if M is None:
Q = qr(paddle.matmul(x, R))[0]
for i in range(niter):
Q = qr(paddle.matmul(A_H, Q))[0]
Q = qr(paddle.matmul(x, Q))[0]
else:
M_H = _transjugate(M)
Q = qr(paddle.matmul(x, R) - paddle.matmul(M, R))[0]
for i in range(niter):
Q = qr(paddle.matmul(A_H, Q) - paddle.matmul(M_H, Q))[0]
Q = qr(paddle.matmul(x, Q) - paddle.matmul(M, Q))[0]

return Q


def svd_lowrank(x, q=None, niter=2, M=None, name=None):
r"""
Return the singular value decomposition (SVD) on a low-rank matrix or batches of such matrices.
If :math:`X` is the input matrix or a batch of input matrices, the output should satisfies:
.. math::
X \approx U * diag(S) * V^{T}
When :math:`M` is given, the output should satisfies:
.. math::
X - M \approx U * diag(S) * V^{T}
Args:
x (Tensor): The input tensor. Its shape should be `[..., N, M]`, where `...` is
zero or more batch dimensions. N and M can be arbitrary positive number.
The data type of ``x`` should be float32 or float64.
q (int, optional): A slightly overestimated rank of :math:`X`.
Default value is None, which means the overestimated rank is 6.
niter (int, optional): The number of iterations to perform. Default: 2.
M (Tensor, optional): The input tensor's mean. Its shape should be `[..., 1, M]`.
Default value is None.
name (str, optional): Name for the operation. For more information, please
refer to :ref:`api_guide_Name`. Default: None.
Returns:
- Tensor U, is N x q matrix.
- Tensor S, is a vector with length q.
- Tensor V, is M x q matrix.
tuple (U, S, V): which is the nearly optimal approximation of a singular value decomposition of the matrix :math:`X` or :math:`X - M`.
Examples:
.. code-block:: python
>>> import paddle
>>> paddle.seed(2024)
>>> x = paddle.randn((5, 5), dtype='float64')
>>> U, S, V = paddle.linalg.svd_lowrank(x)
>>> print(U)
Tensor(shape=[5, 5], dtype=float64, place=Place(cpu), stop_gradient=True,
[[-0.03586982, -0.17211503, 0.31536566, -0.38225676, -0.85059629],
[-0.38386839, 0.67754925, 0.23222694, 0.51777188, -0.26749766],
[-0.85977150, -0.28442378, -0.41412094, -0.08955629, -0.01948348],
[ 0.18611503, 0.56047358, -0.67717019, -0.39286761, -0.19577062],
[ 0.27841082, -0.34099254, -0.46535957, 0.65071250, -0.40770727]])
>>> print(S)
Tensor(shape=[5], dtype=float64, place=Place(cpu), stop_gradient=True,
[4.11253399, 3.03227120, 2.45499752, 1.25602436, 0.45825337])
>>> print(V)
Tensor(shape=[5, 5], dtype=float64, place=Place(cpu), stop_gradient=True,
[[ 0.46401347, 0.50977695, -0.08742316, -0.11140428, -0.71046833],
[-0.48927226, -0.35047624, 0.07918771, 0.45431083, -0.65200463],
[-0.20494730, 0.67097011, -0.05427719, 0.66510472, 0.24997083],
[-0.69645001, 0.40237917, 0.09360970, -0.58032322, -0.08666357],
[ 0.13512270, 0.07199989, 0.98710572, 0.04529277, 0.01134594]])
"""
if not paddle.is_tensor(x):
raise ValueError(f'Input must be tensor, but got {type(x)}')

m, n = x.shape[-2:]
if q is None:
q = min(6, m, n)
elif not (q >= 0 and q <= min(m, n)):
raise ValueError(
f'q(={q}) must be non-negative integer'
f' and not greater than min(m, n)={min(m, n)}'
)

if not (niter >= 0):
raise ValueError(f'niter(={niter}) must be non-negative integer')

if M is None:
M_t = None
else:
M = M.broadcast_to(x.shape)
M_t = _transpose(M)
A_t = _transpose(x)

if m < n or n > q:
Q = _get_approximate_basis(A_t, q, niter=niter, M=M_t)
Q_c = _conjugate(Q)
if M is None:
B_t = paddle.matmul(x, Q_c)
else:
B_t = paddle.matmul(x, Q_c) - paddle.matmul(M, Q_c)
assert B_t.shape[-2] == m, (B_t.shape, m)
assert B_t.shape[-1] == q, (B_t.shape, q)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = _transjugate(Vh)
V = Q.matmul(V)
else:
Q = _get_approximate_basis(x, q, niter=niter, M=M)
Q_c = _conjugate(Q)
if M is None:
B = paddle.matmul(A_t, Q_c)
else:
B = paddle.matmul(A_t, Q_c) - paddle.matmul(M_t, Q_c)
B_t = _transpose(B)
assert B_t.shape[-2] == q, (B_t.shape, q)
assert B_t.shape[-1] == n, (B_t.shape, n)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = _transjugate(Vh)
U = Q.matmul(U)

return U, S, V


def pca_lowrank(x, q=None, center=True, niter=2, name=None):
r"""
Performs linear Principal Component Analysis (PCA) on a low-rank matrix or batches of such matrices.
Expand Down Expand Up @@ -2588,82 +2743,6 @@ def pca_lowrank(x, q=None, center=True, niter=2, name=None):
[-0.67131070, -0.19071018, 0.07795789, -0.04615811, 0.71046714]])
"""

def conjugate(x):
if x.is_complex():
return x.conj()
return x

def transpose(x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
return paddle.transpose(x, perm)

def transjugate(x):
return conjugate(transpose(x))

def get_approximate_basis(x, q, niter=2, M=None):
niter = 2 if niter is None else niter
m, n = x.shape[-2:]
qr = paddle.linalg.qr

R = paddle.randn((n, q), dtype=x.dtype)

A_t = transpose(x)
A_H = conjugate(A_t)
if M is None:
Q = qr(paddle.matmul(x, R))[0]
for i in range(niter):
Q = qr(paddle.matmul(A_H, Q))[0]
Q = qr(paddle.matmul(x, Q))[0]
else:
M_H = transjugate(M)
Q = qr(paddle.matmul(x, R) - paddle.matmul(M, R))[0]
for i in range(niter):
Q = qr(paddle.matmul(A_H, Q) - paddle.matmul(M_H, Q))[0]
Q = qr(paddle.matmul(x, Q) - paddle.matmul(M, Q))[0]

return Q

def svd_lowrank(x, q=6, niter=2, M=None):
q = 6 if q is None else q
m, n = x.shape[-2:]
if M is None:
M_t = None
else:
M_t = transpose(M)
A_t = transpose(x)

if m < n or n > q:
Q = get_approximate_basis(A_t, q, niter=niter, M=M_t)
Q_c = conjugate(Q)
if M is None:
B_t = paddle.matmul(x, Q_c)
else:
B_t = paddle.matmul(x, Q_c) - paddle.matmul(M, Q_c)
assert B_t.shape[-2] == m, (B_t.shape, m)
assert B_t.shape[-1] == q, (B_t.shape, q)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
V = Q.matmul(V)
else:
Q = get_approximate_basis(x, q, niter=niter, M=M)
Q_c = conjugate(Q)
if M is None:
B = paddle.matmul(A_t, Q_c)
else:
B = paddle.matmul(A_t, Q_c) - paddle.matmul(M_t, Q_c)
B_t = transpose(B)
assert B_t.shape[-2] == q, (B_t.shape, q)
assert B_t.shape[-1] == n, (B_t.shape, n)
assert B_t.shape[-1] <= B_t.shape[-2], B_t.shape
U, S, Vh = paddle.linalg.svd(B_t, full_matrices=False)
V = transjugate(Vh)
U = Q.matmul(U)

return U, S, V

if not paddle.is_tensor(x):
raise ValueError(f'Input must be tensor, but got {type(x)}')

Expand Down
Loading

0 comments on commit 5188ef5

Please sign in to comment.