Skip to content

Commit

Permalink
FIX Fix issue with setting model_diagram to False (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
BenjaminBossan authored Mar 13, 2023
1 parent e74dbf3 commit 6e4fde0
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 14 deletions.
58 changes: 44 additions & 14 deletions skops/card/_model_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,8 +380,18 @@ class Card:
``Path``/``str`` of the model or the actual model instance that will be
documented. If a ``Path`` or ``str`` is provided, model will be loaded.
model_diagram: bool, default=True
Set to True if model diagram should be plotted in the card.
model_diagram: bool or "auto" or str, default="auto"
If using the skops template, setting this to ``True`` or ``"auto"`` will
add the model diagram, as generated by sckit-learn, to the default
section, i.e "Model description/Training Procedure/Model Plot". Passing
a string to ``model_diagram`` will instead use that string as the
section name for the diagram. Set to ``False`` to not include the model
diagram.
If using a non-skops template, passing ``"auto"`` won't add the model
diagram because there is no pre-defined section to put it. The model
diagram can, however, always be added later using
:meth:`Card.add_model_plot`.
metadata: ModelCardData, optional
:class:`huggingface_hub.ModelCardData` object. The contents of this
Expand Down Expand Up @@ -481,27 +491,36 @@ class Card:
def __init__(
self,
model,
model_diagram: bool = True,
model_diagram: bool | Literal["auto"] | str = "auto",
metadata: ModelCardData | None = None,
template: Literal["skops"] | dict[str, str] | None = "skops",
trusted: bool = False,
) -> None:
self.model = model
self.model_diagram = model_diagram
self.metadata = metadata or ModelCardData()
self.template = template
self.trusted = trusted

self._data: dict[str, Section] = {}
self._metrics: dict[str, str | float | int] = {}

self._populate_template()
self._populate_template(model_diagram=model_diagram)

def _populate_template(self):
"""If initialized with a template, use it to populate the card."""
if not self.template:
return
def _populate_template(self, model_diagram: bool | Literal["auto"] | str):
"""If initialized with a template, use it to populate the card.
Parameters
----------
model_diagram: bool or "auto" or str
If using the default template, ``"auto"`` and ``True`` will add the
diagram in its default section. If using a custom template,
``"auto"`` will not add the diagram, and passing ``True`` will
result in an error. For either, passing ``False`` will result in the
model diagram being omitted, and passing a string (other than
``"auto"``) will put the model diagram into a section corresponding
to that string.
"""
if isinstance(self.template, str) and (self.template not in VALID_TEMPLATES):
valid_templates = ", ".join(f"'{val}'" for val in sorted(VALID_TEMPLATES))
msg = (
Expand All @@ -510,15 +529,29 @@ def _populate_template(self):
)
raise ValueError(msg)

# default template
if self.template == Templates.skops.value:
self.add(**SKOPS_TEMPLATE)
# for the skops template, automatically add some default sections
self.add_model_plot()
self.add_hyperparams()
self.add_get_started_code()
elif isinstance(self.template, Mapping):

if (model_diagram is True) or (model_diagram == "auto"):
self.add_model_plot()
elif isinstance(model_diagram, str):
self.add_model_plot(section=model_diagram)
return

# non-default template
if isinstance(self.template, Mapping):
self.add(**self.template)

if isinstance(model_diagram, str) and (model_diagram != "auto"):
self.add_model_plot(section=model_diagram)
elif model_diagram is True:
# will trigger an error
self.add_model_plot()

def get_model(self) -> Any:
"""Returns sklearn estimator object.
Expand Down Expand Up @@ -789,9 +822,6 @@ def add_model_plot(
self : object
Card object.
"""
if not self.model_diagram:
return self

if section is None:
if self.template == Templates.skops.value:
section = "Model description/Training Procedure/Model Plot"
Expand Down
61 changes: 61 additions & 0 deletions skops/card/tests/test_card.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,26 @@ def test_model_diagram_false(self):
).content
assert result == "The model plot is below."

def test_model_diagram_str(self):
# if passing a str, use that as the section name
model = fit_model()
other_section_name = "Here is the model diagram"
model_card = Card(model, model_diagram=other_section_name)

# first check that default section only contains placeholder
result = model_card.select(
"Model description/Training Procedure/Model Plot"
).format()
assert result == "The model plot is below."

# now check that the actual model diagram is in the other section
result = model_card.select(other_section_name).format()
assert result.startswith("The model plot is below.\n\n<style>#sk-")
assert "<style>" in result
assert result.endswith(
"<pre>LinearRegression()</pre></div></div></div></div></div>"
)

def test_other_section(self, model_card):
model_card.add_model_plot(section="Other section")
result = model_card.select("Other section").content
Expand Down Expand Up @@ -204,6 +224,47 @@ def test_custom_template_no_section_raises(self, template):
with pytest.raises(ValueError, match=msg):
model_card.add_model_plot()

@pytest.mark.parametrize("template", CUSTOM_TEMPLATES)
def test_custom_template_init_str_works(self, template):
model = fit_model()
section_name = "Here is the model diagram"
model_card = Card(model, template=template, model_diagram=section_name)

result = model_card.select(section_name).format()
assert result.startswith("<style>#sk-")
assert "<style>" in result
assert result.endswith(
"<pre>LinearRegression()</pre></div></div></div></div></div>"
)

def test_default_template_and_model_diagram_true(self, model_card):
# setting model_diagram=True should not change anything vs auto with the
# default template
model = fit_model()
model_card = Card(model, model_diagram=True)
result = model_card.select(
"Model description/Training Procedure/Model Plot"
).content
# don't compare whole text, as it's quite long and non-deterministic
assert result.startswith("The model plot is below.\n\n<style>#sk-")
assert "<style>" in result
assert result.endswith(
"<pre>LinearRegression()</pre></div></div></div></div></div>"
)

@pytest.mark.parametrize("template", CUSTOM_TEMPLATES)
def test_custom_template_and_model_diagram_true(self, model_card, template):
# in contrast to the previous test, when setting model_diagram=True but
# using a custom template, we expect an error during initialization of
# the model cord
model = fit_model()
msg = (
"You are trying to add a model plot but you're using a custom template, "
"please pass the 'section' argument to determine where to put the content"
)
with pytest.raises(ValueError, match=msg):
Card(model, template=template, model_diagram=True)

def test_add_twice(self, model_card):
# it's possible to add the section twice, even if it doesn't make a lot
# of sense
Expand Down

0 comments on commit 6e4fde0

Please sign in to comment.