Skip to content

Commit a119afd

Browse files
authored
Merge pull request #188 from lucasimi/feature/igraph-layout
Feature/igraph layout
2 parents f8e3b6f + d301392 commit a119afd

File tree

5 files changed

+64
-14
lines changed

5 files changed

+64
-14
lines changed

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ keywords = ["tda", "mapper", "topology", "topological data analysis"]
1919
dependencies = [
2020
"matplotlib>=3.3.4",
2121
"networkx>=2.5",
22+
"igraph>=0.11.8",
2223
"numba>=0.54",
2324
"numpy>=1.20.1, <2.0.0",
2425
"plotly>=4.14.3",

src/tdamapper/clustering.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, *args, **kwargs):
4242

4343
class _MapperClustering(EstimatorMixin, ParamsMixin):
4444

45-
def __init__(self, cover=None, clustering=None, n_jobs=1):
45+
def __init__(self, cover=None, clustering=None, n_jobs=-1):
4646
self.cover = cover
4747
self.clustering = clustering
4848
self.n_jobs = n_jobs

src/tdamapper/core.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
)
5959

6060

61-
def mapper_labels(X, y, cover, clustering, n_jobs=1):
61+
def mapper_labels(X, y, cover, clustering, n_jobs=-1):
6262
"""
6363
Identify the nodes of the Mapper graph.
6464
@@ -87,7 +87,7 @@ def mapper_labels(X, y, cover, clustering, n_jobs=1):
8787
interface, typically from :mod:`sklearn.cluster`.
8888
:param n_jobs: The maximum number of parallel clustering jobs. This
8989
parameter is passed to the constructor of :class:`joblib.Parallel`.
90-
Defaults to 1.
90+
Defaults to -1.
9191
:type n_jobs: int
9292
:return: A list of node labels for each point in the dataset.
9393
:rtype: list[list[int]]
@@ -112,7 +112,7 @@ def _run_clustering(local_ids):
112112
return itm_lbls
113113

114114

115-
def mapper_connected_components(X, y, cover, clustering, n_jobs=1):
115+
def mapper_connected_components(X, y, cover, clustering, n_jobs=-1):
116116
"""
117117
Identify the connected components of the Mapper graph.
118118
@@ -139,7 +139,7 @@ def mapper_connected_components(X, y, cover, clustering, n_jobs=1):
139139
interface, typically from :mod:`sklearn.cluster`.
140140
:param n_jobs: The maximum number of parallel clustering jobs. This
141141
parameter is passed to the constructor of :class:`joblib.Parallel`.
142-
Defaults to 1.
142+
Defaults to -1.
143143
:type n_jobs: int
144144
:return: A list of labels. The label at position i identifies the connected
145145
component of the point at position i in the dataset.
@@ -162,7 +162,7 @@ def mapper_connected_components(X, y, cover, clustering, n_jobs=1):
162162
return labels
163163

164164

165-
def mapper_graph(X, y, cover, clustering, n_jobs=1):
165+
def mapper_graph(X, y, cover, clustering, n_jobs=-1):
166166
"""
167167
Create the Mapper graph.
168168
@@ -189,7 +189,7 @@ def mapper_graph(X, y, cover, clustering, n_jobs=1):
189189
interface, typically from :mod:`sklearn.cluster`.
190190
:param n_jobs: The maximum number of parallel clustering jobs. This
191191
parameter is passed to the constructor of :class:`joblib.Parallel`.
192-
Defaults to 1.
192+
Defaults to -1.
193193
:type n_jobs: int
194194
:return: The Mapper graph.
195195
:rtype: :class:`networkx.Graph`
@@ -378,7 +378,7 @@ def __init__(
378378
clustering=None,
379379
failsafe=True,
380380
verbose=True,
381-
n_jobs=1,
381+
n_jobs=-1,
382382
):
383383
self.cover = cover
384384
self.clustering = clustering

src/tdamapper/learn.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@ class MapperClustering(_MapperClustering):
3535
:mod:`sklearn.cluster`
3636
:param n_jobs: The maximum number of parallel clustering jobs. This
3737
parameter is passed to the constructor of :class:`joblib.Parallel`.
38-
Defaults to 1.
38+
Defaults to -1.
3939
:type n_jobs: int
4040
"""
4141

4242
def __init__(
4343
self,
4444
cover=None,
4545
clustering=None,
46-
n_jobs=1,
46+
n_jobs=-1,
4747
):
4848
super().__init__(
4949
cover=cover,
@@ -99,7 +99,7 @@ class MapperAlgorithm(_MapperAlgorithm):
9999
:type verbose: bool, optional
100100
:param n_jobs: The maximum number of parallel clustering jobs. This
101101
parameter is passed to the constructor of :class:`joblib.Parallel`.
102-
Defaults to 1.
102+
Defaults to -1.
103103
:type n_jobs: int
104104
"""
105105

@@ -109,7 +109,7 @@ def __init__(
109109
clustering=None,
110110
failsafe=True,
111111
verbose=True,
112-
n_jobs=1,
112+
n_jobs=-1,
113113
):
114114
super().__init__(
115115
cover=cover,

src/tdamapper/plot.py

+51-2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
This module provides functionalities to visualize the Mapper graph.
33
"""
44
import networkx as nx
5+
import igraph as ig
56

67
import numpy as np
78

@@ -31,20 +32,68 @@ class MapperPlot:
3132
:param seed: The random seed used to construct the graph embedding.
3233
Defaults to None.
3334
:type seed: int, optional
35+
:param layout_engine: The engine used to compute the graph layout in the
36+
specified dimensions. Possible values are 'igraph' and 'networkx'.
37+
Defaults to 'igraph'.
38+
:type layout_engine: str, optional
3439
"""
3540

36-
def __init__(self, graph, dim, iterations=50, seed=None):
41+
def __init__(
42+
self,
43+
graph,
44+
dim,
45+
iterations=50,
46+
seed=None,
47+
layout_engine='igraph',
48+
):
3749
self.graph = graph
3850
self.dim = dim
3951
self.iterations = iterations
4052
self.seed = seed
41-
self.positions = nx.spring_layout(
53+
self.layout_engine = layout_engine
54+
self.positions = self._compute_pos()
55+
56+
def _compute_pos(self):
57+
if self.layout_engine == 'igraph':
58+
return self._compute_pos_ig()
59+
elif self.layout_engine == 'networkx':
60+
return self._compute_pos_nx()
61+
else:
62+
raise ValueError(
63+
f'Unknown engine {self.layout_engine}. '
64+
"Only possible values are 'igraph' and 'networkx'"
65+
)
66+
67+
def _compute_pos_nx(self):
68+
return nx.spring_layout(
4269
self.graph,
4370
dim=self.dim,
4471
seed=self.seed,
4572
iterations=self.iterations,
4673
)
4774

75+
def _compute_pos_ig(self):
76+
if self.graph.number_of_nodes() == 0:
77+
return {}
78+
rng = np.random.default_rng(self.seed)
79+
random_pos = rng.random((len(self.graph.nodes()), self.dim))
80+
graph_ig = ig.Graph.from_networkx(self.graph)
81+
if self.dim == 2:
82+
layout = graph_ig.layout_fruchterman_reingold(
83+
niter=self.iterations,
84+
seed=random_pos,
85+
)
86+
pos = {node: (layout[i][0], layout[i][1]) for i, node in
87+
enumerate(self.graph.nodes())}
88+
elif self.dim == 3:
89+
layout = graph_ig.layout_fruchterman_reingold_3d(
90+
niter=self.iterations,
91+
seed=random_pos,
92+
)
93+
pos = {node: (layout[i][0], layout[i][1], layout[i][2])
94+
for i, node in enumerate(self.graph.nodes())}
95+
return pos
96+
4897
def plot_matplotlib(
4998
self,
5099
colors,

0 commit comments

Comments
 (0)