Skip to content

Commit 6ddba3a

Browse files
author
gyzhou2000
committed
add svd feature reduction
1 parent fb00d7f commit 6ddba3a

File tree

3 files changed

+55
-1
lines changed

3 files changed

+55
-1
lines changed

gammagl/transforms/__init__.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .drop_edge import DropEdge
77
from .random_link_split import RandomLinkSplit
88
from .vgae_pre import mask_test_edges, sparse_to_tuple
9+
from .svd_feature_reduction import SVDFeatureReduction
910

1011
__all__ = [
1112
'BaseTransform',
@@ -16,7 +17,8 @@
1617
'DropEdge',
1718
'RandomLinkSplit',
1819
'mask_test_edges',
19-
'sparse_to_tuple'
20+
'sparse_to_tuple',
21+
'SVDFeatureReduction'
2022

2123
]
2224

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from gammagl.transforms import BaseTransform
2+
3+
from gammagl.data import Graph
4+
from typing import List
5+
import numpy as np
6+
import tensorlayerx as tlx
7+
8+
9+
class SVDFeatureReduction(BaseTransform):
10+
r"""Dimensionality reduction of node features via Singular Value Decomposition (SVD)
11+
(functional name: :obj:`normalize_features`).
12+
13+
Parameters
14+
----------
15+
out_channels: int
16+
The dimensionlity of node features after reduction.
17+
18+
"""
19+
20+
def __init__(self, out_channels: int):
21+
self.out_channels = out_channels
22+
23+
def __call__(self, graph: Graph):
24+
assert graph.x is not None
25+
26+
if graph.x.shape[-1] > self.out_channels:
27+
x = tlx.convert_to_numpy(graph.x)
28+
U, S, _ = np.linalg.svd(x, full_matrices=False)
29+
U_reduced = U[:, :self.out_channels]
30+
S_reduced = np.diag(S[:self.out_channels])
31+
x = np.dot(U_reduced, S_reduced)
32+
graph.x = tlx.convert_to_tensor(x)
33+
34+
return graph
35+
36+
def __repr__(self) -> str:
37+
return f'{self.__class__.__name__}()'
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import tensorlayerx as tlx
2+
3+
from gammagl.data import Graph, HeteroGraph
4+
from gammagl.transforms import SVDFeatureReduction
5+
6+
7+
def test_normalize_scale():
8+
assert SVDFeatureReduction(2).__repr__() == 'SVDFeatureReduction()'
9+
10+
x = tlx.convert_to_tensor([[1, 2, 3, 4, 5], [5, 4, 3, 2, 1]], dtype=tlx.float32)
11+
12+
data = Graph(x=x)
13+
data = SVDFeatureReduction(out_channels=2)(data)
14+
assert len(data) == 1
15+
assert tlx.get_tensor_shape(data.x) == [2, 2]

0 commit comments

Comments
 (0)