Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cmap to pillow graph #347

Merged
merged 2 commits into from
Feb 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## [Unreleased]
### Added:
- Tree Exporter: `tree_to_pillow_graph` method to allow cmap for node background.

## [0.24.0] - 2025-02-09
### Added:
Expand Down
90 changes: 59 additions & 31 deletions bigtree/tree/export/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import collections
import re
from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union

from bigtree.node import node
from bigtree.tree.export.stdout import yield_tree
Expand All @@ -22,6 +22,15 @@

Image = ImageDraw = ImageFont = MagicMock()

try:
import matplotlib as mpl
from matplotlib.colors import Normalize
except ImportError: # pragma: no cover
from unittest.mock import MagicMock

mpl = MagicMock()
Normalize = MagicMock()


__all__ = [
"tree_to_dot",
Expand Down Expand Up @@ -211,6 +220,21 @@ def _recursive_append(parent_name: Optional[str], child_node: T) -> None:
return _graph


def _load_font(font_family: str, font_size: int) -> ImageFont.truetype:
if not font_family:
from urllib.request import urlopen

dejavusans_url = "https://github.com/kayjan/bigtree/raw/master/assets/DejaVuSans.ttf?raw=true"
font_family = urlopen(dejavusans_url)
try:
font = ImageFont.truetype(font_family, font_size)
except OSError:
raise ValueError(
f"Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file."
)
return font


@exceptions.optional_dependencies_image("Pillow")
def tree_to_pillow_graph(
tree: T,
Expand All @@ -225,9 +249,10 @@ def tree_to_pillow_graph(
text_align: str = "center",
bg_colour: Union[Tuple[int, int, int], str] = "white",
rect_margin: Optional[Dict[str, int]] = None,
rect_fill: Union[Tuple[int, int, int], str] = "white",
rect_fill: Union[Tuple[int, int, int], str, mpl.colors.Colormap] = "white",
rect_cmap_attr: Optional[str] = None,
rect_outline: Union[Tuple[int, int, int], str] = "black",
rect_width: Union[float, int] = 1,
rect_width: int = 1,
) -> Image.Image:
r"""Export tree to PIL.Image.Image object. Object can be
converted to other formats, such as jpg, or png. Image will look
Expand Down Expand Up @@ -267,13 +292,21 @@ def tree_to_pillow_graph(
text_align (str): text align for multi-line text
bg_colour (Union[Tuple[int, int, int], str]): background of image, accepts tuple of RGB values or string, defaults to white
rect_margin (Dict[str, int]): (for rectangle) margin of text to rectangle, in pixels
rect_fill (Union[Tuple[int, int, int], str]): (for rectangle) colour to use for fill
rect_fill (Union[Tuple[int, int, int], str, mpl.colormap]): (for rectangle) colour to use for fill
rect_cmap_attr (str): (for rectangle) if rect_fill is a colormap, attribute of node to retrieve fill from colormap,
must be a float/int attribute
rect_outline (Union[Tuple[int, int, int], str]): (for rectangle) colour to use for outline
rect_width (Union[float, int]): (for rectangle) line width, in pixels
rect_width (int): (for rectangle) line width, in pixels

Returns:
(PIL.Image.Image)
"""
use_cmap = isinstance(rect_fill, mpl.colors.Colormap)
if use_cmap and rect_cmap_attr is None:
raise ValueError(
"`rect_cmap_attr` cannot be None if rect_fill is mpl.colormaps"
)

default_margin = {"t": 10, "b": 10, "l": 10, "r": 10}
default_rect_margin = {"t": 5, "b": 5, "l": 5, "r": 5}
if not margin:
Expand All @@ -286,19 +319,10 @@ def tree_to_pillow_graph(
rect_margin = {**default_rect_margin, **rect_margin}

# Initialize font
if not font_family:
from urllib.request import urlopen
font = _load_font(font_family, font_size)

dejavusans_url = "https://github.com/kayjan/bigtree/raw/master/assets/DejaVuSans.ttf?raw=true"
font_family = urlopen(dejavusans_url)
try:
font = ImageFont.truetype(font_family, font_size)
except OSError:
raise ValueError(
f"Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file."
)

# Calculate image dimension from text, otherwise override with argument
# Iterate tree once to obtain attributes
# Calculate image dimension from text, get range for colourmap if applicable
_max_text_width = 0
_max_text_height = 0
_image = Image.new("RGB", (0, 0))
Expand All @@ -309,11 +333,11 @@ def get_node_text(_node: T, _node_content: str) -> str:
matches = re.findall(pattern, _node_content)
for match in matches:
_node_content = _node_content.replace(
f"{{{match}}}",
str(_node.get_attr(match)) if _node.get_attr(match) else "",
f"{{{match}}}", str(_node.get_attr(match, ""))
)
return _node_content

cmap_range: Set[Union[float, int]] = set()
for _, _, _node in yield_tree(tree):
l, t, r, b = _draw.multiline_textbbox(
(0, 0), get_node_text(_node, node_content), font=font
Expand All @@ -324,6 +348,15 @@ def get_node_text(_node: T, _node_content: str) -> str:
_max_text_height = max(
_max_text_height, t + b + rect_margin.get("t", 0) + rect_margin.get("b", 0)
)
if use_cmap:
cmap_range.add(_node.get_attr(rect_cmap_attr, 0))

cmap_dict = {}
if use_cmap:
norm = Normalize(vmin=min(cmap_range), vmax=max(cmap_range))
cmap_range_list = [norm(c) for c in cmap_range]
cmap_colour_list = rect_fill(cmap_range_list) # type: ignore
cmap_dict = dict(zip(cmap_range_list, cmap_colour_list))

# Get x, y, coordinates and height, width of diagram
from bigtree.utils.plot import reingold_tilford
Expand Down Expand Up @@ -357,8 +390,13 @@ def get_node_text(_node: T, _node_content: str) -> str:
x1, x2 = _x - 0.5 * _max_text_width, _x + 0.5 * _max_text_width
y1, y2 = _y - 0.5 * _max_text_height, _y + 0.5 * _max_text_height
# Draw box
_rect_fill = rect_fill
if use_cmap:
_rect_fill = mpl.colors.rgb2hex(
cmap_dict[norm(_node.get_attr(rect_cmap_attr, 0))]
)
image_draw.rectangle(
[x1, y1, x2, y2], fill=rect_fill, outline=rect_outline, width=rect_width
[x1, y1, x2, y2], fill=_rect_fill, outline=rect_outline, width=rect_width
)
# Draw text
image_draw.text(
Expand Down Expand Up @@ -445,17 +483,7 @@ def tree_to_pillow(
(PIL.Image.Image)
"""
# Initialize font
if not font_family:
from urllib.request import urlopen

dejavusans_url = "https://github.com/kayjan/bigtree/raw/master/assets/DejaVuSans.ttf?raw=true"
font_family = urlopen(dejavusans_url)
try:
font = ImageFont.truetype(font_family, font_size)
except OSError:
raise ValueError(
f"Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file."
)
font = _load_font(font_family, font_size)

# Initialize text
image_text = []
Expand Down
8 changes: 4 additions & 4 deletions bigtree/tree/export/stdout.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def print_tree(
- (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style
- (List[str]): Choose own style for stem, branch, and final stem icons, they must have the same number of characters
- (constants.BasePrintStyle): `ANSIPrintStyle`, `ASCIIPrintStyle`, `ConstPrintStyle`, `ConstBoldPrintStyle`, `RoundedPrintStyle`,
`DoublePrintStyle` style or inherit from `constants.BasePrintStyle`
`DoublePrintStyle` style or inherit from `constants.BasePrintStyle`

Examples:
**Printing tree**
Expand Down Expand Up @@ -235,7 +235,7 @@ def yield_tree(
- (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style
- (List[str]): Choose own style for stem, branch, and final stem icons, they must have the same number of characters
- (constants.BasePrintStyle): `ANSIPrintStyle`, `ASCIIPrintStyle`, `ConstPrintStyle`, `ConstBoldPrintStyle`, `RoundedPrintStyle`,
`DoublePrintStyle` style or inherit from `constants.BasePrintStyle`
`DoublePrintStyle` style or inherit from `constants.BasePrintStyle`

Examples:
**Yield tree**
Expand Down Expand Up @@ -413,7 +413,7 @@ def hprint_tree(
- (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style
- (List[str]): Choose own style icons, they must have the same number of characters
- (constants.BaseHPrintStyle): `ANSIHPrintStyle`, `ASCIIHPrintStyle`, `ConstHPrintStyle`, `ConstBoldHPrintStyle`,
`RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle
`RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle

Examples:
**Printing tree**
Expand Down Expand Up @@ -535,7 +535,7 @@ def hyield_tree(
- (str): `ansi`, `ascii`, `const` (default), `const_bold`, `rounded`, `double` style
- (List[str]): Choose own style icons, they must have the same number of characters
- (constants.BaseHPrintStyle): `ANSIHPrintStyle`, `ASCIIHPrintStyle`, `ConstHPrintStyle`, `ConstBoldHPrintStyle`,
`RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle
`RoundedHPrintStyle`, `DoubleHPrintStyle` style or inherit from constants.BaseHPrintStyle

Examples:
**Printing tree**
Expand Down
3 changes: 3 additions & 0 deletions tests/test_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,9 @@ class Constants:
"Node name or path {node_name_or_path} not found"
)
ERROR_NODE_EXPORT_PILLOW_FONT_FAMILY = "Font file {font_family} is not found, set `font_family` parameter to point to a valid .ttf file."
ERROR_NODE_EXPORT_PILLOW_CMAP = (
"`rect_cmap_attr` cannot be None if rect_fill is mpl.colormaps"
)
ERROR_NODE_MERMAID_INVALID_STYLE = "Unable to construct style!"

ERROR_NODE_EXPORT_PRINT_INVALID_STYLE = "Choose one of "
Expand Down
50 changes: 50 additions & 0 deletions tests/tree/export/test_images.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import matplotlib as mpl
import pydot
import pytest

Expand Down Expand Up @@ -744,6 +745,14 @@ def test_tree_to_pillow_graph_buffer(tree_node):
if LOCAL:
pillow_image.save("tests/tree_pillow_graph_buffer.png")

@staticmethod
def test_tree_to_pillow_graph_bg_colour(tree_node):
pillow_image = export.tree_to_pillow_graph(
tree_node, node_content="{node_name}\nAge: {age}", bg_colour="beige"
)
if LOCAL:
pillow_image.save("tests/tree_pillow_graph_bg_colour.png")

@staticmethod
def test_tree_to_pillow_graph_rect_tb_margins(tree_node):
pillow_image = export.tree_to_pillow_graph(
Expand All @@ -764,6 +773,47 @@ def test_tree_to_pillow_graph_rect_lr_margins(tree_node):
if LOCAL:
pillow_image.save("tests/tree_pillow_graph_rect_lr_margins.png")

@staticmethod
def test_tree_to_pillow_graph_rect_fill(tree_node):
pillow_image = export.tree_to_pillow_graph(
tree_node,
node_content="{node_name}\nAge: {age}",
rect_fill="beige",
)
if LOCAL:
pillow_image.save("tests/tree_pillow_graph_rect_fill.png")

@staticmethod
def test_tree_to_pillow_graph_rect_fill_cmap_error(tree_node):
with pytest.raises(ValueError) as exc_info:
export.tree_to_pillow_graph(
tree_node,
node_content="{node_name}\nAge: {age}",
rect_fill=mpl.colormaps["RdBu"],
)
assert str(exc_info.value) == Constants.ERROR_NODE_EXPORT_PILLOW_CMAP

@staticmethod
def test_tree_to_pillow_graph_rect_fill_cmap(tree_node):
pillow_image = export.tree_to_pillow_graph(
tree_node,
node_content="{node_name}\nAge: {age}",
rect_fill=mpl.colormaps["RdBu"],
rect_cmap_attr="age",
)
if LOCAL:
pillow_image.save("tests/tree_pillow_graph_rect_fill_cmap.png")

@staticmethod
def test_tree_to_pillow_graph_rect_width(tree_node):
pillow_image = export.tree_to_pillow_graph(
tree_node,
node_content="{node_name}\nAge: {age}",
rect_width=3,
)
if LOCAL:
pillow_image.save("tests/tree_pillow_graph_rect_width.png")


class TestTreeToPillow:
@staticmethod
Expand Down