Skip to content

Commit df56899

Browse files
authored
Merge pull request #532 from FZJ-INM1-BDA/conn_profile_colormap
feat: obtain colorscale of a connectivity profile
2 parents 8c0a453 + d3e831b commit df56899

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

siibra/features/connectivity/regional_connectivity.py

+45-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import pandas as pd
3030
import numpy as np
31-
from typing import Callable, Union, List
31+
from typing import Callable, Union, List, Tuple, Iterator
3232

3333
try:
3434
from typing import Literal
@@ -208,7 +208,6 @@ def get_profile(
208208
Parameters
209209
----------
210210
region: str, Region
211-
subject: str, default: None
212211
min_connectivity: float, default: 0
213212
Regions with connectivity less than this value are discarded.
214213
max_rows: int, default: None
@@ -321,6 +320,50 @@ def plot(
321320
else:
322321
return profile.data.plot(*args, backend=backend, **kwargs)
323322

323+
def get_profile_colorscale(
324+
self,
325+
region: Union[str, _region.Region],
326+
min_connectivity: float = 0,
327+
max_rows: int = None,
328+
direction: Literal['column', 'row'] = 'column',
329+
colorgradient: str = "jet"
330+
) -> Iterator[Tuple[_region.Region, Tuple[int, int, int]]]:
331+
"""
332+
Extract the colorscale corresponding to the regional profile from the
333+
matrix sorted by the values. See `get_profile` for further details.
334+
335+
Note:
336+
-----
337+
Requires `plotly`.
338+
339+
Parameters
340+
----------
341+
region: str, Region
342+
min_connectivity: float, default: 0
343+
Regions with connectivity less than this value are discarded.
344+
max_rows: int, default: None
345+
Max number of regions with highest connectivity.
346+
direction: str, default: 'column'
347+
Choose the direction of profile extraction particularly for
348+
non-symmetric matrices. ('column' or 'row')
349+
colorgradient: str, default: 'jet'
350+
The gradient used to extract colorscale.
351+
Returns
352+
-------
353+
Iterator[Tuple[_region.Region, Tuple[int, int, int]]]
354+
Color values are in RGB 255.
355+
"""
356+
from plotly.express.colors import sample_colorscale
357+
profile = self.get_profile(region, min_connectivity, max_rows, direction)
358+
colorscale = sample_colorscale(
359+
colorgradient,
360+
profile.data.values.reshape(len(profile.data))
361+
)
362+
return zip(
363+
profile.data.index.values,
364+
[eval(c.removeprefix('rgb')) for c in colorscale]
365+
)
366+
324367
def __len__(self):
325368
return len(self._filename)
326369

0 commit comments

Comments
 (0)