Skip to content

Commit dcbdeb8

Browse files
committed
feat: obtain colorscale of a connectivity profile
1 parent 8dc3ee6 commit dcbdeb8

File tree

1 file changed

+45
-2
lines changed

1 file changed

+45
-2
lines changed

siibra/features/connectivity/regional_connectivity.py

+45-2
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import pandas as pd
3030
import numpy as np
31-
from typing import Callable, Union, List
31+
from typing import Callable, Union, List, Tuple, Generator
3232

3333
try:
3434
from typing import Literal
@@ -208,7 +208,6 @@ def get_profile(
208208
Parameters
209209
----------
210210
region: str, Region
211-
subject: str, default: None
212211
min_connectivity: float, default: 0
213212
Regions with connectivity less than this value are discarded.
214213
max_rows: int, default: None
@@ -320,6 +319,50 @@ def plot(
320319
else:
321320
return profile.data.plot(*args, backend=backend, **kwargs)
322321

322+
def get_profile_colorscale(
323+
self,
324+
region: Union[str, _region.Region],
325+
min_connectivity: float = 0,
326+
max_rows: int = None,
327+
direction: Literal['column', 'row'] = 'column',
328+
colorgradient: str = "jet"
329+
) -> Generator[_region.Region, Tuple[int, int, int]]:
330+
"""
331+
Extract the colorscale corresponding to the regional profile from the
332+
matrix sorted by the values. See `get_profile` for further details.
333+
334+
Note:
335+
-----
336+
Requires `plotly`.
337+
338+
Parameters
339+
----------
340+
region: str, Region
341+
min_connectivity: float, default: 0
342+
Regions with connectivity less than this value are discarded.
343+
max_rows: int, default: None
344+
Max number of regions with highest connectivity.
345+
direction: str, default: 'column'
346+
Choose the direction of profile extraction particularly for
347+
non-symmetric matrices. ('column' or 'row')
348+
colorgradient: str, default: 'jet'
349+
The gradient used to extract colorscale.
350+
Returns
351+
-------
352+
Generator[_region.Region, Tuple[int, int, int]]
353+
Color values are in RGB 255.
354+
"""
355+
from plotly.express.colors import sample_colorscale
356+
profile = self.get_profile(region, min_connectivity, max_rows, direction)
357+
colorscale = sample_colorscale(
358+
colorgradient,
359+
profile.data.values.reshape(len(profile.data))
360+
)
361+
return zip(
362+
profile.data.index.values,
363+
[eval(c.removeprefix('rgb')) for c in colorscale]
364+
)
365+
323366
def __len__(self):
324367
return len(self._filename)
325368

0 commit comments

Comments
 (0)