diff --git a/CHANGELOG.md b/CHANGELOG.md index b6b13e27..494672c5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## [Unreleased] +### Added: +- Tree Exporter: `tree_to_pillow_graph` method to allow cmap for node background. ## [0.24.0] - 2025-02-09 ### Added: diff --git a/bigtree/tree/export/images.py b/bigtree/tree/export/images.py index cc134f19..a97f1f70 100644 --- a/bigtree/tree/export/images.py +++ b/bigtree/tree/export/images.py @@ -2,7 +2,7 @@ import collections import re -from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union from bigtree.node import node from bigtree.tree.export.stdout import yield_tree @@ -22,6 +22,15 @@ Image = ImageDraw = ImageFont = MagicMock() +try: + import matplotlib as mpl + from matplotlib.colors import Normalize +except ImportError: # pragma: no cover + from unittest.mock import MagicMock + + mpl = MagicMock() + Normalize = MagicMock() + __all__ = [ "tree_to_dot", @@ -211,6 +220,21 @@ def _recursive_append(parent_name: Optional[str], child_node: T) -> None: return _graph +def _load_font(font_family: str, font_size: int) -> ImageFont.truetype: + if not font_family: + from urllib.request import urlopen + + dejavusans_url = "https://github.com/kayjan/bigtree/raw/master/assets/DejaVuSans.ttf?raw=true" + font_family = urlopen(dejavusans_url) + try: + font = ImageFont.truetype(font_family, font_size) + except OSError: + raise ValueError( + f"Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file." + ) + return font + + @exceptions.optional_dependencies_image("Pillow") def tree_to_pillow_graph( tree: T, @@ -225,9 +249,10 @@ def tree_to_pillow_graph( text_align: str = "center", bg_colour: Union[Tuple[int, int, int], str] = "white", rect_margin: Optional[Dict[str, int]] = None, - rect_fill: Union[Tuple[int, int, int], str] = "white", + rect_fill: Union[Tuple[int, int, int], str, mpl.colors.Colormap] = "white", + rect_cmap_attr: Optional[str] = None, rect_outline: Union[Tuple[int, int, int], str] = "black", - rect_width: Union[float, int] = 1, + rect_width: int = 1, ) -> Image.Image: r"""Export tree to PIL.Image.Image object. Object can be converted to other formats, such as jpg, or png. Image will look @@ -267,13 +292,21 @@ def tree_to_pillow_graph( text_align (str): text align for multi-line text bg_colour (Union[Tuple[int, int, int], str]): background of image, accepts tuple of RGB values or string, defaults to white rect_margin (Dict[str, int]): (for rectangle) margin of text to rectangle, in pixels - rect_fill (Union[Tuple[int, int, int], str]): (for rectangle) colour to use for fill + rect_fill (Union[Tuple[int, int, int], str, mpl.colormap]): (for rectangle) colour to use for fill + rect_cmap_attr (str): (for rectangle) if rect_fill is a colormap, attribute of node to retrieve fill from colormap, + must be a float/int attribute rect_outline (Union[Tuple[int, int, int], str]): (for rectangle) colour to use for outline - rect_width (Union[float, int]): (for rectangle) line width, in pixels + rect_width (int): (for rectangle) line width, in pixels Returns: (PIL.Image.Image) """ + use_cmap = isinstance(rect_fill, mpl.colors.Colormap) + if use_cmap and rect_cmap_attr is None: + raise ValueError( + "`rect_cmap_attr` cannot be None if rect_fill is mpl.colormaps" + ) + default_margin = {"t": 10, "b": 10, "l": 10, "r": 10} default_rect_margin = {"t": 5, "b": 5, "l": 5, "r": 5} if not margin: @@ -286,19 +319,10 @@ def tree_to_pillow_graph( rect_margin = {**default_rect_margin, **rect_margin} # Initialize font - if not font_family: - from urllib.request import urlopen + font = _load_font(font_family, font_size) - dejavusans_url = "https://github.com/kayjan/bigtree/raw/master/assets/DejaVuSans.ttf?raw=true" - font_family = urlopen(dejavusans_url) - try: - font = ImageFont.truetype(font_family, font_size) - except OSError: - raise ValueError( - f"Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file." - ) - - # Calculate image dimension from text, otherwise override with argument + # Iterate tree once to obtain attributes + # Calculate image dimension from text, get range for colourmap if applicable _max_text_width = 0 _max_text_height = 0 _image = Image.new("RGB", (0, 0)) @@ -309,11 +333,11 @@ def get_node_text(_node: T, _node_content: str) -> str: matches = re.findall(pattern, _node_content) for match in matches: _node_content = _node_content.replace( - f"{{{match}}}", - str(_node.get_attr(match)) if _node.get_attr(match) else "", + f"{{{match}}}", str(_node.get_attr(match, "")) ) return _node_content + cmap_range: Set[Union[float, int]] = set() for _, _, _node in yield_tree(tree): l, t, r, b = _draw.multiline_textbbox( (0, 0), get_node_text(_node, node_content), font=font @@ -324,6 +348,15 @@ def get_node_text(_node: T, _node_content: str) -> str: _max_text_height = max( _max_text_height, t + b + rect_margin.get("t", 0) + rect_margin.get("b", 0) ) + if use_cmap: + cmap_range.add(_node.get_attr(rect_cmap_attr, 0)) + + cmap_dict = {} + if use_cmap: + norm = Normalize(vmin=min(cmap_range), vmax=max(cmap_range)) + cmap_range_list = [norm(c) for c in cmap_range] + cmap_colour_list = rect_fill(cmap_range_list) # type: ignore + cmap_dict = dict(zip(cmap_range_list, cmap_colour_list)) # Get x, y, coordinates and height, width of diagram from bigtree.utils.plot import reingold_tilford @@ -357,8 +390,13 @@ def get_node_text(_node: T, _node_content: str) -> str: x1, x2 = _x - 0.5 * _max_text_width, _x + 0.5 * _max_text_width y1, y2 = _y - 0.5 * _max_text_height, _y + 0.5 * _max_text_height # Draw box + _rect_fill = rect_fill + if use_cmap: + _rect_fill = mpl.colors.rgb2hex( + cmap_dict[norm(_node.get_attr(rect_cmap_attr, 0))] + ) image_draw.rectangle( - [x1, y1, x2, y2], fill=rect_fill, outline=rect_outline, width=rect_width + [x1, y1, x2, y2], fill=_rect_fill, outline=rect_outline, width=rect_width ) # Draw text image_draw.text( @@ -445,17 +483,7 @@ def tree_to_pillow( (PIL.Image.Image) """ # Initialize font - if not font_family: - from urllib.request import urlopen - - dejavusans_url = "https://github.com/kayjan/bigtree/raw/master/assets/DejaVuSans.ttf?raw=true" - font_family = urlopen(dejavusans_url) - try: - font = ImageFont.truetype(font_family, font_size) - except OSError: - raise ValueError( - f"Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file." - ) + font = _load_font(font_family, font_size) # Initialize text image_text = [] diff --git a/bigtree/tree/export/stdout.py b/bigtree/tree/export/stdout.py index 758885dc..418f31e2 100644 --- a/bigtree/tree/export/stdout.py +++ b/bigtree/tree/export/stdout.py @@ -45,7 +45,7 @@ def print_tree( - (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style - (List[str]): Choose own style for stem, branch, and final stem icons, they must have the same number of characters - (constants.BasePrintStyle): `ANSIPrintStyle`, `ASCIIPrintStyle`, `ConstPrintStyle`, `ConstBoldPrintStyle`, `RoundedPrintStyle`, - `DoublePrintStyle` style or inherit from `constants.BasePrintStyle` + `DoublePrintStyle` style or inherit from `constants.BasePrintStyle` Examples: **Printing tree** @@ -235,7 +235,7 @@ def yield_tree( - (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style - (List[str]): Choose own style for stem, branch, and final stem icons, they must have the same number of characters - (constants.BasePrintStyle): `ANSIPrintStyle`, `ASCIIPrintStyle`, `ConstPrintStyle`, `ConstBoldPrintStyle`, `RoundedPrintStyle`, - `DoublePrintStyle` style or inherit from `constants.BasePrintStyle` + `DoublePrintStyle` style or inherit from `constants.BasePrintStyle` Examples: **Yield tree** @@ -413,7 +413,7 @@ def hprint_tree( - (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style - (List[str]): Choose own style icons, they must have the same number of characters - (constants.BaseHPrintStyle): `ANSIHPrintStyle`, `ASCIIHPrintStyle`, `ConstHPrintStyle`, `ConstBoldHPrintStyle`, - `RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle + `RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle Examples: **Printing tree** @@ -535,7 +535,7 @@ def hyield_tree( - (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style - (List[str]): Choose own style icons, they must have the same number of characters - (constants.BaseHPrintStyle): `ANSIHPrintStyle`, `ASCIIHPrintStyle`, `ConstHPrintStyle`, `ConstBoldHPrintStyle`, - `RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle + `RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle Examples: **Printing tree** diff --git a/tests/test_constants.py b/tests/test_constants.py index 7484d925..a26c8d58 100644 --- a/tests/test_constants.py +++ b/tests/test_constants.py @@ -154,6 +154,9 @@ class Constants: "Node name or path {node_name_or_path} not found" ) ERROR_NODE_EXPORT_PILLOW_FONT_FAMILY = "Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file." + ERROR_NODE_EXPORT_PILLOW_CMAP = ( + "`rect_cmap_attr` cannot be None if rect_fill is mpl.colormaps" + ) ERROR_NODE_MERMAID_INVALID_STYLE = "Unable to construct style!" ERROR_NODE_EXPORT_PRINT_INVALID_STYLE = "Choose one of " diff --git a/tests/tree/export/test_images.py b/tests/tree/export/test_images.py index 0a7d32a7..18efa248 100644 --- a/tests/tree/export/test_images.py +++ b/tests/tree/export/test_images.py @@ -1,3 +1,4 @@ +import matplotlib as mpl import pydot import pytest @@ -744,6 +745,14 @@ def test_tree_to_pillow_graph_buffer(tree_node): if LOCAL: pillow_image.save("tests/tree_pillow_graph_buffer.png") + @staticmethod + def test_tree_to_pillow_graph_bg_colour(tree_node): + pillow_image = export.tree_to_pillow_graph( + tree_node, node_content="{node_name}\nAge: {age}", bg_colour="beige" + ) + if LOCAL: + pillow_image.save("tests/tree_pillow_graph_bg_colour.png") + @staticmethod def test_tree_to_pillow_graph_rect_tb_margins(tree_node): pillow_image = export.tree_to_pillow_graph( @@ -764,6 +773,47 @@ def test_tree_to_pillow_graph_rect_lr_margins(tree_node): if LOCAL: pillow_image.save("tests/tree_pillow_graph_rect_lr_margins.png") + @staticmethod + def test_tree_to_pillow_graph_rect_fill(tree_node): + pillow_image = export.tree_to_pillow_graph( + tree_node, + node_content="{node_name}\nAge: {age}", + rect_fill="beige", + ) + if LOCAL: + pillow_image.save("tests/tree_pillow_graph_rect_fill.png") + + @staticmethod + def test_tree_to_pillow_graph_rect_fill_cmap_error(tree_node): + with pytest.raises(ValueError) as exc_info: + export.tree_to_pillow_graph( + tree_node, + node_content="{node_name}\nAge: {age}", + rect_fill=mpl.colormaps["RdBu"], + ) + assert str(exc_info.value) == Constants.ERROR_NODE_EXPORT_PILLOW_CMAP + + @staticmethod + def test_tree_to_pillow_graph_rect_fill_cmap(tree_node): + pillow_image = export.tree_to_pillow_graph( + tree_node, + node_content="{node_name}\nAge: {age}", + rect_fill=mpl.colormaps["RdBu"], + rect_cmap_attr="age", + ) + if LOCAL: + pillow_image.save("tests/tree_pillow_graph_rect_fill_cmap.png") + + @staticmethod + def test_tree_to_pillow_graph_rect_width(tree_node): + pillow_image = export.tree_to_pillow_graph( + tree_node, + node_content="{node_name}\nAge: {age}", + rect_width=3, + ) + if LOCAL: + pillow_image.save("tests/tree_pillow_graph_rect_width.png") + class TestTreeToPillow: @staticmethod