Skip to content

Commit

Permalink
【PIR API adaptor No.289】Migrate pca_lowrank to pir (#60320)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Dec 27, 2023
1 parent aec353c commit 2dfa0f7
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 1 deletion.
2 changes: 1 addition & 1 deletion python/paddle/tensor/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ def gaussian(shape, mean=0.0, std=1.0, seed=0, dtype=None, name=None):
op_type_for_check, supported_dtypes, dtype
)
)
if not isinstance(dtype, core.VarDesc.VarType):
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
dtype = convert_np_dtype_to_dtype_(dtype)

if in_dynamic_or_pir_mode():
Expand Down
102 changes: 102 additions & 0 deletions test/legacy_test/test_pca_lowrank.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,107 @@ def test_niter_range():
self.assertRaises(ValueError, test_niter_range)


class TestStaticPcaLowrankAPI(unittest.TestCase):
def transpose(self, x):
shape = x.shape
perm = list(range(0, len(shape)))
perm = perm[:-2] + [perm[-1]] + [perm[-2]]
return paddle.transpose(x, perm)

def random_matrix(self, rows, columns, *batch_dims, **kwargs):
dtype = kwargs.get('dtype', 'float64')

x = paddle.randn(batch_dims + (rows, columns), dtype=dtype)
u, _, vh = paddle.linalg.svd(x, full_matrices=False)
k = min(rows, columns)
s = paddle.linspace(1 / (k + 1), 1, k, dtype=dtype)
return (u * s.unsqueeze(-2)) @ vh

def random_lowrank_matrix(self, rank, rows, columns, *batch_dims, **kwargs):
B = self.random_matrix(rows, rank, *batch_dims, **kwargs)
C = self.random_matrix(rank, columns, *batch_dims, **kwargs)
return B.matmul(C)

def run_subtest(
self, guess_rank, actual_rank, matrix_size, batches, pca, **options
):
main = paddle.static.Program()
startup = paddle.static.Program()
with paddle.static.program_guard(main, startup):
if isinstance(matrix_size, int):
rows = columns = matrix_size
else:
rows, columns = matrix_size
a_input = self.random_lowrank_matrix(
actual_rank, rows, columns, *batches
)
a = a_input

u, s, v = pca(a_input, q=guess_rank, **options)

self.assertEqual(s.shape[-1], guess_rank)
self.assertEqual(u.shape[-2], rows)
self.assertEqual(u.shape[-1], guess_rank)
self.assertEqual(v.shape[-1], guess_rank)
self.assertEqual(v.shape[-2], columns)

A1 = u.matmul(paddle.diag_embed(s)).matmul(self.transpose(v))
ones_m1 = paddle.ones(batches + (rows, 1), dtype=a.dtype)
c = a.sum(axis=-2) / rows
c = c.reshape(batches + (1, columns))
A2 = a - ones_m1.matmul(c)
detect_rank = (s.abs() > 1e-5).sum(axis=-1)
left1 = actual_rank * paddle.ones(batches, dtype=paddle.int64)
S = paddle.linalg.svd(A2, full_matrices=False)[1]
left2 = s[..., :actual_rank]
right = S[..., :actual_rank]

exe = paddle.static.Executor()
exe.run(startup)
A1, A2, left1, detect_rank, left2, right = exe.run(
main,
feed={},
fetch_list=[A1, A2, left1, detect_rank, left2, right],
)

np.testing.assert_allclose(A1, A2, atol=1e-5)
if not left1.shape:
np.testing.assert_allclose(int(left1), int(detect_rank))
else:
np.testing.assert_allclose(left1, detect_rank)
np.testing.assert_allclose(left2, right)

def test_forward(self):
with paddle.pir_utils.IrGuard():
pca_lowrank = paddle.linalg.pca_lowrank
all_batches = [(), (1,), (3,), (2, 3)]
for actual_rank, size in [
(2, (17, 4)),
(2, (100, 4)),
(6, (100, 40)),
]:
for batches in all_batches:
for guess_rank in [
actual_rank,
actual_rank + 2,
actual_rank + 6,
]:
if guess_rank <= min(*size):
self.run_subtest(
guess_rank,
actual_rank,
size,
batches,
pca_lowrank,
)
self.run_subtest(
guess_rank,
actual_rank,
size[::-1],
batches,
pca_lowrank,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 2dfa0f7

Please sign in to comment.