Skip to content

Commit a589ef6

Browse files
authored
Add cmap to pillow graph (#347)
* feat: allow cmaps for node background for pillow graph export * docs: update CHANGELOG
1 parent 7302c10 commit a589ef6

File tree

5 files changed

+118
-35
lines changed

5 files changed

+118
-35
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
55
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
66

77
## [Unreleased]
8+
### Added:
9+
- Tree Exporter: `tree_to_pillow_graph` method to allow cmap for node background.
810

911
## [0.24.0] - 2025-02-09
1012
### Added:

bigtree/tree/export/images.py

+59-31
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import collections
44
import re
5-
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
5+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union
66

77
from bigtree.node import node
88
from bigtree.tree.export.stdout import yield_tree
@@ -22,6 +22,15 @@
2222

2323
Image = ImageDraw = ImageFont = MagicMock()
2424

25+
try:
26+
import matplotlib as mpl
27+
from matplotlib.colors import Normalize
28+
except ImportError: # pragma: no cover
29+
from unittest.mock import MagicMock
30+
31+
mpl = MagicMock()
32+
Normalize = MagicMock()
33+
2534

2635
__all__ = [
2736
"tree_to_dot",
@@ -211,6 +220,21 @@ def _recursive_append(parent_name: Optional[str], child_node: T) -> None:
211220
return _graph
212221

213222

223+
def _load_font(font_family: str, font_size: int) -> ImageFont.truetype:
224+
if not font_family:
225+
from urllib.request import urlopen
226+
227+
dejavusans_url = "https://github.com/kayjan/bigtree/raw/master/assets/DejaVuSans.ttf?raw=true"
228+
font_family = urlopen(dejavusans_url)
229+
try:
230+
font = ImageFont.truetype(font_family, font_size)
231+
except OSError:
232+
raise ValueError(
233+
f"Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file."
234+
)
235+
return font
236+
237+
214238
@exceptions.optional_dependencies_image("Pillow")
215239
def tree_to_pillow_graph(
216240
tree: T,
@@ -225,9 +249,10 @@ def tree_to_pillow_graph(
225249
text_align: str = "center",
226250
bg_colour: Union[Tuple[int, int, int], str] = "white",
227251
rect_margin: Optional[Dict[str, int]] = None,
228-
rect_fill: Union[Tuple[int, int, int], str] = "white",
252+
rect_fill: Union[Tuple[int, int, int], str, mpl.colors.Colormap] = "white",
253+
rect_cmap_attr: Optional[str] = None,
229254
rect_outline: Union[Tuple[int, int, int], str] = "black",
230-
rect_width: Union[float, int] = 1,
255+
rect_width: int = 1,
231256
) -> Image.Image:
232257
r"""Export tree to PIL.Image.Image object. Object can be
233258
converted to other formats, such as jpg, or png. Image will look
@@ -267,13 +292,21 @@ def tree_to_pillow_graph(
267292
text_align (str): text align for multi-line text
268293
bg_colour (Union[Tuple[int, int, int], str]): background of image, accepts tuple of RGB values or string, defaults to white
269294
rect_margin (Dict[str, int]): (for rectangle) margin of text to rectangle, in pixels
270-
rect_fill (Union[Tuple[int, int, int], str]): (for rectangle) colour to use for fill
295+
rect_fill (Union[Tuple[int, int, int], str, mpl.colormap]): (for rectangle) colour to use for fill
296+
rect_cmap_attr (str): (for rectangle) if rect_fill is a colormap, attribute of node to retrieve fill from colormap,
297+
must be a float/int attribute
271298
rect_outline (Union[Tuple[int, int, int], str]): (for rectangle) colour to use for outline
272-
rect_width (Union[float, int]): (for rectangle) line width, in pixels
299+
rect_width (int): (for rectangle) line width, in pixels
273300
274301
Returns:
275302
(PIL.Image.Image)
276303
"""
304+
use_cmap = isinstance(rect_fill, mpl.colors.Colormap)
305+
if use_cmap and rect_cmap_attr is None:
306+
raise ValueError(
307+
"`rect_cmap_attr` cannot be None if rect_fill is mpl.colormaps"
308+
)
309+
277310
default_margin = {"t": 10, "b": 10, "l": 10, "r": 10}
278311
default_rect_margin = {"t": 5, "b": 5, "l": 5, "r": 5}
279312
if not margin:
@@ -286,19 +319,10 @@ def tree_to_pillow_graph(
286319
rect_margin = {**default_rect_margin, **rect_margin}
287320

288321
# Initialize font
289-
if not font_family:
290-
from urllib.request import urlopen
322+
font = _load_font(font_family, font_size)
291323

292-
dejavusans_url = "https://github.com/kayjan/bigtree/raw/master/assets/DejaVuSans.ttf?raw=true"
293-
font_family = urlopen(dejavusans_url)
294-
try:
295-
font = ImageFont.truetype(font_family, font_size)
296-
except OSError:
297-
raise ValueError(
298-
f"Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file."
299-
)
300-
301-
# Calculate image dimension from text, otherwise override with argument
324+
# Iterate tree once to obtain attributes
325+
# Calculate image dimension from text, get range for colourmap if applicable
302326
_max_text_width = 0
303327
_max_text_height = 0
304328
_image = Image.new("RGB", (0, 0))
@@ -309,11 +333,11 @@ def get_node_text(_node: T, _node_content: str) -> str:
309333
matches = re.findall(pattern, _node_content)
310334
for match in matches:
311335
_node_content = _node_content.replace(
312-
f"{{{match}}}",
313-
str(_node.get_attr(match)) if _node.get_attr(match) else "",
336+
f"{{{match}}}", str(_node.get_attr(match, ""))
314337
)
315338
return _node_content
316339

340+
cmap_range: Set[Union[float, int]] = set()
317341
for _, _, _node in yield_tree(tree):
318342
l, t, r, b = _draw.multiline_textbbox(
319343
(0, 0), get_node_text(_node, node_content), font=font
@@ -324,6 +348,15 @@ def get_node_text(_node: T, _node_content: str) -> str:
324348
_max_text_height = max(
325349
_max_text_height, t + b + rect_margin.get("t", 0) + rect_margin.get("b", 0)
326350
)
351+
if use_cmap:
352+
cmap_range.add(_node.get_attr(rect_cmap_attr, 0))
353+
354+
cmap_dict = {}
355+
if use_cmap:
356+
norm = Normalize(vmin=min(cmap_range), vmax=max(cmap_range))
357+
cmap_range_list = [norm(c) for c in cmap_range]
358+
cmap_colour_list = rect_fill(cmap_range_list) # type: ignore
359+
cmap_dict = dict(zip(cmap_range_list, cmap_colour_list))
327360

328361
# Get x, y, coordinates and height, width of diagram
329362
from bigtree.utils.plot import reingold_tilford
@@ -357,8 +390,13 @@ def get_node_text(_node: T, _node_content: str) -> str:
357390
x1, x2 = _x - 0.5 * _max_text_width, _x + 0.5 * _max_text_width
358391
y1, y2 = _y - 0.5 * _max_text_height, _y + 0.5 * _max_text_height
359392
# Draw box
393+
_rect_fill = rect_fill
394+
if use_cmap:
395+
_rect_fill = mpl.colors.rgb2hex(
396+
cmap_dict[norm(_node.get_attr(rect_cmap_attr, 0))]
397+
)
360398
image_draw.rectangle(
361-
[x1, y1, x2, y2], fill=rect_fill, outline=rect_outline, width=rect_width
399+
[x1, y1, x2, y2], fill=_rect_fill, outline=rect_outline, width=rect_width
362400
)
363401
# Draw text
364402
image_draw.text(
@@ -445,17 +483,7 @@ def tree_to_pillow(
445483
(PIL.Image.Image)
446484
"""
447485
# Initialize font
448-
if not font_family:
449-
from urllib.request import urlopen
450-
451-
dejavusans_url = "https://github.com/kayjan/bigtree/raw/master/assets/DejaVuSans.ttf?raw=true"
452-
font_family = urlopen(dejavusans_url)
453-
try:
454-
font = ImageFont.truetype(font_family, font_size)
455-
except OSError:
456-
raise ValueError(
457-
f"Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file."
458-
)
486+
font = _load_font(font_family, font_size)
459487

460488
# Initialize text
461489
image_text = []

bigtree/tree/export/stdout.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def print_tree(
4545
- (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style
4646
- (List[str]): Choose own style for stem, branch, and final stem icons, they must have the same number of characters
4747
- (constants.BasePrintStyle): `ANSIPrintStyle`, `ASCIIPrintStyle`, `ConstPrintStyle`, `ConstBoldPrintStyle`, `RoundedPrintStyle`,
48-
`DoublePrintStyle` style or inherit from `constants.BasePrintStyle`
48+
`DoublePrintStyle` style or inherit from `constants.BasePrintStyle`
4949
5050
Examples:
5151
**Printing tree**
@@ -235,7 +235,7 @@ def yield_tree(
235235
- (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style
236236
- (List[str]): Choose own style for stem, branch, and final stem icons, they must have the same number of characters
237237
- (constants.BasePrintStyle): `ANSIPrintStyle`, `ASCIIPrintStyle`, `ConstPrintStyle`, `ConstBoldPrintStyle`, `RoundedPrintStyle`,
238-
`DoublePrintStyle` style or inherit from `constants.BasePrintStyle`
238+
`DoublePrintStyle` style or inherit from `constants.BasePrintStyle`
239239
240240
Examples:
241241
**Yield tree**
@@ -413,7 +413,7 @@ def hprint_tree(
413413
- (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style
414414
- (List[str]): Choose own style icons, they must have the same number of characters
415415
- (constants.BaseHPrintStyle): `ANSIHPrintStyle`, `ASCIIHPrintStyle`, `ConstHPrintStyle`, `ConstBoldHPrintStyle`,
416-
`RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle
416+
`RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle
417417
418418
Examples:
419419
**Printing tree**
@@ -535,7 +535,7 @@ def hyield_tree(
535535
- (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style
536536
- (List[str]): Choose own style icons, they must have the same number of characters
537537
- (constants.BaseHPrintStyle): `ANSIHPrintStyle`, `ASCIIHPrintStyle`, `ConstHPrintStyle`, `ConstBoldHPrintStyle`,
538-
`RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle
538+
`RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle
539539
540540
Examples:
541541
**Printing tree**

tests/test_constants.py

+3
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,9 @@ class Constants:
154154
"Node name or path {node_name_or_path} not found"
155155
)
156156
ERROR_NODE_EXPORT_PILLOW_FONT_FAMILY = "Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file."
157+
ERROR_NODE_EXPORT_PILLOW_CMAP = (
158+
"`rect_cmap_attr` cannot be None if rect_fill is mpl.colormaps"
159+
)
157160
ERROR_NODE_MERMAID_INVALID_STYLE = "Unable to construct style!"
158161

159162
ERROR_NODE_EXPORT_PRINT_INVALID_STYLE = "Choose one of "

tests/tree/export/test_images.py

+50
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import matplotlib as mpl
12
import pydot
23
import pytest
34

@@ -744,6 +745,14 @@ def test_tree_to_pillow_graph_buffer(tree_node):
744745
if LOCAL:
745746
pillow_image.save("tests/tree_pillow_graph_buffer.png")
746747

748+
@staticmethod
749+
def test_tree_to_pillow_graph_bg_colour(tree_node):
750+
pillow_image = export.tree_to_pillow_graph(
751+
tree_node, node_content="{node_name}\nAge: {age}", bg_colour="beige"
752+
)
753+
if LOCAL:
754+
pillow_image.save("tests/tree_pillow_graph_bg_colour.png")
755+
747756
@staticmethod
748757
def test_tree_to_pillow_graph_rect_tb_margins(tree_node):
749758
pillow_image = export.tree_to_pillow_graph(
@@ -764,6 +773,47 @@ def test_tree_to_pillow_graph_rect_lr_margins(tree_node):
764773
if LOCAL:
765774
pillow_image.save("tests/tree_pillow_graph_rect_lr_margins.png")
766775

776+
@staticmethod
777+
def test_tree_to_pillow_graph_rect_fill(tree_node):
778+
pillow_image = export.tree_to_pillow_graph(
779+
tree_node,
780+
node_content="{node_name}\nAge: {age}",
781+
rect_fill="beige",
782+
)
783+
if LOCAL:
784+
pillow_image.save("tests/tree_pillow_graph_rect_fill.png")
785+
786+
@staticmethod
787+
def test_tree_to_pillow_graph_rect_fill_cmap_error(tree_node):
788+
with pytest.raises(ValueError) as exc_info:
789+
export.tree_to_pillow_graph(
790+
tree_node,
791+
node_content="{node_name}\nAge: {age}",
792+
rect_fill=mpl.colormaps["RdBu"],
793+
)
794+
assert str(exc_info.value) == Constants.ERROR_NODE_EXPORT_PILLOW_CMAP
795+
796+
@staticmethod
797+
def test_tree_to_pillow_graph_rect_fill_cmap(tree_node):
798+
pillow_image = export.tree_to_pillow_graph(
799+
tree_node,
800+
node_content="{node_name}\nAge: {age}",
801+
rect_fill=mpl.colormaps["RdBu"],
802+
rect_cmap_attr="age",
803+
)
804+
if LOCAL:
805+
pillow_image.save("tests/tree_pillow_graph_rect_fill_cmap.png")
806+
807+
@staticmethod
808+
def test_tree_to_pillow_graph_rect_width(tree_node):
809+
pillow_image = export.tree_to_pillow_graph(
810+
tree_node,
811+
node_content="{node_name}\nAge: {age}",
812+
rect_width=3,
813+
)
814+
if LOCAL:
815+
pillow_image.save("tests/tree_pillow_graph_rect_width.png")
816+
767817

768818
class TestTreeToPillow:
769819
@staticmethod

0 commit comments

Comments
 (0)