From 2fb0d26847083461c125edd1001ec7bf76037758 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 13 Sep 2023 20:10:35 +1000 Subject: [PATCH 01/39] Update templates with new anchors --- src/dvc_render/vega_templates.py | 116 ++++++++++++++++--------------- tests/test_vega.py | 10 +-- 2 files changed, 64 insertions(+), 62 deletions(-) diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index 36e23d9..c102f02 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -159,10 +159,8 @@ class BarHorizontalSortedTemplate(Template): "sort": "-x", }, "yOffset": {"field": "rev"}, - "color": { - "field": "rev", - "type": "nominal", - }, + "color": Template.anchor("color"), + "row": Template.anchor("row"), }, } @@ -190,10 +188,8 @@ class BarHorizontalTemplate(Template): "title": Template.anchor("y_label"), }, "yOffset": {"field": "rev"}, - "color": { - "field": "rev", - "type": "nominal", - }, + "color": Template.anchor("color"), + "row": Template.anchor("row"), }, } @@ -443,10 +439,8 @@ class ScatterTemplate(Template): "type": "quantitative", "title": Template.anchor("y_label"), }, - "color": { - "field": "rev", - "type": "nominal", - }, + "color": Template.anchor("color"), + "shape": Template.anchor("shape"), }, } @@ -474,10 +468,8 @@ class ScatterJitterTemplate(Template): "field": Template.anchor("y"), "title": Template.anchor("y_label"), }, - "color": { - "field": "rev", - "type": "nominal", - }, + "color": Template.anchor("color"), + "shape": Template.anchor("shape"), "xOffset": {"field": "randomX", "type": "quantitative"}, "yOffset": {"field": "randomY", "type": "quantitative"}, }, @@ -504,15 +496,26 @@ class SmoothLinearTemplate(Template): }, }, ], + "encoding": { + "x": { + "field": Template.anchor("x"), + "type": "quantitative", + "title": Template.anchor("x_label"), + }, + "color": Template.anchor("color"), + "strokeDash": Template.anchor("stroke_dash"), + "shape": Template.anchor("shape"), + }, "layer": [ { - "mark": "line", - "encoding": { - "x": { - "field": Template.anchor("x"), - "type": "quantitative", - "title": Template.anchor("x_label"), + "layer": [ + {"mark": "line"}, + { + "transform": [{"filter": {"param": "hover", "empty": False}}], + "mark": "point", }, + ], + "encoding": { "y": { "field": Template.anchor("y"), "type": "quantitative", @@ -523,18 +526,6 @@ class SmoothLinearTemplate(Template): "field": "rev", "type": "nominal", }, - "tooltip": [ - { - "field": Template.anchor("x"), - "title": Template.anchor("x_label"), - "type": "quantitative", - }, - { - "field": Template.anchor("y"), - "title": Template.anchor("y_label"), - "type": "quantitative", - }, - ], }, "transform": [ { @@ -560,26 +551,10 @@ class SmoothLinearTemplate(Template): "scale": {"zero": False}, }, "color": {"field": "rev", "type": "nominal"}, - "tooltip": [ - { - "field": Template.anchor("x"), - "title": Template.anchor("x_label"), - "type": "quantitative", - }, - { - "field": Template.anchor("y"), - "title": Template.anchor("y_label"), - "type": "quantitative", - }, - ], }, }, { - "mark": { - "type": "circle", - "size": 10, - "tooltip": {"content": "encoding"}, - }, + "mark": {"type": "circle", "size": 10}, "encoding": { "x": { "aggregate": "max", @@ -597,6 +572,38 @@ class SmoothLinearTemplate(Template): "color": {"field": "rev", "type": "nominal"}, }, }, + { + "transform": [ + { + "pivot": Template.anchor("group_by"), + "value": Template.anchor("y"), + "groupby": [Template.anchor("x")], + } + ], + "mark": { + "type": "rule", + "tooltip": {"content": "data"}, + "stroke": "grey", + }, + "encoding": { + "opacity": { + "condition": {"value": 0.3, "param": "hover", "empty": False}, + "value": 0, + } + }, + "params": [ + { + "name": "hover", + "select": { + "type": "point", + "fields": [Template.anchor("x")], + "nearest": True, + "on": "mouseover", + "clear": "mouseout", + }, + } + ], + }, ], } @@ -630,10 +637,9 @@ class SimpleLinearTemplate(Template): "title": Template.anchor("y_label"), "scale": {"zero": False}, }, - "color": { - "field": "rev", - "type": "nominal", - }, + "color": Template.anchor("color"), + "strokeDash": Template.anchor("stroke_dash"), + "shape": Template.anchor("shape"), }, } diff --git a/tests/test_vega.py b/tests/test_vega.py index 0cdd3a9..5dcaf5b 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -44,15 +44,11 @@ def test_default_template_mark(): plot_content = VegaRenderer(datapoints, "foo").get_filled_template(as_string=False) - assert plot_content["layer"][0]["mark"] == "line" + assert plot_content["layer"][0]["layer"][0]["mark"] == "line" assert plot_content["layer"][1]["mark"] == {"type": "line", "opacity": 0.2} - assert plot_content["layer"][2]["mark"] == { - "type": "circle", - "size": 10, - "tooltip": {"content": "encoding"}, - } + assert plot_content["layer"][2]["mark"] == {"type": "circle", "size": 10} def test_choose_axes(): @@ -78,7 +74,7 @@ def test_choose_axes(): "second_val": 300, }, ] - assert plot_content["layer"][0]["encoding"]["x"]["field"] == "first_val" + assert plot_content["encoding"]["x"]["field"] == "first_val" assert plot_content["layer"][0]["encoding"]["y"]["field"] == "second_val" From dcafa805d9b4367ba8da4d25636b48631b5eff74 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 20 Sep 2023 14:00:33 +1000 Subject: [PATCH 02/39] prototype filling optional anchors --- src/dvc_render/vega.py | 177 ++++++++++++++++++++++++++++++- src/dvc_render/vega_templates.py | 7 +- 2 files changed, 180 insertions(+), 4 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index b10ab0a..5e7d95c 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -1,8 +1,9 @@ import base64 import io import json +from collections import defaultdict from pathlib import Path -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from warnings import warn from .base import Renderer @@ -38,6 +39,29 @@ def __init__(self, datapoints: List, name: str, **properties): self.properties.get("template", None), self.properties.get("template_dir", None), ) + self._optional_anchor_ranges: Dict[ + str, + Union[ + List[str], + List[List[int]], + ], + ] = { + "stroke_dash": [[1, 0], [8, 8], [8, 4], [4, 4], [4, 2], [2, 1], [1, 1]], + "color": [ + "#945dd6", + "#13adc7", + "#f46837", + "#48bb78", + "#4299e1", + "#ed8936", + "#f56565", + ], + "shape": ["square", "circle", "triangle", "diamond"], + } + self._optional_anchor_values: Dict[ + str, + Dict[str, Dict[str, str]], + ] = defaultdict() def get_filled_template( self, @@ -85,6 +109,8 @@ def get_filled_template( value = self.template.escape_special_characters(value) self.template.fill_anchor(name, value) + self._fill_optional_anchors(skip_anchors) + if as_string: return json.dumps(self.template.content) @@ -137,3 +163,152 @@ def generate_markdown(self, report_path=None) -> str: return f"\n![{self.name}]({src})" return "" + + def _fill_optional_anchors(self, skip_anchors: List[str]): + optional_anchors = [ + anchor + for anchor in [ + "row", + "group_by", + "pivot_field", + "color", + "stroke_dash", + "shape", + ] + if anchor not in skip_anchors and self.template.has_anchor(anchor) + ] + if not optional_anchors: + return + + self._fill_color(optional_anchors) + + if not optional_anchors: + return + + y_defn = self.properties.get("anchors_y_defn", []) + + if len(y_defn) <= 1: + self._fill_optional_anchor(optional_anchors, "group_by", ["rev"]) + self._fill_optional_anchor(optional_anchors, "pivot_field", "rev") + for anchor in optional_anchors: + self.template.fill_anchor(anchor, {}) + self._update_datapoints(to_remove=["filename", "file"]) + return + + keys, variations = self._collect_variations(y_defn) + grouped_keys = ["rev", *keys] + self._fill_optional_anchor(optional_anchors, "group_by", grouped_keys) + self._fill_optional_anchor( + optional_anchors, "pivot_field", "::".join(grouped_keys) + ) + # concatenate grouped_keys together + self._fill_optional_anchor(optional_anchors, "row", {"field": "::".join(keys)}) + + if not optional_anchors: + return + + if len(keys) == 2: + self._update_datapoints( + to_remove=["filename", "file"], to_concatenate=[["filename", "file"]] + ) + domain = ["::".join([d.get("filename"), d.get("file")]) for d in y_defn] + else: + filenameOrField = keys[0] + to_remove = ["filename", "file"] + to_remove.remove(filenameOrField) + self._update_datapoints(to_remove=to_remove) + + domain = list(variations[filenameOrField]) + + stroke_dash_scale = self._set_optional_anchor_scale( + optional_anchors, "stroke_dash", domain + ) + self._fill_optional_anchor(optional_anchors, "stroke_dash", stroke_dash_scale) + + shape_scale = self._set_optional_anchor_scale(optional_anchors, "shape", domain) + self._fill_optional_anchor(optional_anchors, "shape", shape_scale) + + def _fill_color(self, optional_anchors: List[str]): + all_revs = self.properties.get("anchor_revs", []) + self._fill_optional_anchor( + optional_anchors, + "color", + { + "scale": { + "domain": list(all_revs), + "range": self._optional_anchor_ranges.get("color", [])[ + : len(all_revs) + ], + } + }, + ) + + def _collect_variations( + self, y_defn: List[Dict[str, str]] + ) -> Tuple[List[str], Dict[str, set]]: + variations = defaultdict(set) + for defn in y_defn: + for key in ["filename", "field"]: + variations[key].add(defn.get(key, None)) + + valuesMatchVariations = [] + lessValuesThanVariations = [] + + for filenameOrField, valueSet in variations.items(): + num_values = len(valueSet) + if num_values == 1: + continue + if num_values == len(y_defn): + valuesMatchVariations.append(filenameOrField) + continue + lessValuesThanVariations.append(filenameOrField) + + if valuesMatchVariations: + valuesMatchVariations.extend(lessValuesThanVariations) + valuesMatchVariations.sort(reverse=True) + return valuesMatchVariations, variations + + lessValuesThanVariations.sort(reverse=True) + return lessValuesThanVariations, variations + + def _fill_optional_anchor(self, optional_anchors: List[str], name: str, value: Any): + if name not in optional_anchors: + return + + optional_anchors.remove(name) + self.template.fill_anchor(name, value) + + def _set_optional_anchor_scale( + self, optional_anchors: List[str], name: str, domain: List[str] + ): + if name not in optional_anchors: + return {"scale": {"domain": [], "range": []}} + + full_range_values: List[Any] = self._optional_anchor_ranges.get(name, []) + anchor_range_values = full_range_values.copy() + anchor_range = [] + + for domain_value in domain: + if not anchor_range_values: + anchor_range_values = full_range_values.copy() + range_value = anchor_range_values.pop() + self._optional_anchor_values[name][domain_value] = range_value + anchor_range.append(range_value) + + return {"scale": {"domain": domain, "range": anchor_range}} + + def _update_datapoints( + self, + to_remove: Optional[List[str]] = None, + to_concatenate: Optional[List[List[str]]] = None, + ): + if to_concatenate: + for datapoint in self.datapoints: + for keys in to_concatenate: + concat_key = "::".join(keys) + datapoint[concat_key] = "::".join([datapoint.get(k) for k in keys]) + + if to_remove: + for datapoint in self.datapoints: + for concat_key in to_remove: + datapoint.pop(concat_key, None) diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index c102f02..3b08d52 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -531,7 +531,7 @@ class SmoothLinearTemplate(Template): { "loess": Template.anchor("y"), "on": Template.anchor("x"), - "groupby": ["rev", "filename", "field", "filename::field"], + "groupby": Template.anchor("group_by"), "bandwidth": {"signal": "smooth"}, }, ], @@ -574,11 +574,12 @@ class SmoothLinearTemplate(Template): }, { "transform": [ + {"calculate": Template.anchor("pivot_field"), "as": "pivot_field"}, { - "pivot": Template.anchor("group_by"), + "pivot": "pivot_field", "value": Template.anchor("y"), "groupby": [Template.anchor("x")], - } + }, ], "mark": { "type": "rule", From c9f1d8ad9d92cd2a65ed5f7df9691f73d081e185 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 20 Sep 2023 15:44:45 +1000 Subject: [PATCH 03/39] update based on quick manual test --- src/dvc_render/vega.py | 61 ++++++++++++++++++-------------- src/dvc_render/vega_templates.py | 2 -- 2 files changed, 34 insertions(+), 29 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 5e7d95c..0edafb7 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -61,7 +61,7 @@ def __init__(self, datapoints: List, name: str, **properties): self._optional_anchor_values: Dict[ str, Dict[str, Dict[str, str]], - ] = defaultdict() + ] = defaultdict(dict) def get_filled_template( self, @@ -91,6 +91,8 @@ def get_filled_template( self.properties.setdefault("y_label", self.properties.get("y")) self.properties.setdefault("data", self.datapoints) + self._fill_optional_anchors(skip_anchors) + names = ["title", "x", "y", "x_label", "y_label", "data"] for name in names: if name in skip_anchors: @@ -109,8 +111,6 @@ def get_filled_template( value = self.template.escape_special_characters(value) self.template.fill_anchor(name, value) - self._fill_optional_anchors(skip_anchors) - if as_string: return json.dumps(self.template.content) @@ -192,40 +192,45 @@ def _fill_optional_anchors(self, skip_anchors: List[str]): self._fill_optional_anchor(optional_anchors, "pivot_field", "rev") for anchor in optional_anchors: self.template.fill_anchor(anchor, {}) - self._update_datapoints(to_remove=["filename", "file"]) + self._update_datapoints(to_remove=["filename", "field"]) return keys, variations = self._collect_variations(y_defn) grouped_keys = ["rev", *keys] + concat_field = "::".join(keys) self._fill_optional_anchor(optional_anchors, "group_by", grouped_keys) self._fill_optional_anchor( - optional_anchors, "pivot_field", "::".join(grouped_keys) + optional_anchors, + "pivot_field", + "+ '::'".join([f"datum.{key}" for key in grouped_keys]), ) # concatenate grouped_keys together - self._fill_optional_anchor(optional_anchors, "row", {"field": "::".join(keys)}) + self._fill_optional_anchor(optional_anchors, "row", {"field": concat_field}) if not optional_anchors: return if len(keys) == 2: self._update_datapoints( - to_remove=["filename", "file"], to_concatenate=[["filename", "file"]] + to_remove=["filename", "field"], to_concatenate=[["filename", "field"]] ) - domain = ["::".join([d.get("filename"), d.get("file")]) for d in y_defn] + domain = ["::".join([d.get("filename"), d.get("field")]) for d in y_defn] else: filenameOrField = keys[0] - to_remove = ["filename", "file"] + to_remove = ["filename", "field"] to_remove.remove(filenameOrField) self._update_datapoints(to_remove=to_remove) domain = list(variations[filenameOrField]) stroke_dash_scale = self._set_optional_anchor_scale( - optional_anchors, "stroke_dash", domain + optional_anchors, concat_field, "stroke_dash", domain ) self._fill_optional_anchor(optional_anchors, "stroke_dash", stroke_dash_scale) - shape_scale = self._set_optional_anchor_scale(optional_anchors, "shape", domain) + shape_scale = self._set_optional_anchor_scale( + optional_anchors, concat_field, "shape", domain + ) self._fill_optional_anchor(optional_anchors, "shape", shape_scale) def _fill_color(self, optional_anchors: List[str]): @@ -234,12 +239,13 @@ def _fill_color(self, optional_anchors: List[str]): optional_anchors, "color", { + "field": "rev", "scale": { "domain": list(all_revs), "range": self._optional_anchor_ranges.get("color", [])[ : len(all_revs) ], - } + }, }, ) @@ -279,10 +285,10 @@ def _fill_optional_anchor(self, optional_anchors: List[str], name: str, value: A self.template.fill_anchor(name, value) def _set_optional_anchor_scale( - self, optional_anchors: List[str], name: str, domain: List[str] + self, optional_anchors: List[str], field: str, name: str, domain: List[str] ): if name not in optional_anchors: - return {"scale": {"domain": [], "range": []}} + return {"field": field, "scale": {"domain": [], "range": []}} full_range_values: List[Any] = self._optional_anchor_ranges.get(name, []) anchor_range_values = full_range_values.copy() @@ -291,24 +297,25 @@ def _set_optional_anchor_scale( for domain_value in domain: if not anchor_range_values: anchor_range_values = full_range_values.copy() - range_value = anchor_range_values.pop() + range_value = anchor_range_values.pop(0) self._optional_anchor_values[name][domain_value] = range_value anchor_range.append(range_value) - return {"scale": {"domain": domain, "range": anchor_range}} + return {"field": field, "scale": {"domain": domain, "range": anchor_range}} def _update_datapoints( self, - to_remove: Optional[List[str]] = None, to_concatenate: Optional[List[List[str]]] = None, + to_remove: Optional[List[str]] = None, ): - if to_concatenate: - for datapoint in self.datapoints: - for keys in to_concatenate: - concat_key = "::".join(keys) - datapoint[concat_key] = "::".join([datapoint.get(k) for k in keys]) - - if to_remove: - for datapoint in self.datapoints: - for concat_key in to_remove: - datapoint.pop(concat_key, None) + if to_concatenate is None: + to_concatenate = [] + if to_remove is None: + to_remove = [] + + for datapoint in self.datapoints: + for keys in to_concatenate: + concat_key = "::".join(keys) + datapoint[concat_key] = "::".join([datapoint.get(k) for k in keys]) + for key in to_remove: + datapoint.pop(key, None) diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index 3b08d52..909a331 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -504,7 +504,6 @@ class SmoothLinearTemplate(Template): }, "color": Template.anchor("color"), "strokeDash": Template.anchor("stroke_dash"), - "shape": Template.anchor("shape"), }, "layer": [ { @@ -640,7 +639,6 @@ class SimpleLinearTemplate(Template): }, "color": Template.anchor("color"), "strokeDash": Template.anchor("stroke_dash"), - "shape": Template.anchor("shape"), }, } From bad457f3effa77db0606b9c7010562dac2f12285 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 20 Sep 2023 20:26:24 +1000 Subject: [PATCH 04/39] be dumb and refactor before adding tests --- src/dvc_render/vega.py | 112 +++++++++++++++++++++++++---------------- 1 file changed, 69 insertions(+), 43 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 0edafb7..bbe6a9c 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -91,7 +91,7 @@ def get_filled_template( self.properties.setdefault("y_label", self.properties.get("y")) self.properties.setdefault("data", self.datapoints) - self._fill_optional_anchors(skip_anchors) + self._process_optional_anchors(skip_anchors) names = ["title", "x", "y", "x_label", "y_label", "data"] for name in names: @@ -164,7 +164,7 @@ def generate_markdown(self, report_path=None) -> str: return "" - def _fill_optional_anchors(self, skip_anchors: List[str]): + def _process_optional_anchors(self, skip_anchors: List[str]): optional_anchors = [ anchor for anchor in [ @@ -175,67 +175,79 @@ def _fill_optional_anchors(self, skip_anchors: List[str]): "stroke_dash", "shape", ] - if anchor not in skip_anchors and self.template.has_anchor(anchor) + if self.template.has_anchor(anchor) ] - if not optional_anchors: - return + if optional_anchors: + # split varied_keys out from _fill_optional_anchors to avoid bugs + # but first.... tests + varied_keys = self._fill_optional_anchors(skip_anchors, optional_anchors) + self._update_datapoints(varied_keys) - self._fill_color(optional_anchors) + def _fill_optional_anchors( + self, skip_anchors: List[str], optional_anchors: List[str] + ) -> List[str]: + self._fill_color(skip_anchors, optional_anchors) if not optional_anchors: - return + return [] y_defn = self.properties.get("anchors_y_defn", []) if len(y_defn) <= 1: - self._fill_optional_anchor(optional_anchors, "group_by", ["rev"]) - self._fill_optional_anchor(optional_anchors, "pivot_field", "rev") + self._fill_optional_anchor( + skip_anchors, optional_anchors, "group_by", ["rev"] + ) + self._fill_optional_anchor( + skip_anchors, optional_anchors, "pivot_field", "rev" + ) for anchor in optional_anchors: self.template.fill_anchor(anchor, {}) - self._update_datapoints(to_remove=["filename", "field"]) - return + return [] - keys, variations = self._collect_variations(y_defn) - grouped_keys = ["rev", *keys] - concat_field = "::".join(keys) - self._fill_optional_anchor(optional_anchors, "group_by", grouped_keys) + varied_keys, variations = self._collect_variations(y_defn) + grouped_keys = ["rev", *varied_keys] + concat_field = "::".join(varied_keys) self._fill_optional_anchor( + skip_anchors, optional_anchors, "group_by", grouped_keys + ) + self._fill_optional_anchor( + skip_anchors, optional_anchors, "pivot_field", - "+ '::'".join([f"datum.{key}" for key in grouped_keys]), + " + '::' + ".join([f"datum.{key}" for key in grouped_keys]), ) # concatenate grouped_keys together - self._fill_optional_anchor(optional_anchors, "row", {"field": concat_field}) + self._fill_optional_anchor( + skip_anchors, optional_anchors, "row", {"field": concat_field} + ) if not optional_anchors: - return + return varied_keys - if len(keys) == 2: - self._update_datapoints( - to_remove=["filename", "field"], to_concatenate=[["filename", "field"]] - ) + if len(varied_keys) == 2: domain = ["::".join([d.get("filename"), d.get("field")]) for d in y_defn] else: - filenameOrField = keys[0] - to_remove = ["filename", "field"] - to_remove.remove(filenameOrField) - self._update_datapoints(to_remove=to_remove) - + filenameOrField = varied_keys[0] domain = list(variations[filenameOrField]) stroke_dash_scale = self._set_optional_anchor_scale( optional_anchors, concat_field, "stroke_dash", domain ) - self._fill_optional_anchor(optional_anchors, "stroke_dash", stroke_dash_scale) + self._fill_optional_anchor( + skip_anchors, optional_anchors, "stroke_dash", stroke_dash_scale + ) shape_scale = self._set_optional_anchor_scale( optional_anchors, concat_field, "shape", domain ) - self._fill_optional_anchor(optional_anchors, "shape", shape_scale) + self._fill_optional_anchor(skip_anchors, optional_anchors, "shape", shape_scale) - def _fill_color(self, optional_anchors: List[str]): + return varied_keys + + def _fill_color(self, skip_anchors: List[str], optional_anchors: List[str]): all_revs = self.properties.get("anchor_revs", []) self._fill_optional_anchor( + skip_anchors, optional_anchors, "color", { @@ -277,11 +289,21 @@ def _collect_variations( lessValuesThanVariations.sort(reverse=True) return lessValuesThanVariations, variations - def _fill_optional_anchor(self, optional_anchors: List[str], name: str, value: Any): + def _fill_optional_anchor( + self, + skip_anchors: List[str], + optional_anchors: List[str], + name: str, + value: Any, + ): if name not in optional_anchors: return optional_anchors.remove(name) + + if name in skip_anchors: + return + self.template.fill_anchor(name, value) def _set_optional_anchor_scale( @@ -301,21 +323,25 @@ def _set_optional_anchor_scale( self._optional_anchor_values[name][domain_value] = range_value anchor_range.append(range_value) - return {"field": field, "scale": {"domain": domain, "range": anchor_range}} + return { + "field": field, + "scale": {"domain": domain, "range": anchor_range}, + "legend": {"symbolFillColor": "transparent", "symbolStrokeColor": "grey"}, + } - def _update_datapoints( - self, - to_concatenate: Optional[List[List[str]]] = None, - to_remove: Optional[List[str]] = None, - ): - if to_concatenate is None: + def _update_datapoints(self, varied_keys: List[str]): + if len(varied_keys) == 2: + to_concatenate = varied_keys + to_remove = varied_keys + else: to_concatenate = [] - if to_remove is None: - to_remove = [] + to_remove = [key for key in ["filename", "field"] if key not in varied_keys] for datapoint in self.datapoints: - for keys in to_concatenate: - concat_key = "::".join(keys) - datapoint[concat_key] = "::".join([datapoint.get(k) for k in keys]) + if to_concatenate: + concat_key = "::".join(to_concatenate) + datapoint[concat_key] = "::".join( + [datapoint.get(k) for k in to_concatenate] + ) for key in to_remove: datapoint.pop(key, None) From 143288d619fca76d97a80291d92498b7bb2c9e82 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Thu, 21 Sep 2023 11:03:06 +1000 Subject: [PATCH 05/39] add tests for linear/smooth template --- src/dvc_render/vega.py | 24 ++-- tests/test_vega.py | 258 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 271 insertions(+), 11 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index bbe6a9c..9e35dea 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -198,7 +198,7 @@ def _fill_optional_anchors( skip_anchors, optional_anchors, "group_by", ["rev"] ) self._fill_optional_anchor( - skip_anchors, optional_anchors, "pivot_field", "rev" + skip_anchors, optional_anchors, "pivot_field", "datum.rev" ) for anchor in optional_anchors: self.template.fill_anchor(anchor, {}) @@ -230,6 +230,8 @@ def _fill_optional_anchors( filenameOrField = varied_keys[0] domain = list(variations[filenameOrField]) + domain.sort() + stroke_dash_scale = self._set_optional_anchor_scale( optional_anchors, concat_field, "stroke_dash", domain ) @@ -269,25 +271,25 @@ def _collect_variations( for key in ["filename", "field"]: variations[key].add(defn.get(key, None)) - valuesMatchVariations = [] - lessValuesThanVariations = [] + values_match_variations = [] + less_values_than_variations = [] for filenameOrField, valueSet in variations.items(): num_values = len(valueSet) if num_values == 1: continue if num_values == len(y_defn): - valuesMatchVariations.append(filenameOrField) + values_match_variations.append(filenameOrField) continue - lessValuesThanVariations.append(filenameOrField) + less_values_than_variations.append(filenameOrField) - if valuesMatchVariations: - valuesMatchVariations.extend(lessValuesThanVariations) - valuesMatchVariations.sort(reverse=True) - return valuesMatchVariations, variations + if values_match_variations: + values_match_variations.extend(less_values_than_variations) + values_match_variations.sort(reverse=True) + return values_match_variations, variations - lessValuesThanVariations.sort(reverse=True) - return lessValuesThanVariations, variations + less_values_than_variations.sort(reverse=True) + return less_values_than_variations, variations def _fill_optional_anchor( self, diff --git a/tests/test_vega.py b/tests/test_vega.py index 5dcaf5b..17cbd93 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -263,3 +263,261 @@ def test_fill_anchor_in_string(tmp_dir): assert filled["transform"][1]["calculate"] == "pow(datum.lab - datum.SR,2)" assert filled["encoding"]["x"]["field"] == x assert filled["encoding"]["y"]["field"] == y + +@pytest.mark.parametrize( + ",".join( + [ + "datapoints", + "y", + "anchors_y_defn", + "expected_dp_keys", + "color_encoding", + "stroke_dash_encoding", + "pivot_field", + "group_by", + ] + ), + ( + ( + [ + { + "rev": "B", + "acc": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + ], + "acc", + [{"filename": "test", "field": "acc"}], + ["rev", "acc", "step"], + { + "field": "rev", + "scale": {"domain": ["B"], "range": ["#945dd6"]}, + }, + {}, + "datum.rev", + ["rev"], + ), + ( + [ + { + "rev": "B", + "acc": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "acc": "0.04", + "filename": "train", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.09", + "filename": "train", + "field": "acc", + "step": 2, + }, + ], + "acc", + [ + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], + ["rev", "acc", "step", "filename"], + { + "field": "rev", + "scale": {"domain": ["B"], "range": ["#945dd6"]}, + }, + { + "field": "filename", + "scale": {"domain": ["test", "train"], "range": [[1, 0], [8, 8]]}, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + "datum.rev + '::' + datum.filename", + ["rev", "filename"], + ), + ( + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.04", + "filename": "train", + "field": "acc_norm", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.09", + "filename": "test", + "field": "acc_norm", + "step": 2, + }, + ], + "dvc_inferred_y_value", + [ + {"filename": "test", "field": "acc"}, + {"filename": "test", "field": "acc_norm"}, + ], + ["rev", "dvc_inferred_y_value", "step", "field"], + { + "field": "rev", + "scale": {"domain": ["B"], "range": ["#945dd6"]}, + }, + { + "field": "field", + "scale": {"domain": ["acc", "acc_norm"], "range": [[1, 0], [8, 8]]}, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + "datum.rev + '::' + datum.field", + ["rev", "field"], + ), + ( + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.04", + "filename": "train", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.09", + "filename": "train", + "field": "acc", + "step": 2, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.02", + "filename": "test", + "field": "acc_norm", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.07", + "filename": "test", + "field": "acc_norm", + "step": 2, + }, + ], + "dvc_inferred_y_value", + [ + {"filename": "test", "field": "acc_norm"}, + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], + ["rev", "dvc_inferred_y_value", "step", "filename::field"], + { + "field": "rev", + "scale": {"domain": ["B"], "range": ["#945dd6"]}, + }, + { + "field": "filename::field", + "scale": { + "domain": ["test::acc", "test::acc_norm", "train::acc"], + "range": [[1, 0], [8, 8], [8, 4]], + }, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + "datum.rev + '::' + datum.filename + '::' + datum.field", + ["rev", "filename", "field"], + ), + ), +) +def test_optional_anchors_linear( + datapoints, + y, + anchors_y_defn, + expected_dp_keys, + color_encoding, + stroke_dash_encoding, + pivot_field, + group_by, +): # pylint: disable=too-many-arguments + props = { + "template": "linear", + "x": "step", + "y": y, + "anchor_revs": ["B"], + "anchors_y_defn": anchors_y_defn, + } + + expected_datapoints = [] + for datapoint in datapoints: + expected_datapoint = {} + for key in expected_dp_keys: + if key == "filename::field": + expected_datapoint[ + key + ] = f"{datapoint['filename']}::{datapoint['field']}" + else: + expected_datapoint[key] = datapoint.get(key) + expected_datapoints.append(expected_datapoint) + + plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template( + as_string=False + ) + + assert plot_content["data"]["values"] == expected_datapoints + assert plot_content["encoding"]["color"] == color_encoding + assert plot_content["encoding"]["strokeDash"] == stroke_dash_encoding + assert plot_content["layer"][3]["transform"][0]["calculate"] == pivot_field + assert plot_content["layer"][0]["transform"][0]["groupby"] == group_by From ba7826dcc859d3e2441df42e69ce3aa36c0120be Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Thu, 21 Sep 2023 20:06:11 +1000 Subject: [PATCH 06/39] refactor --- src/dvc_render/vega.py | 210 ++++++++++++++++++++++++++--------------- tests/test_vega.py | 170 ++++++++++++++++++++++++++++++--- 2 files changed, 292 insertions(+), 88 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 9e35dea..d2a66dd 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -8,7 +8,7 @@ from .base import Renderer from .utils import list_dict_to_dict_list -from .vega_templates import BadTemplateError, LinearTemplate, get_template +from .vega_templates import BadTemplateError, LinearTemplate, Template, get_template class VegaRenderer(Renderer): @@ -58,14 +58,12 @@ def __init__(self, datapoints: List, name: str, **properties): ], "shape": ["square", "circle", "triangle", "diamond"], } - self._optional_anchor_values: Dict[ - str, - Dict[str, Dict[str, str]], - ] = defaultdict(dict) + + self._split_content: Dict[str, Any] = {} def get_filled_template( self, - skip_anchors: Optional[List[str]] = None, + split_anchors: Optional[List[str]] = None, strict: bool = True, as_string: bool = True, ) -> Union[str, Dict[str, Any]]: @@ -74,8 +72,8 @@ def get_filled_template( if not self.datapoints: return {} - if skip_anchors is None: - skip_anchors = [] + if split_anchors is None: + split_anchors = [] if strict: if self.properties.get("x"): @@ -91,15 +89,18 @@ def get_filled_template( self.properties.setdefault("y_label", self.properties.get("y")) self.properties.setdefault("data", self.datapoints) - self._process_optional_anchors(skip_anchors) + self._process_optional_anchors(split_anchors) names = ["title", "x", "y", "x_label", "y_label", "data"] for name in names: - if name in skip_anchors: - continue value = self.properties.get(name) if value is None: continue + + if name in split_anchors: + self._set_split_content(name, value) + continue + if name == "data": if not self.template.has_anchor(name): anchor = self.template.anchor(name) @@ -116,6 +117,15 @@ def get_filled_template( return self.template.content + def get_partial_filled_template(self): + """ + Returns a partially filled template along with the split out anchor content + """ + content = self.get_filled_template( + split_anchors=["data", "color", "stroke_dash", "shape"], strict=True + ) + return content, self._split_content + def partial_html(self, **kwargs) -> str: return self.get_filled_template() # type: ignore @@ -164,7 +174,7 @@ def generate_markdown(self, report_path=None) -> str: return "" - def _process_optional_anchors(self, skip_anchors: List[str]): + def _process_optional_anchors(self, split_anchors: List[str]): optional_anchors = [ anchor for anchor in [ @@ -177,79 +187,85 @@ def _process_optional_anchors(self, skip_anchors: List[str]): ] if self.template.has_anchor(anchor) ] - if optional_anchors: - # split varied_keys out from _fill_optional_anchors to avoid bugs - # but first.... tests - varied_keys = self._fill_optional_anchors(skip_anchors, optional_anchors) - self._update_datapoints(varied_keys) - - def _fill_optional_anchors( - self, skip_anchors: List[str], optional_anchors: List[str] - ) -> List[str]: - self._fill_color(skip_anchors, optional_anchors) - if not optional_anchors: - return [] + return y_defn = self.properties.get("anchors_y_defn", []) + is_single_source = len(y_defn) <= 1 - if len(y_defn) <= 1: - self._fill_optional_anchor( - skip_anchors, optional_anchors, "group_by", ["rev"] - ) - self._fill_optional_anchor( - skip_anchors, optional_anchors, "pivot_field", "datum.rev" - ) - for anchor in optional_anchors: - self.template.fill_anchor(anchor, {}) - return [] + if is_single_source: + self._process_single_source_plot(split_anchors, optional_anchors) + return + + self._process_multi_source_plot(split_anchors, optional_anchors, y_defn) + + def _process_single_source_plot( + self, split_anchors: List[str], optional_anchors: List[str] + ): + self._fill_color(split_anchors, optional_anchors) + self._fill_optional_anchor(split_anchors, optional_anchors, "group_by", ["rev"]) + self._fill_optional_anchor( + split_anchors, optional_anchors, "pivot_field", "datum.rev" + ) + for anchor in optional_anchors: + self.template.fill_anchor(anchor, {}) + + self._update_datapoints([]) + + def _process_multi_source_plot( + self, + split_anchors: List[str], + optional_anchors: List[str], + y_defn: List[Dict[str, str]], + ): + varied_keys, varied_values = self._collect_variations(y_defn) + domain = self._get_domain(varied_keys, varied_values, y_defn) + + self._fill_optional_multi_source_anchors( + split_anchors, optional_anchors, varied_keys, domain + ) + self._update_datapoints(varied_keys) + + def _fill_optional_multi_source_anchors( + self, + split_anchors: List[str], + optional_anchors: List[str], + varied_keys: List[str], + domain: List[str], + ): + self._fill_color(split_anchors, optional_anchors) + + if not optional_anchors: + return - varied_keys, variations = self._collect_variations(y_defn) grouped_keys = ["rev", *varied_keys] - concat_field = "::".join(varied_keys) self._fill_optional_anchor( - skip_anchors, optional_anchors, "group_by", grouped_keys + split_anchors, optional_anchors, "group_by", grouped_keys ) self._fill_optional_anchor( - skip_anchors, + split_anchors, optional_anchors, "pivot_field", " + '::' + ".join([f"datum.{key}" for key in grouped_keys]), ) - # concatenate grouped_keys together - self._fill_optional_anchor( - skip_anchors, optional_anchors, "row", {"field": concat_field} - ) - - if not optional_anchors: - return varied_keys - - if len(varied_keys) == 2: - domain = ["::".join([d.get("filename"), d.get("field")]) for d in y_defn] - else: - filenameOrField = varied_keys[0] - domain = list(variations[filenameOrField]) - domain.sort() - - stroke_dash_scale = self._set_optional_anchor_scale( - optional_anchors, concat_field, "stroke_dash", domain - ) + concat_field = "::".join(varied_keys) self._fill_optional_anchor( - skip_anchors, optional_anchors, "stroke_dash", stroke_dash_scale + split_anchors, optional_anchors, "row", {"field": concat_field} ) - shape_scale = self._set_optional_anchor_scale( - optional_anchors, concat_field, "shape", domain - ) - self._fill_optional_anchor(skip_anchors, optional_anchors, "shape", shape_scale) + if not optional_anchors: + return - return varied_keys + for field in ["stroke_dash", "shape"]: + self._fill_optional_anchor_mapping( + split_anchors, optional_anchors, concat_field, field, domain + ) - def _fill_color(self, skip_anchors: List[str], optional_anchors: List[str]): + def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): all_revs = self.properties.get("anchor_revs", []) self._fill_optional_anchor( - skip_anchors, + split_anchors, optional_anchors, "color", { @@ -266,15 +282,15 @@ def _fill_color(self, skip_anchors: List[str], optional_anchors: List[str]): def _collect_variations( self, y_defn: List[Dict[str, str]] ) -> Tuple[List[str], Dict[str, set]]: - variations = defaultdict(set) + varied_values = defaultdict(set) for defn in y_defn: for key in ["filename", "field"]: - variations[key].add(defn.get(key, None)) + varied_values[key].add(defn.get(key, None)) values_match_variations = [] less_values_than_variations = [] - for filenameOrField, valueSet in variations.items(): + for filenameOrField, valueSet in varied_values.items(): num_values = len(valueSet) if num_values == 1: continue @@ -286,14 +302,14 @@ def _collect_variations( if values_match_variations: values_match_variations.extend(less_values_than_variations) values_match_variations.sort(reverse=True) - return values_match_variations, variations + return values_match_variations, varied_values less_values_than_variations.sort(reverse=True) - return less_values_than_variations, variations + return less_values_than_variations, varied_values def _fill_optional_anchor( self, - skip_anchors: List[str], + split_anchors: List[str], optional_anchors: List[str], name: str, value: Any, @@ -303,26 +319,63 @@ def _fill_optional_anchor( optional_anchors.remove(name) - if name in skip_anchors: + if name in split_anchors: return self.template.fill_anchor(name, value) - def _set_optional_anchor_scale( - self, optional_anchors: List[str], field: str, name: str, domain: List[str] + def _get_domain( + self, + varied_keys: List[str], + varied_values: Dict[str, set], + y_defn: List[Dict[str, str]], ): + if len(varied_keys) == 2: + domain = [ + "::".join([d.get("filename", ""), d.get("field", "")]) for d in y_defn + ] + else: + filenameOrField = varied_keys[0] + domain = list(varied_values[filenameOrField]) + + domain.sort() + return domain + + def _fill_optional_anchor_mapping( + self, + split_anchors: List[str], + optional_anchors: List[str], + field: str, + name: str, + domain: List[str], + ): # pylint: disable=too-many-arguments if name not in optional_anchors: - return {"field": field, "scale": {"domain": [], "range": []}} + return + + optional_anchors.remove(name) + + encoding = self._get_optional_anchor_mapping(field, name, domain) + if name in split_anchors: + self._set_split_content(name, encoding) + return + + self.template.fill_anchor(name, encoding) + + def _get_optional_anchor_mapping( + self, + field: str, + name: str, + domain: List[str], + ): full_range_values: List[Any] = self._optional_anchor_ranges.get(name, []) anchor_range_values = full_range_values.copy() - anchor_range = [] - for domain_value in domain: + anchor_range = [] + for _ in range(len(domain)): if not anchor_range_values: anchor_range_values = full_range_values.copy() range_value = anchor_range_values.pop(0) - self._optional_anchor_values[name][domain_value] = range_value anchor_range.append(range_value) return { @@ -347,3 +400,6 @@ def _update_datapoints(self, varied_keys: List[str]): ) for key in to_remove: datapoint.pop(key, None) + + def _set_split_content(self, name: str, value: Any): + self._split_content[Template.anchor(name)] = value diff --git a/tests/test_vega.py b/tests/test_vega.py index 17cbd93..a9adcb1 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -1,4 +1,5 @@ import json +from typing import Any, Dict, List import pytest @@ -500,7 +501,159 @@ def test_optional_anchors_linear( "anchors_y_defn": anchors_y_defn, } - expected_datapoints = [] + expected_datapoints = _get_expected_datapoints(datapoints, expected_dp_keys) + + renderer = VegaRenderer(datapoints, "foo", **props) + plot_content = renderer.get_filled_template(as_string=False) + + assert plot_content["data"]["values"] == expected_datapoints + assert plot_content["encoding"]["color"] == color_encoding + assert plot_content["encoding"]["strokeDash"] == stroke_dash_encoding + assert plot_content["layer"][3]["transform"][0]["calculate"] == pivot_field + assert plot_content["layer"][0]["transform"][0]["groupby"] == group_by + + +@pytest.mark.parametrize( + "datapoints,y,anchors_y_defn,expected_dp_keys,stroke_dash_encoding", + ( + ( + [ + { + "rev": "B", + "acc": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.1", + "filename": "test", + "field": "acc", + "step": 2, + }, + ], + "acc", + [{"filename": "test", "field": "acc"}], + ["rev", "acc", "step"], + {}, + ), + ( + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.04", + "filename": "train", + "field": "acc_norm", + "step": 1, + }, + ], + "dvc_inferred_y_value", + [ + {"filename": "test", "field": "acc"}, + {"filename": "test", "field": "acc_norm"}, + ], + ["rev", "dvc_inferred_y_value", "step", "field"], + { + "field": "field", + "scale": {"domain": ["acc", "acc_norm"], "range": [[1, 0], [8, 8]]}, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + ), + ( + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.04", + "filename": "train", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.02", + "filename": "test", + "field": "acc_norm", + "step": 1, + }, + ], + "dvc_inferred_y_value", + [ + {"filename": "test", "field": "acc_norm"}, + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], + ["rev", "dvc_inferred_y_value", "step", "filename::field"], + { + "field": "filename::field", + "scale": { + "domain": ["test::acc", "test::acc_norm", "train::acc"], + "range": [[1, 0], [8, 8], [8, 4]], + }, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + ), + ), +) +def test_partial_filled_template( + datapoints, + y, + anchors_y_defn, + expected_dp_keys, + stroke_dash_encoding, +): + props = { + "template": "linear", + "x": "step", + "y": y, + "anchor_revs": ["B"], + "anchors_y_defn": anchors_y_defn, + } + + expected_split = { + Template.anchor("data"): _get_expected_datapoints(datapoints, expected_dp_keys) + } + + split_anchors = [ + Template.anchor("color"), + Template.anchor("data"), + ] + if len(anchors_y_defn) > 1: + split_anchors.append(Template.anchor("stroke_dash")) + expected_split[Template.anchor("stroke_dash")] = stroke_dash_encoding + + renderer = VegaRenderer(datapoints, "foo", **props) + content, split = renderer.get_partial_filled_template() + + for anchor in split_anchors: + assert anchor in content + assert split == expected_split + + +def _get_expected_datapoints( + datapoints: List[Dict[str, Any]], expected_dp_keys: List[str] +): + expected_datapoints: List[Dict[str, Any]] = [] for datapoint in datapoints: expected_datapoint = {} for key in expected_dp_keys: @@ -509,15 +662,10 @@ def test_optional_anchors_linear( key ] = f"{datapoint['filename']}::{datapoint['field']}" else: - expected_datapoint[key] = datapoint.get(key) + value = datapoint.get(key) + if value is None: + continue + expected_datapoint[key] = value expected_datapoints.append(expected_datapoint) - plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template( - as_string=False - ) - - assert plot_content["data"]["values"] == expected_datapoints - assert plot_content["encoding"]["color"] == color_encoding - assert plot_content["encoding"]["strokeDash"] == stroke_dash_encoding - assert plot_content["layer"][3]["transform"][0]["calculate"] == pivot_field - assert plot_content["layer"][0]["transform"][0]["groupby"] == group_by + return datapoints From d16cbd3f75a69e33301ffdfbb78704e3f1e5238c Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Thu, 21 Sep 2023 20:29:49 +1000 Subject: [PATCH 07/39] move update datapoints to top level --- src/dvc_render/vega.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index d2a66dd..fb5ac9c 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -89,7 +89,8 @@ def get_filled_template( self.properties.setdefault("y_label", self.properties.get("y")) self.properties.setdefault("data", self.datapoints) - self._process_optional_anchors(split_anchors) + varied_keys = self._process_optional_anchors(split_anchors) + self._update_datapoints(varied_keys) names = ["title", "x", "y", "x_label", "y_label", "data"] for name in names: @@ -188,16 +189,16 @@ def _process_optional_anchors(self, split_anchors: List[str]): if self.template.has_anchor(anchor) ] if not optional_anchors: - return + return None y_defn = self.properties.get("anchors_y_defn", []) is_single_source = len(y_defn) <= 1 if is_single_source: self._process_single_source_plot(split_anchors, optional_anchors) - return + return [] - self._process_multi_source_plot(split_anchors, optional_anchors, y_defn) + return self._process_multi_source_plot(split_anchors, optional_anchors, y_defn) def _process_single_source_plot( self, split_anchors: List[str], optional_anchors: List[str] @@ -210,8 +211,6 @@ def _process_single_source_plot( for anchor in optional_anchors: self.template.fill_anchor(anchor, {}) - self._update_datapoints([]) - def _process_multi_source_plot( self, split_anchors: List[str], @@ -224,7 +223,7 @@ def _process_multi_source_plot( self._fill_optional_multi_source_anchors( split_anchors, optional_anchors, varied_keys, domain ) - self._update_datapoints(varied_keys) + return varied_keys def _fill_optional_multi_source_anchors( self, @@ -384,7 +383,10 @@ def _get_optional_anchor_mapping( "legend": {"symbolFillColor": "transparent", "symbolStrokeColor": "grey"}, } - def _update_datapoints(self, varied_keys: List[str]): + def _update_datapoints(self, varied_keys: Optional[List[str]] = None): + if varied_keys is None: + return + if len(varied_keys) == 2: to_concatenate = varied_keys to_remove = varied_keys From e22ad99b30a4c7e6d98b33e32bff4392bb320b40 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Fri, 22 Sep 2023 08:24:42 +1000 Subject: [PATCH 08/39] add color into anchor definitions --- src/dvc_render/vega.py | 28 +++++++++++++--------------- tests/test_vega.py | 8 ++++++-- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index fb5ac9c..cdd92e2 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -125,7 +125,7 @@ def get_partial_filled_template(self): content = self.get_filled_template( split_anchors=["data", "color", "stroke_dash", "shape"], strict=True ) - return content, self._split_content + return content, {"anchor_definitions": self._split_content} def partial_html(self, **kwargs) -> str: return self.get_filled_template() # type: ignore @@ -191,6 +191,8 @@ def _process_optional_anchors(self, split_anchors: List[str]): if not optional_anchors: return None + self._fill_color(split_anchors, optional_anchors) + y_defn = self.properties.get("anchors_y_defn", []) is_single_source = len(y_defn) <= 1 @@ -203,7 +205,6 @@ def _process_optional_anchors(self, split_anchors: List[str]): def _process_single_source_plot( self, split_anchors: List[str], optional_anchors: List[str] ): - self._fill_color(split_anchors, optional_anchors) self._fill_optional_anchor(split_anchors, optional_anchors, "group_by", ["rev"]) self._fill_optional_anchor( split_anchors, optional_anchors, "pivot_field", "datum.rev" @@ -232,8 +233,6 @@ def _fill_optional_multi_source_anchors( varied_keys: List[str], domain: List[str], ): - self._fill_color(split_anchors, optional_anchors) - if not optional_anchors: return @@ -263,19 +262,12 @@ def _fill_optional_multi_source_anchors( def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): all_revs = self.properties.get("anchor_revs", []) - self._fill_optional_anchor( + self._fill_optional_anchor_mapping( split_anchors, optional_anchors, + "rev", "color", - { - "field": "rev", - "scale": { - "domain": list(all_revs), - "range": self._optional_anchor_ranges.get("color", [])[ - : len(all_revs) - ], - }, - }, + all_revs, ) def _collect_variations( @@ -377,10 +369,16 @@ def _get_optional_anchor_mapping( range_value = anchor_range_values.pop(0) anchor_range.append(range_value) + legend = ( + {"legend": {"symbolFillColor": "transparent", "symbolStrokeColor": "grey"}} + if name != "color" + else {} + ) + return { "field": field, "scale": {"domain": domain, "range": anchor_range}, - "legend": {"symbolFillColor": "transparent", "symbolStrokeColor": "grey"}, + **legend, } def _update_datapoints(self, varied_keys: Optional[List[str]] = None): diff --git a/tests/test_vega.py b/tests/test_vega.py index a9adcb1..fda9ee9 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -631,7 +631,11 @@ def test_partial_filled_template( } expected_split = { - Template.anchor("data"): _get_expected_datapoints(datapoints, expected_dp_keys) + Template.anchor("data"): _get_expected_datapoints(datapoints, expected_dp_keys), + Template.anchor("color"): { + "field": "rev", + "scale": {"domain": ["B"], "range": ["#945dd6"]}, + }, } split_anchors = [ @@ -647,7 +651,7 @@ def test_partial_filled_template( for anchor in split_anchors: assert anchor in content - assert split == expected_split + assert split["anchor_definitions"] == expected_split def _get_expected_datapoints( From 9ee0ab9acbfdb2928662a9a787501fe5c945d3e3 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Fri, 22 Sep 2023 09:38:59 +1000 Subject: [PATCH 09/39] add get_revs method to renderer --- src/dvc_render/vega.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index cdd92e2..786946a 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -175,6 +175,13 @@ def generate_markdown(self, report_path=None) -> str: return "" + def get_revs(self): + """ + Returns all revisions that were collected. + Potentially will include revisions that have no datapoints + """ + return self.properties.get("anchor_revs", []) + def _process_optional_anchors(self, split_anchors: List[str]): optional_anchors = [ anchor From 69dec78da4cb98f4128a0978fc41e5a50a54b536 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Fri, 22 Sep 2023 12:22:10 +1000 Subject: [PATCH 10/39] improve names --- src/dvc_render/vega.py | 56 ++++++++++++++++++++++++------------------ tests/test_vega.py | 14 +++++------ 2 files changed, 39 insertions(+), 31 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 786946a..329154a 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -10,6 +10,11 @@ from .utils import list_dict_to_dict_list from .vega_templates import BadTemplateError, LinearTemplate, Template, get_template +FIELD_SEPARATOR = "::" +FILENAME = "filename" +FIELD = "field" +FILENAME_FIELD = [FILENAME, FIELD] + class VegaRenderer(Renderer): """Renderer for vega plots.""" @@ -200,14 +205,16 @@ def _process_optional_anchors(self, split_anchors: List[str]): self._fill_color(split_anchors, optional_anchors) - y_defn = self.properties.get("anchors_y_defn", []) - is_single_source = len(y_defn) <= 1 + y_definitions = self.properties.get("anchors_y_definitions", []) + is_single_source = len(y_definitions) <= 1 if is_single_source: self._process_single_source_plot(split_anchors, optional_anchors) return [] - return self._process_multi_source_plot(split_anchors, optional_anchors, y_defn) + return self._process_multi_source_plot( + split_anchors, optional_anchors, y_definitions + ) def _process_single_source_plot( self, split_anchors: List[str], optional_anchors: List[str] @@ -223,10 +230,10 @@ def _process_multi_source_plot( self, split_anchors: List[str], optional_anchors: List[str], - y_defn: List[Dict[str, str]], + y_definitions: List[Dict[str, str]], ): - varied_keys, varied_values = self._collect_variations(y_defn) - domain = self._get_domain(varied_keys, varied_values, y_defn) + varied_keys, varied_values = self._collect_variations(y_definitions) + domain = self._get_domain(varied_keys, varied_values, y_definitions) self._fill_optional_multi_source_anchors( split_anchors, optional_anchors, varied_keys, domain @@ -254,7 +261,7 @@ def _fill_optional_multi_source_anchors( " + '::' + ".join([f"datum.{key}" for key in grouped_keys]), ) - concat_field = "::".join(varied_keys) + concat_field = FIELD_SEPARATOR.join(varied_keys) self._fill_optional_anchor( split_anchors, optional_anchors, "row", {"field": concat_field} ) @@ -262,9 +269,9 @@ def _fill_optional_multi_source_anchors( if not optional_anchors: return - for field in ["stroke_dash", "shape"]: + for anchor in ["stroke_dash", "shape"]: self._fill_optional_anchor_mapping( - split_anchors, optional_anchors, concat_field, field, domain + split_anchors, optional_anchors, concat_field, anchor, domain ) def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): @@ -278,24 +285,24 @@ def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): ) def _collect_variations( - self, y_defn: List[Dict[str, str]] + self, y_definitions: List[Dict[str, str]] ) -> Tuple[List[str], Dict[str, set]]: varied_values = defaultdict(set) - for defn in y_defn: - for key in ["filename", "field"]: + for defn in y_definitions: + for key in FILENAME_FIELD: varied_values[key].add(defn.get(key, None)) values_match_variations = [] less_values_than_variations = [] - for filenameOrField, valueSet in varied_values.items(): - num_values = len(valueSet) + for filename_or_field, value_set in varied_values.items(): + num_values = len(value_set) if num_values == 1: continue - if num_values == len(y_defn): - values_match_variations.append(filenameOrField) + if num_values == len(y_definitions): + values_match_variations.append(filename_or_field) continue - less_values_than_variations.append(filenameOrField) + less_values_than_variations.append(filename_or_field) if values_match_variations: values_match_variations.extend(less_values_than_variations) @@ -326,15 +333,16 @@ def _get_domain( self, varied_keys: List[str], varied_values: Dict[str, set], - y_defn: List[Dict[str, str]], + y_definitions: List[Dict[str, str]], ): if len(varied_keys) == 2: domain = [ - "::".join([d.get("filename", ""), d.get("field", "")]) for d in y_defn + FIELD_SEPARATOR.join([d.get(FILENAME, ""), d.get(FIELD, "")]) + for d in y_definitions ] else: - filenameOrField = varied_keys[0] - domain = list(varied_values[filenameOrField]) + filename_or_field = varied_keys[0] + domain = list(varied_values[filename_or_field]) domain.sort() return domain @@ -397,12 +405,12 @@ def _update_datapoints(self, varied_keys: Optional[List[str]] = None): to_remove = varied_keys else: to_concatenate = [] - to_remove = [key for key in ["filename", "field"] if key not in varied_keys] + to_remove = [key for key in FILENAME_FIELD if key not in varied_keys] for datapoint in self.datapoints: if to_concatenate: - concat_key = "::".join(to_concatenate) - datapoint[concat_key] = "::".join( + concat_key = FIELD_SEPARATOR.join(to_concatenate) + datapoint[concat_key] = FIELD_SEPARATOR.join( [datapoint.get(k) for k in to_concatenate] ) for key in to_remove: diff --git a/tests/test_vega.py b/tests/test_vega.py index fda9ee9..ac6393c 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -270,7 +270,7 @@ def test_fill_anchor_in_string(tmp_dir): [ "datapoints", "y", - "anchors_y_defn", + "anchors_y_definitions", "expected_dp_keys", "color_encoding", "stroke_dash_encoding", @@ -486,7 +486,7 @@ def test_fill_anchor_in_string(tmp_dir): def test_optional_anchors_linear( datapoints, y, - anchors_y_defn, + anchors_y_definitions, expected_dp_keys, color_encoding, stroke_dash_encoding, @@ -498,7 +498,7 @@ def test_optional_anchors_linear( "x": "step", "y": y, "anchor_revs": ["B"], - "anchors_y_defn": anchors_y_defn, + "anchors_y_definitions": anchors_y_definitions, } expected_datapoints = _get_expected_datapoints(datapoints, expected_dp_keys) @@ -514,7 +514,7 @@ def test_optional_anchors_linear( @pytest.mark.parametrize( - "datapoints,y,anchors_y_defn,expected_dp_keys,stroke_dash_encoding", + "datapoints,y,anchors_y_definitions,expected_dp_keys,stroke_dash_encoding", ( ( [ @@ -618,7 +618,7 @@ def test_optional_anchors_linear( def test_partial_filled_template( datapoints, y, - anchors_y_defn, + anchors_y_definitions, expected_dp_keys, stroke_dash_encoding, ): @@ -627,7 +627,7 @@ def test_partial_filled_template( "x": "step", "y": y, "anchor_revs": ["B"], - "anchors_y_defn": anchors_y_defn, + "anchors_y_definitions": anchors_y_definitions, } expected_split = { @@ -642,7 +642,7 @@ def test_partial_filled_template( Template.anchor("color"), Template.anchor("data"), ] - if len(anchors_y_defn) > 1: + if len(anchors_y_definitions) > 1: split_anchors.append(Template.anchor("stroke_dash")) expected_split[Template.anchor("stroke_dash")] = stroke_dash_encoding From 6668c3d18b0d39358e35264fe803506eee07d850 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Fri, 22 Sep 2023 14:13:45 +1000 Subject: [PATCH 11/39] refactor y definitions out of get domain --- src/dvc_render/vega.py | 95 ++++++++++++++++++++---------------------- 1 file changed, 46 insertions(+), 49 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 329154a..f723b50 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -14,6 +14,7 @@ FILENAME = "filename" FIELD = "field" FILENAME_FIELD = [FILENAME, FIELD] +CONCAT_FIELDS = FIELD_SEPARATOR.join(FILENAME_FIELD) class VegaRenderer(Renderer): @@ -216,6 +217,16 @@ def _process_optional_anchors(self, split_anchors: List[str]): split_anchors, optional_anchors, y_definitions ) + def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): + all_revs = self.get_revs() + self._fill_optional_anchor_mapping( + split_anchors, + optional_anchors, + "rev", + "color", + all_revs, + ) + def _process_single_source_plot( self, split_anchors: List[str], optional_anchors: List[str] ): @@ -233,13 +244,45 @@ def _process_multi_source_plot( y_definitions: List[Dict[str, str]], ): varied_keys, varied_values = self._collect_variations(y_definitions) - domain = self._get_domain(varied_keys, varied_values, y_definitions) + domain = self._get_domain(varied_keys, varied_values) self._fill_optional_multi_source_anchors( split_anchors, optional_anchors, varied_keys, domain ) return varied_keys + def _collect_variations( + self, y_definitions: List[Dict[str, str]] + ) -> Tuple[List[str], Dict[str, set]]: + varied_values = defaultdict(set) + for defn in y_definitions: + for key in FILENAME_FIELD: + varied_values[key].add(defn.get(key, None)) + varied_values[CONCAT_FIELDS].add( + FIELD_SEPARATOR.join([defn.get(FILENAME, ""), defn.get(FIELD, "")]) + ) + + values_match_variations = [] + less_values_than_variations = [] + + for filename_or_field in FILENAME_FIELD: + value_set = varied_values[filename_or_field] + num_values = len(value_set) + if num_values == 1: + continue + if num_values == len(y_definitions): + values_match_variations.append(filename_or_field) + continue + less_values_than_variations.append(filename_or_field) + + if values_match_variations: + values_match_variations.extend(less_values_than_variations) + values_match_variations.sort(reverse=True) + return values_match_variations, varied_values + + less_values_than_variations.sort(reverse=True) + return less_values_than_variations, varied_values + def _fill_optional_multi_source_anchors( self, split_anchors: List[str], @@ -274,44 +317,6 @@ def _fill_optional_multi_source_anchors( split_anchors, optional_anchors, concat_field, anchor, domain ) - def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): - all_revs = self.properties.get("anchor_revs", []) - self._fill_optional_anchor_mapping( - split_anchors, - optional_anchors, - "rev", - "color", - all_revs, - ) - - def _collect_variations( - self, y_definitions: List[Dict[str, str]] - ) -> Tuple[List[str], Dict[str, set]]: - varied_values = defaultdict(set) - for defn in y_definitions: - for key in FILENAME_FIELD: - varied_values[key].add(defn.get(key, None)) - - values_match_variations = [] - less_values_than_variations = [] - - for filename_or_field, value_set in varied_values.items(): - num_values = len(value_set) - if num_values == 1: - continue - if num_values == len(y_definitions): - values_match_variations.append(filename_or_field) - continue - less_values_than_variations.append(filename_or_field) - - if values_match_variations: - values_match_variations.extend(less_values_than_variations) - values_match_variations.sort(reverse=True) - return values_match_variations, varied_values - - less_values_than_variations.sort(reverse=True) - return less_values_than_variations, varied_values - def _fill_optional_anchor( self, split_anchors: List[str], @@ -329,17 +334,9 @@ def _fill_optional_anchor( self.template.fill_anchor(name, value) - def _get_domain( - self, - varied_keys: List[str], - varied_values: Dict[str, set], - y_definitions: List[Dict[str, str]], - ): + def _get_domain(self, varied_keys: List[str], varied_values: Dict[str, set]): if len(varied_keys) == 2: - domain = [ - FIELD_SEPARATOR.join([d.get(FILENAME, ""), d.get(FIELD, "")]) - for d in y_definitions - ] + domain = list(varied_values[CONCAT_FIELDS]) else: filename_or_field = varied_keys[0] domain = list(varied_values[filename_or_field]) From e1dc7963f131d136ad51c90a2fdffb30739bdb41 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Fri, 22 Sep 2023 20:51:25 +1000 Subject: [PATCH 12/39] refactor domain collection --- src/dvc_render/vega.py | 21 ++++++--------------- 1 file changed, 6 insertions(+), 15 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index f723b50..ba53472 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -243,8 +243,7 @@ def _process_multi_source_plot( optional_anchors: List[str], y_definitions: List[Dict[str, str]], ): - varied_keys, varied_values = self._collect_variations(y_definitions) - domain = self._get_domain(varied_keys, varied_values) + varied_keys, domain = self._collect_variations(y_definitions) self._fill_optional_multi_source_anchors( split_anchors, optional_anchors, varied_keys, domain @@ -253,7 +252,7 @@ def _process_multi_source_plot( def _collect_variations( self, y_definitions: List[Dict[str, str]] - ) -> Tuple[List[str], Dict[str, set]]: + ) -> Tuple[List[str], List[str]]: varied_values = defaultdict(set) for defn in y_definitions: for key in FILENAME_FIELD: @@ -262,26 +261,18 @@ def _collect_variations( FIELD_SEPARATOR.join([defn.get(FILENAME, ""), defn.get(FIELD, "")]) ) - values_match_variations = [] - less_values_than_variations = [] + varied_keys = [] for filename_or_field in FILENAME_FIELD: value_set = varied_values[filename_or_field] num_values = len(value_set) if num_values == 1: continue - if num_values == len(y_definitions): - values_match_variations.append(filename_or_field) - continue - less_values_than_variations.append(filename_or_field) + varied_keys.append(filename_or_field) - if values_match_variations: - values_match_variations.extend(less_values_than_variations) - values_match_variations.sort(reverse=True) - return values_match_variations, varied_values + domain = self._get_domain(varied_keys, varied_values) - less_values_than_variations.sort(reverse=True) - return less_values_than_variations, varied_values + return varied_keys, domain def _fill_optional_multi_source_anchors( self, From 221f9512ebd3d508e58f18a754b7a2764e6b8db3 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Tue, 26 Sep 2023 09:59:06 +1000 Subject: [PATCH 13/39] send all anchor_definitions as strings --- src/dvc_render/vega.py | 11 ++++++++--- tests/test_vega.py | 3 ++- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index ba53472..8c50384 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -65,7 +65,7 @@ def __init__(self, datapoints: List, name: str, **properties): "shape": ["square", "circle", "triangle", "diamond"], } - self._split_content: Dict[str, Any] = {} + self._split_content: Dict[str, str] = {} def get_filled_template( self, @@ -129,7 +129,12 @@ def get_partial_filled_template(self): Returns a partially filled template along with the split out anchor content """ content = self.get_filled_template( - split_anchors=["data", "color", "stroke_dash", "shape"], strict=True + split_anchors=[ + "data", + "color", + "stroke_dash", + "shape", + ] # add y_label, x_label so we can truncate, strict=True ) return content, {"anchor_definitions": self._split_content} @@ -405,4 +410,4 @@ def _update_datapoints(self, varied_keys: Optional[List[str]] = None): datapoint.pop(key, None) def _set_split_content(self, name: str, value: Any): - self._split_content[Template.anchor(name)] = value + self._split_content[Template.anchor(name)] = json.dumps(value) diff --git a/tests/test_vega.py b/tests/test_vega.py index ac6393c..45df64b 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -651,7 +651,8 @@ def test_partial_filled_template( for anchor in split_anchors: assert anchor in content - assert split["anchor_definitions"] == expected_split + for key, value in split["anchor_definitions"].items(): + assert json.loads(value) == expected_split[key] def _get_expected_datapoints( From 65fa9126eb0d700c56d1ee68bb3f29d8f08b1287 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Tue, 26 Sep 2023 19:25:59 +1000 Subject: [PATCH 14/39] add anchors to fix confusion templates --- src/dvc_render/vega.py | 31 +++++++++++++++++++++++++++---- src/dvc_render/vega_templates.py | 10 +++++----- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 8c50384..d6836d3 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -199,6 +199,8 @@ def _process_optional_anchors(self, split_anchors: List[str]): for anchor in [ "row", "group_by", + "group_by_x", + "group_by_y", "pivot_field", "color", "stroke_dash", @@ -235,7 +237,7 @@ def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): def _process_single_source_plot( self, split_anchors: List[str], optional_anchors: List[str] ): - self._fill_optional_anchor(split_anchors, optional_anchors, "group_by", ["rev"]) + self._fill_group_by(split_anchors, optional_anchors, ["rev"]) self._fill_optional_anchor( split_anchors, optional_anchors, "pivot_field", "datum.rev" ) @@ -290,9 +292,8 @@ def _fill_optional_multi_source_anchors( return grouped_keys = ["rev", *varied_keys] - self._fill_optional_anchor( - split_anchors, optional_anchors, "group_by", grouped_keys - ) + self._fill_group_by(split_anchors, optional_anchors, grouped_keys) + self._fill_optional_anchor( split_anchors, optional_anchors, @@ -313,6 +314,28 @@ def _fill_optional_multi_source_anchors( split_anchors, optional_anchors, concat_field, anchor, domain ) + def _fill_group_by( + self, + split_anchors: List[str], + optional_anchors: List[str], + grouped_keys: List[str], + ): + self._fill_optional_anchor( + split_anchors, optional_anchors, "group_by", grouped_keys + ) + self._fill_optional_anchor( + split_anchors, + optional_anchors, + "group_by_x", + [*grouped_keys, self.properties.get("x")], + ) + self._fill_optional_anchor( + split_anchors, + optional_anchors, + "group_by_y", + [*grouped_keys, self.properties.get("y")], + ) + def _fill_optional_anchor( self, split_anchors: List[str], diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index 909a331..6a6e8df 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -200,7 +200,7 @@ class ConfusionTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "facet": {"field": "rev", "type": "nominal"}, + "facet": {"column": {"field": "rev"}, "row": {"field": Template.anchor("row")}}, "spec": { "transform": [ { @@ -209,13 +209,13 @@ class ConfusionTemplate(Template): }, { "impute": "xy_count", - "groupby": ["rev", Template.anchor("y")], + "groupby": Template.anchor("group_by_y"), "key": Template.anchor("x"), "value": 0, }, { "impute": "xy_count", - "groupby": ["rev", Template.anchor("x")], + "groupby": Template.anchor("group_by_x"), "key": Template.anchor("y"), "value": 0, }, @@ -310,7 +310,7 @@ class NormalizedConfusionTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "facet": {"field": "rev", "type": "nominal"}, + "facet": {"column": {"field": "rev"}, "row": {"field": Template.anchor("row")}}, "spec": { "transform": [ { @@ -325,7 +325,7 @@ class NormalizedConfusionTemplate(Template): }, { "impute": "xy_count", - "groupby": ["rev", Template.anchor("x")], + "groupby": Template.anchor("group_by_x"), "key": Template.anchor("y"), "value": 0, }, From 9cbc6c9bc0863fc77b99da0b6ea865c1313d9cd1 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Tue, 26 Sep 2023 19:28:00 +1000 Subject: [PATCH 15/39] split x and y labels out so they can be truncated by vs code --- src/dvc_render/vega.py | 2 ++ tests/test_vega.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index d6836d3..03b6a4a 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -134,6 +134,8 @@ def get_partial_filled_template(self): "color", "stroke_dash", "shape", + "x_label", + "y_label", ] # add y_label, x_label so we can truncate, strict=True ) return content, {"anchor_definitions": self._split_content} diff --git a/tests/test_vega.py b/tests/test_vega.py index 45df64b..c3da959 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -636,6 +636,8 @@ def test_partial_filled_template( "field": "rev", "scale": {"domain": ["B"], "range": ["#945dd6"]}, }, + Template.anchor("x_label"): "step", + Template.anchor("y_label"): y, } split_anchors = [ From 9e5c3eab81d806b19331f27596fe8d9671ae26bc Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 27 Sep 2023 11:25:48 +1000 Subject: [PATCH 16/39] remove erroneous comment --- src/dvc_render/vega.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 03b6a4a..b744bd9 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -136,7 +136,8 @@ def get_partial_filled_template(self): "shape", "x_label", "y_label", - ] # add y_label, x_label so we can truncate, strict=True + ], + strict=True, ) return content, {"anchor_definitions": self._split_content} From 33076e3057352a49fbaf34990996e92fc42237eb Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Thu, 28 Sep 2023 13:49:23 +1000 Subject: [PATCH 17/39] fix issue with confusion matrix --- src/dvc_render/vega_templates.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index 6a6e8df..33957d8 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -200,7 +200,7 @@ class ConfusionTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "facet": {"column": {"field": "rev"}, "row": {"field": Template.anchor("row")}}, + "facet": {"column": {"field": "rev"}, "row": Template.anchor("row")}, "spec": { "transform": [ { @@ -310,7 +310,7 @@ class NormalizedConfusionTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "facet": {"column": {"field": "rev"}, "row": {"field": Template.anchor("row")}}, + "facet": {"column": {"field": "rev"}, "row": Template.anchor("row")}, "spec": { "transform": [ { From 1f0601069945b57ce51653c57874973c644815bf Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Thu, 28 Sep 2023 14:08:35 +1000 Subject: [PATCH 18/39] fix string issues --- src/dvc_render/vega.py | 4 +++- tests/test_vega.py | 3 +++ 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index b744bd9..f396c43 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -436,4 +436,6 @@ def _update_datapoints(self, varied_keys: Optional[List[str]] = None): datapoint.pop(key, None) def _set_split_content(self, name: str, value: Any): - self._split_content[Template.anchor(name)] = json.dumps(value) + self._split_content[Template.anchor(name)] = ( + value if isinstance(value, str) else json.dumps(value) + ) diff --git a/tests/test_vega.py b/tests/test_vega.py index c3da959..c2830e2 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -654,6 +654,9 @@ def test_partial_filled_template( for anchor in split_anchors: assert anchor in content for key, value in split["anchor_definitions"].items(): + if key in [Template.anchor("x_label"), Template.anchor("y_label")]: + assert value == expected_split[key] + continue assert json.loads(value) == expected_split[key] From dc5d8b4c7c3d24010db7fc25e8b1cd8349e629e4 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Thu, 28 Sep 2023 15:55:07 +1000 Subject: [PATCH 19/39] add title to split anchors --- src/dvc_render/vega.py | 5 +++-- tests/test_vega.py | 9 ++++++++- 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index f396c43..d2e103c 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -130,10 +130,11 @@ def get_partial_filled_template(self): """ content = self.get_filled_template( split_anchors=[ - "data", "color", - "stroke_dash", + "data", "shape", + "stroke_dash", + "title", "x_label", "y_label", ], diff --git a/tests/test_vega.py b/tests/test_vega.py index c2830e2..36e0ba3 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -622,12 +622,14 @@ def test_partial_filled_template( expected_dp_keys, stroke_dash_encoding, ): + title = f"{y} by step" props = { "template": "linear", "x": "step", "y": y, "anchor_revs": ["B"], "anchors_y_definitions": anchors_y_definitions, + "title": title, } expected_split = { @@ -638,6 +640,7 @@ def test_partial_filled_template( }, Template.anchor("x_label"): "step", Template.anchor("y_label"): y, + Template.anchor("title"): title, } split_anchors = [ @@ -654,7 +657,11 @@ def test_partial_filled_template( for anchor in split_anchors: assert anchor in content for key, value in split["anchor_definitions"].items(): - if key in [Template.anchor("x_label"), Template.anchor("y_label")]: + if key in [ + Template.anchor("x_label"), + Template.anchor("y_label"), + Template.anchor("title"), + ]: assert value == expected_split[key] continue assert json.loads(value) == expected_split[key] From 06d2d21c417bd83a63bf66afff85cba20fec7dd0 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Tue, 3 Oct 2023 12:53:48 +1100 Subject: [PATCH 20/39] add zoom and pan anchor --- src/dvc_render/vega.py | 25 +++++++++++++++++++++---- src/dvc_render/vega_templates.py | 4 +++- tests/test_vega.py | 9 +++++++-- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index d2e103c..37e8b9c 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -137,6 +137,7 @@ def get_partial_filled_template(self): "title", "x_label", "y_label", + "zoom_and_pan", ], strict=True, ) @@ -201,14 +202,15 @@ def _process_optional_anchors(self, split_anchors: List[str]): optional_anchors = [ anchor for anchor in [ - "row", - "group_by", + "color", "group_by_x", "group_by_y", + "group_by", "pivot_field", - "color", - "stroke_dash", + "row", "shape", + "stroke_dash", + "zoom_and_pan", ] if self.template.has_anchor(anchor) ] @@ -216,6 +218,7 @@ def _process_optional_anchors(self, split_anchors: List[str]): return None self._fill_color(split_anchors, optional_anchors) + self._fill_zoom_and_pan(split_anchors, optional_anchors) y_definitions = self.properties.get("anchors_y_definitions", []) is_single_source = len(y_definitions) <= 1 @@ -238,6 +241,20 @@ def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): all_revs, ) + def _fill_zoom_and_pan(self, split_anchors: List[str], optional_anchors: List[str]): + name = "zoom_and_pan" + encoding = {"name": "grid", "select": "interval", "bind": "scales"} + if "zoom_and_pan" not in optional_anchors: + return + + optional_anchors.remove("zoom_and_pan") + + if name in split_anchors: + self._set_split_content(name, encoding) + return + + self.template.fill_anchor(name, encoding) + def _process_single_source_plot( self, split_anchors: List[str], optional_anchors: List[str] ): diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index 33957d8..ead167f 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -428,6 +428,7 @@ class ScatterTemplate(Template): "width": 300, "height": 300, "mark": {"type": "point", "tooltip": {"content": "data"}}, + "params": [Template.anchor("zoom_and_pan")], "encoding": { "x": { "field": Template.anchor("x"), @@ -508,7 +509,7 @@ class SmoothLinearTemplate(Template): "layer": [ { "layer": [ - {"mark": "line"}, + {"params": [Template.anchor("zoom_and_pan")], "mark": "line"}, { "transform": [{"filter": {"param": "hover", "empty": False}}], "mark": "point", @@ -619,6 +620,7 @@ class SimpleLinearTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), + "params": [Template.anchor("zoom_and_pan")], "width": 300, "height": 300, "mark": { diff --git a/tests/test_vega.py b/tests/test_vega.py index 36e0ba3..70c4ab8 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -633,14 +633,19 @@ def test_partial_filled_template( } expected_split = { - Template.anchor("data"): _get_expected_datapoints(datapoints, expected_dp_keys), Template.anchor("color"): { "field": "rev", "scale": {"domain": ["B"], "range": ["#945dd6"]}, }, + Template.anchor("data"): _get_expected_datapoints(datapoints, expected_dp_keys), + Template.anchor("title"): title, Template.anchor("x_label"): "step", Template.anchor("y_label"): y, - Template.anchor("title"): title, + Template.anchor("zoom_and_pan"): { + "name": "grid", + "select": "interval", + "bind": "scales", + }, } split_anchors = [ From 5131b4248010cc0c15fae27fa17bcfc0c7051664 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Tue, 3 Oct 2023 13:12:18 +1100 Subject: [PATCH 21/39] add zoom and pan anchor --- src/dvc_render/vega_templates.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index ead167f..2cb7a9d 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -145,6 +145,7 @@ class BarHorizontalSortedTemplate(Template): "width": 300, "height": 300, "mark": {"type": "bar"}, + "params": [Template.anchor("zoom_and_pan")], "encoding": { "x": { "field": Template.anchor("x"), @@ -175,6 +176,7 @@ class BarHorizontalTemplate(Template): "width": 300, "height": 300, "mark": {"type": "bar"}, + "params": [Template.anchor("zoom_and_pan")], "encoding": { "x": { "field": Template.anchor("x"), From c0f3cbb4ab132f8b01759542640833d42f2d84b7 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Thu, 5 Oct 2023 13:22:59 +1100 Subject: [PATCH 22/39] add tooltip anchor for Studio --- src/dvc_render/vega.py | 26 ++++++++++++++++++++++++-- src/dvc_render/vega_templates.py | 3 +++ 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 37e8b9c..35fcae3 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -134,6 +134,7 @@ def get_partial_filled_template(self): "data", "shape", "stroke_dash", + "tooltip", "title", "x_label", "y_label", @@ -210,6 +211,7 @@ def _process_optional_anchors(self, split_anchors: List[str]): "row", "shape", "stroke_dash", + "tooltip", "zoom_and_pan", ] if self.template.has_anchor(anchor) @@ -262,6 +264,7 @@ def _process_single_source_plot( self._fill_optional_anchor( split_anchors, optional_anchors, "pivot_field", "datum.rev" ) + self._fill_tooltip(split_anchors, optional_anchors) for anchor in optional_anchors: self.template.fill_anchor(anchor, {}) @@ -327,8 +330,7 @@ def _fill_optional_multi_source_anchors( split_anchors, optional_anchors, "row", {"field": concat_field} ) - if not optional_anchors: - return + self._fill_tooltip(split_anchors, optional_anchors, [concat_field]) for anchor in ["stroke_dash", "shape"]: self._fill_optional_anchor_mapping( @@ -357,6 +359,26 @@ def _fill_group_by( [*grouped_keys, self.properties.get("y")], ) + def _fill_tooltip( + self, + split_anchors: List[str], + optional_anchors: List[str], + additional_fields: Optional[List[str]] = None, + ): + if not additional_fields: + additional_fields = [] + self._fill_optional_anchor( + split_anchors, + optional_anchors, + "tooltip", + [ + {"field": "rev"}, + {"field": self.properties.get("x")}, + {"field": self.properties.get("y")}, + *[{"field": field} for field in additional_fields], + ], + ) + def _fill_optional_anchor( self, split_anchors: List[str], diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index 2cb7a9d..bf219eb 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -444,6 +444,7 @@ class ScatterTemplate(Template): }, "color": Template.anchor("color"), "shape": Template.anchor("shape"), + "tooltip": Template.anchor("tooltip"), }, } @@ -473,6 +474,7 @@ class ScatterJitterTemplate(Template): }, "color": Template.anchor("color"), "shape": Template.anchor("shape"), + "tooltip": Template.anchor("tooltip"), "xOffset": {"field": "randomX", "type": "quantitative"}, "yOffset": {"field": "randomY", "type": "quantitative"}, }, @@ -643,6 +645,7 @@ class SimpleLinearTemplate(Template): }, "color": Template.anchor("color"), "strokeDash": Template.anchor("stroke_dash"), + "tooltip": Template.anchor("tooltip"), }, } From b470f995e99ea210558246bb79cac7cacc5e8bfd Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Thu, 5 Oct 2023 19:35:49 +1100 Subject: [PATCH 23/39] rename anchor_revs to revs_with_datapoints --- src/dvc_render/vega.py | 2 +- tests/test_vega.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 35fcae3..d8e8e3e 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -197,7 +197,7 @@ def get_revs(self): Returns all revisions that were collected. Potentially will include revisions that have no datapoints """ - return self.properties.get("anchor_revs", []) + return self.properties.get("revs_with_datapoints", []) def _process_optional_anchors(self, split_anchors: List[str]): optional_anchors = [ diff --git a/tests/test_vega.py b/tests/test_vega.py index 70c4ab8..4817e36 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -497,7 +497,7 @@ def test_optional_anchors_linear( "template": "linear", "x": "step", "y": y, - "anchor_revs": ["B"], + "revs_with_datapoints": ["B"], "anchors_y_definitions": anchors_y_definitions, } @@ -627,7 +627,7 @@ def test_partial_filled_template( "template": "linear", "x": "step", "y": y, - "anchor_revs": ["B"], + "revs_with_datapoints": ["B"], "anchors_y_definitions": anchors_y_definitions, "title": title, } From 9aa12a56dba1a7a126e3127f784f78a59642582a Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Fri, 6 Oct 2023 14:07:27 +1100 Subject: [PATCH 24/39] add empty sort property to order facet by rev in datapoints --- src/dvc_render/vega_templates.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index bf219eb..81b57be 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -202,7 +202,10 @@ class ConfusionTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "facet": {"column": {"field": "rev"}, "row": Template.anchor("row")}, + "facet": { + "column": {"field": "rev", "sort": []}, + "row": Template.anchor("row"), + }, "spec": { "transform": [ { @@ -312,7 +315,10 @@ class NormalizedConfusionTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "facet": {"column": {"field": "rev"}, "row": Template.anchor("row")}, + "facet": { + "column": {"field": "rev", "sort": []}, + "row": Template.anchor("row"), + }, "spec": { "transform": [ { From c42504831d7f6258578e4ed35f07367ff74058a9 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Fri, 6 Oct 2023 19:33:15 +1100 Subject: [PATCH 25/39] add width and height anchors --- src/dvc_render/vega.py | 37 ++++++++++++++++++-------------- src/dvc_render/vega_templates.py | 32 +++++++++++++-------------- tests/test_vega.py | 2 ++ 3 files changed, 39 insertions(+), 32 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index d8e8e3e..78ed3f7 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -131,11 +131,15 @@ def get_partial_filled_template(self): content = self.get_filled_template( split_anchors=[ "color", + "column_width", "data", + "plot_height", + "plot_width", + "row_height", "shape", "stroke_dash", - "tooltip", "title", + "tooltip", "x_label", "y_label", "zoom_and_pan", @@ -204,10 +208,14 @@ def _process_optional_anchors(self, split_anchors: List[str]): anchor for anchor in [ "color", + "column_width", "group_by_x", "group_by_y", "group_by", "pivot_field", + "plot_height", + "plot_width", + "row_height", "row", "shape", "stroke_dash", @@ -220,7 +228,7 @@ def _process_optional_anchors(self, split_anchors: List[str]): return None self._fill_color(split_anchors, optional_anchors) - self._fill_zoom_and_pan(split_anchors, optional_anchors) + self._fill_set_encoding(split_anchors, optional_anchors) y_definitions = self.properties.get("anchors_y_definitions", []) is_single_source = len(y_definitions) <= 1 @@ -243,19 +251,15 @@ def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): all_revs, ) - def _fill_zoom_and_pan(self, split_anchors: List[str], optional_anchors: List[str]): - name = "zoom_and_pan" - encoding = {"name": "grid", "select": "interval", "bind": "scales"} - if "zoom_and_pan" not in optional_anchors: - return - - optional_anchors.remove("zoom_and_pan") - - if name in split_anchors: - self._set_split_content(name, encoding) - return - - self.template.fill_anchor(name, encoding) + def _fill_set_encoding(self, split_anchors: List[str], optional_anchors: List[str]): + for name, encoding in [ + ("zoom_and_pan", {"name": "grid", "select": "interval", "bind": "scales"}), + ("column_width", 300), + ("plot_height", 300), + ("plot_width", 300), + ("row_height", 300), + ]: + self._fill_optional_anchor(split_anchors, optional_anchors, name, encoding) def _process_single_source_plot( self, split_anchors: List[str], optional_anchors: List[str] @@ -327,7 +331,7 @@ def _fill_optional_multi_source_anchors( concat_field = FIELD_SEPARATOR.join(varied_keys) self._fill_optional_anchor( - split_anchors, optional_anchors, "row", {"field": concat_field} + split_anchors, optional_anchors, "row", {"field": concat_field, "sort": []} ) self._fill_tooltip(split_anchors, optional_anchors, [concat_field]) @@ -392,6 +396,7 @@ def _fill_optional_anchor( optional_anchors.remove(name) if name in split_anchors: + self._set_split_content(name, value) return self.template.fill_anchor(name, value) diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index 81b57be..c3cb605 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -142,8 +142,8 @@ class BarHorizontalSortedTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "mark": {"type": "bar"}, "params": [Template.anchor("zoom_and_pan")], "encoding": { @@ -173,8 +173,8 @@ class BarHorizontalTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "mark": {"type": "bar"}, "params": [Template.anchor("zoom_and_pan")], "encoding": { @@ -252,8 +252,8 @@ class ConfusionTemplate(Template): "layer": [ { "mark": "rect", - "width": 300, - "height": 300, + "width": Template.anchor("column_width"), + "height": Template.anchor("row_height"), "encoding": { "color": { "field": "xy_count", @@ -365,8 +365,8 @@ class NormalizedConfusionTemplate(Template): "layer": [ { "mark": "rect", - "width": 300, - "height": 300, + "width": Template.anchor("column_width"), + "height": Template.anchor("row_height"), "encoding": { "color": { "field": "percent_of_y", @@ -433,8 +433,8 @@ class ScatterTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "mark": {"type": "point", "tooltip": {"content": "data"}}, "params": [Template.anchor("zoom_and_pan")], "encoding": { @@ -462,8 +462,8 @@ class ScatterJitterTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "transform": [ {"calculate": "random()", "as": "randomX"}, {"calculate": "random()", "as": "randomY"}, @@ -493,8 +493,8 @@ class SmoothLinearTemplate(Template): "$schema": "https://vega.github.io/schema/vega-lite/v5.json", "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "params": [ { "name": "smooth", @@ -631,8 +631,8 @@ class SimpleLinearTemplate(Template): "data": {"values": Template.anchor("data")}, "title": Template.anchor("title"), "params": [Template.anchor("zoom_and_pan")], - "width": 300, - "height": 300, + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "mark": { "type": "line", "tooltip": {"content": "data"}, diff --git a/tests/test_vega.py b/tests/test_vega.py index 4817e36..aafb271 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -638,6 +638,8 @@ def test_partial_filled_template( "scale": {"domain": ["B"], "range": ["#945dd6"]}, }, Template.anchor("data"): _get_expected_datapoints(datapoints, expected_dp_keys), + Template.anchor("plot_height"): 300, + Template.anchor("plot_width"): 300, Template.anchor("title"): title, Template.anchor("x_label"): "step", Template.anchor("y_label"): y, From 2f0f6f94c36fba4408e4be25f60401168d07934d Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 11 Oct 2023 19:45:02 +1100 Subject: [PATCH 26/39] add get_template_as_string method --- src/dvc_render/vega.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 78ed3f7..eb9c98e 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -148,6 +148,12 @@ def get_partial_filled_template(self): ) return content, {"anchor_definitions": self._split_content} + def get_template_as_string(self): + """ + Returns unfilled template as a string (for Studio) + """ + return json.dumps(self.template.content) + def partial_html(self, **kwargs) -> str: return self.get_filled_template() # type: ignore From e958a75a0240e975efda327a69f207ba833cc3bc Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Thu, 2 Nov 2023 13:21:13 +1100 Subject: [PATCH 27/39] add sort to y offset in horizontal bar templates --- src/dvc_render/vega_templates.py | 4 ++-- tests/test_vega.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index c3cb605..a1db9a6 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -159,7 +159,7 @@ class BarHorizontalSortedTemplate(Template): "title": Template.anchor("y_label"), "sort": "-x", }, - "yOffset": {"field": "rev"}, + "yOffset": {"field": "rev", "sort": []}, "color": Template.anchor("color"), "row": Template.anchor("row"), }, @@ -189,7 +189,7 @@ class BarHorizontalTemplate(Template): "type": "nominal", "title": Template.anchor("y_label"), }, - "yOffset": {"field": "rev"}, + "yOffset": {"field": "rev", "sort": []}, "color": Template.anchor("color"), "row": Template.anchor("row"), }, diff --git a/tests/test_vega.py b/tests/test_vega.py index aafb271..7fbf038 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -264,7 +264,8 @@ def test_fill_anchor_in_string(tmp_dir): assert filled["transform"][1]["calculate"] == "pow(datum.lab - datum.SR,2)" assert filled["encoding"]["x"]["field"] == x assert filled["encoding"]["y"]["field"] == y - + + @pytest.mark.parametrize( ",".join( [ From 9b9c0cf9bf262190a54a9256e7251dc925e4e592 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Thu, 9 Nov 2023 09:38:33 +1100 Subject: [PATCH 28/39] drop terrible idea of holding all data as strings --- src/dvc_render/vega.py | 22 ++++++++------------- tests/test_vega.py | 44 ++++++++++++++++++++++-------------------- 2 files changed, 31 insertions(+), 35 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index eb9c98e..281f835 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -71,8 +71,7 @@ def get_filled_template( self, split_anchors: Optional[List[str]] = None, strict: bool = True, - as_string: bool = True, - ) -> Union[str, Dict[str, Any]]: + ) -> Dict[str, Any]: """Returns a functional vega specification""" self.template.reset() if not self.datapoints: @@ -119,9 +118,6 @@ def get_filled_template( value = self.template.escape_special_characters(value) self.template.fill_anchor(name, value) - if as_string: - return json.dumps(self.template.content) - return self.template.content def get_partial_filled_template(self): @@ -148,14 +144,15 @@ def get_partial_filled_template(self): ) return content, {"anchor_definitions": self._split_content} - def get_template_as_string(self): + def get_template(self): """ - Returns unfilled template as a string (for Studio) + Returns unfilled template (for Studio) """ - return json.dumps(self.template.content) + return self.template.content def partial_html(self, **kwargs) -> str: - return self.get_filled_template() # type: ignore + content = self.get_filled_template() + return json.dumps(content) def generate_markdown(self, report_path=None) -> str: if not isinstance(self.template, LinearTemplate): @@ -204,8 +201,7 @@ def generate_markdown(self, report_path=None) -> str: def get_revs(self): """ - Returns all revisions that were collected. - Potentially will include revisions that have no datapoints + Returns all revisions that were collected that have datapoints. """ return self.properties.get("revs_with_datapoints", []) @@ -487,6 +483,4 @@ def _update_datapoints(self, varied_keys: Optional[List[str]] = None): datapoint.pop(key, None) def _set_split_content(self, name: str, value: Any): - self._split_content[Template.anchor(name)] = ( - value if isinstance(value, str) else json.dumps(value) - ) + self._split_content[Template.anchor(name)] = value diff --git a/tests/test_vega.py b/tests/test_vega.py index 7fbf038..7d80fe2 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -43,7 +43,7 @@ def test_default_template_mark(): {"first_val": 200, "second_val": 300, "val": 3}, ] - plot_content = VegaRenderer(datapoints, "foo").get_filled_template(as_string=False) + plot_content = VegaRenderer(datapoints, "foo").get_filled_template() assert plot_content["layer"][0]["layer"][0]["mark"] == "line" @@ -59,9 +59,7 @@ def test_choose_axes(): {"first_val": 200, "second_val": 300, "val": 3}, ] - plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template( - as_string=False - ) + plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template() assert plot_content["data"]["values"] == [ { @@ -86,9 +84,7 @@ def test_confusion(): ] props = {"template": "confusion", "x": "predicted", "y": "actual"} - plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template( - as_string=False - ) + plot_content = VegaRenderer(datapoints, "foo", **props).get_filled_template() assert plot_content["data"]["values"] == [ {"predicted": "B", "actual": "A"}, @@ -218,7 +214,7 @@ def test_escape_special_characters(): ] props = {"template": "simple", "x": "foo.bar[0]", "y": "foo.bar[1]"} renderer = VegaRenderer(datapoints, "foo", **props) - filled = renderer.get_filled_template(as_string=False) + filled = renderer.get_filled_template() # data is not escaped assert filled["data"]["values"][0] == datapoints[0] # field and title yes @@ -260,7 +256,7 @@ def test_fill_anchor_in_string(tmp_dir): props = {"template": "custom.json", "x": x, "y": y} renderer = VegaRenderer(datapoints, "foo", **props) - filled = renderer.get_filled_template(as_string=False) + filled = renderer.get_filled_template() assert filled["transform"][1]["calculate"] == "pow(datum.lab - datum.SR,2)" assert filled["encoding"]["x"]["field"] == x assert filled["encoding"]["y"]["field"] == y @@ -505,7 +501,7 @@ def test_optional_anchors_linear( expected_datapoints = _get_expected_datapoints(datapoints, expected_dp_keys) renderer = VegaRenderer(datapoints, "foo", **props) - plot_content = renderer.get_filled_template(as_string=False) + plot_content = renderer.get_filled_template() assert plot_content["data"]["values"] == expected_datapoints assert plot_content["encoding"]["color"] == color_encoding @@ -659,20 +655,16 @@ def test_partial_filled_template( split_anchors.append(Template.anchor("stroke_dash")) expected_split[Template.anchor("stroke_dash")] = stroke_dash_encoding - renderer = VegaRenderer(datapoints, "foo", **props) - content, split = renderer.get_partial_filled_template() + content, split = VegaRenderer( + datapoints, "foo", **props + ).get_partial_filled_template() + + content_str = json.dumps(content) for anchor in split_anchors: - assert anchor in content + assert anchor in content_str for key, value in split["anchor_definitions"].items(): - if key in [ - Template.anchor("x_label"), - Template.anchor("y_label"), - Template.anchor("title"), - ]: - assert value == expected_split[key] - continue - assert json.loads(value) == expected_split[key] + assert value == expected_split[key] def _get_expected_datapoints( @@ -694,3 +686,13 @@ def _get_expected_datapoints( expected_datapoints.append(expected_datapoint) return datapoints + + +def test_partial_html(): + props = {"x": "x", "y": "y"} + datapoints = [ + {"x": 100, "y": 100, "val": 2}, + {"x": 200, "y": 300, "val": 3}, + ] + + assert isinstance(VegaRenderer(datapoints, "foo", **props).partial_html(), str) From 35678583f0a01fa9a6806a67d4fe25460faf612b Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Tue, 14 Nov 2023 12:57:58 +1100 Subject: [PATCH 29/39] move horizontal bar plots from row to column --- src/dvc_render/vega.py | 4 ++++ src/dvc_render/vega_templates.py | 8 ++++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 281f835..3282c96 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -210,6 +210,7 @@ def _process_optional_anchors(self, split_anchors: List[str]): anchor for anchor in [ "color", + "column", "column_width", "group_by_x", "group_by_y", @@ -335,6 +336,9 @@ def _fill_optional_multi_source_anchors( self._fill_optional_anchor( split_anchors, optional_anchors, "row", {"field": concat_field, "sort": []} ) + self._fill_optional_anchor( + split_anchors, optional_anchors, "column", {"field": concat_field, "sort": []} + ) self._fill_tooltip(split_anchors, optional_anchors, [concat_field]) diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index a1db9a6..29dfc5d 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -161,7 +161,7 @@ class BarHorizontalSortedTemplate(Template): }, "yOffset": {"field": "rev", "sort": []}, "color": Template.anchor("color"), - "row": Template.anchor("row"), + "column": Template.anchor("column"), }, } @@ -191,7 +191,7 @@ class BarHorizontalTemplate(Template): }, "yOffset": {"field": "rev", "sort": []}, "color": Template.anchor("color"), - "row": Template.anchor("row"), + "column": Template.anchor("column"), }, } @@ -365,8 +365,8 @@ class NormalizedConfusionTemplate(Template): "layer": [ { "mark": "rect", - "width": Template.anchor("column_width"), - "height": Template.anchor("row_height"), + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "encoding": { "color": { "field": "percent_of_y", From 2ff42dacb70174f3c61536e347af125aebdd0d63 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Tue, 14 Nov 2023 12:58:31 +1100 Subject: [PATCH 30/39] remove row_height and column_width anchors --- src/dvc_render/vega.py | 11 ++++------- src/dvc_render/vega_templates.py | 4 ++-- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 3282c96..d6cffee 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -127,11 +127,9 @@ def get_partial_filled_template(self): content = self.get_filled_template( split_anchors=[ "color", - "column_width", "data", "plot_height", "plot_width", - "row_height", "shape", "stroke_dash", "title", @@ -211,14 +209,12 @@ def _process_optional_anchors(self, split_anchors: List[str]): for anchor in [ "color", "column", - "column_width", "group_by_x", "group_by_y", "group_by", "pivot_field", "plot_height", "plot_width", - "row_height", "row", "shape", "stroke_dash", @@ -257,10 +253,8 @@ def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): def _fill_set_encoding(self, split_anchors: List[str], optional_anchors: List[str]): for name, encoding in [ ("zoom_and_pan", {"name": "grid", "select": "interval", "bind": "scales"}), - ("column_width", 300), ("plot_height", 300), ("plot_width", 300), - ("row_height", 300), ]: self._fill_optional_anchor(split_anchors, optional_anchors, name, encoding) @@ -337,7 +331,10 @@ def _fill_optional_multi_source_anchors( split_anchors, optional_anchors, "row", {"field": concat_field, "sort": []} ) self._fill_optional_anchor( - split_anchors, optional_anchors, "column", {"field": concat_field, "sort": []} + split_anchors, + optional_anchors, + "column", + {"field": concat_field, "sort": []}, ) self._fill_tooltip(split_anchors, optional_anchors, [concat_field]) diff --git a/src/dvc_render/vega_templates.py b/src/dvc_render/vega_templates.py index 29dfc5d..ffac915 100644 --- a/src/dvc_render/vega_templates.py +++ b/src/dvc_render/vega_templates.py @@ -252,8 +252,8 @@ class ConfusionTemplate(Template): "layer": [ { "mark": "rect", - "width": Template.anchor("column_width"), - "height": Template.anchor("row_height"), + "width": Template.anchor("plot_width"), + "height": Template.anchor("plot_height"), "encoding": { "color": { "field": "xy_count", From b709fa9b86bf9b4368758f17dc567d6a807a61ec Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Tue, 14 Nov 2023 14:48:22 +1100 Subject: [PATCH 31/39] swap square and circle --- src/dvc_render/vega.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index d6cffee..cdba5a8 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -62,7 +62,7 @@ def __init__(self, datapoints: List, name: str, **properties): "#ed8936", "#f56565", ], - "shape": ["square", "circle", "triangle", "diamond"], + "shape": ["circle", "square", "triangle", "diamond"], } self._split_content: Dict[str, str] = {} From 9f34f8e17197f4255b75ce823aa45e2b52ac63c8 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 15 Nov 2023 11:02:23 +1100 Subject: [PATCH 32/39] use field separator for pivot field --- src/dvc_render/vega.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index cdba5a8..685e07b 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -323,7 +323,7 @@ def _fill_optional_multi_source_anchors( split_anchors, optional_anchors, "pivot_field", - " + '::' + ".join([f"datum.{key}" for key in grouped_keys]), + f" + '{FIELD_SEPARATOR}' + ".join([f"datum.{key}" for key in grouped_keys]), ) concat_field = FIELD_SEPARATOR.join(varied_keys) From 8150f7ab629eacaa4d01977a60ed86e8b2dbc323 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Wed, 15 Nov 2023 11:20:23 +1100 Subject: [PATCH 33/39] fix linear plots with varied filename field --- src/dvc_render/vega.py | 24 +++++++++++++----------- tests/test_vega.py | 2 +- 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 685e07b..5b270fb 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -11,6 +11,7 @@ from .vega_templates import BadTemplateError, LinearTemplate, Template, get_template FIELD_SEPARATOR = "::" +REV = "rev" FILENAME = "filename" FIELD = "field" FILENAME_FIELD = [FILENAME, FIELD] @@ -245,7 +246,7 @@ def _fill_color(self, split_anchors: List[str], optional_anchors: List[str]): self._fill_optional_anchor_mapping( split_anchors, optional_anchors, - "rev", + REV, "color", all_revs, ) @@ -261,7 +262,7 @@ def _fill_set_encoding(self, split_anchors: List[str], optional_anchors: List[st def _process_single_source_plot( self, split_anchors: List[str], optional_anchors: List[str] ): - self._fill_group_by(split_anchors, optional_anchors, ["rev"]) + self._fill_group_by(split_anchors, optional_anchors, [REV]) self._fill_optional_anchor( split_anchors, optional_anchors, "pivot_field", "datum.rev" ) @@ -316,17 +317,18 @@ def _fill_optional_multi_source_anchors( if not optional_anchors: return - grouped_keys = ["rev", *varied_keys] - self._fill_group_by(split_anchors, optional_anchors, grouped_keys) + concat_field = FIELD_SEPARATOR.join(varied_keys) + self._fill_group_by(split_anchors, optional_anchors, [REV, concat_field]) self._fill_optional_anchor( split_anchors, optional_anchors, "pivot_field", - f" + '{FIELD_SEPARATOR}' + ".join([f"datum.{key}" for key in grouped_keys]), + f" + '{FIELD_SEPARATOR}' + ".join( + [f"datum.{key}" for key in [REV, *varied_keys]] + ), ) - concat_field = FIELD_SEPARATOR.join(varied_keys) self._fill_optional_anchor( split_anchors, optional_anchors, "row", {"field": concat_field, "sort": []} ) @@ -348,22 +350,22 @@ def _fill_group_by( self, split_anchors: List[str], optional_anchors: List[str], - grouped_keys: List[str], + group_by: List[str], ): self._fill_optional_anchor( - split_anchors, optional_anchors, "group_by", grouped_keys + split_anchors, optional_anchors, "group_by", group_by ) self._fill_optional_anchor( split_anchors, optional_anchors, "group_by_x", - [*grouped_keys, self.properties.get("x")], + [*group_by, self.properties.get("x")], ) self._fill_optional_anchor( split_anchors, optional_anchors, "group_by_y", - [*grouped_keys, self.properties.get("y")], + [*group_by, self.properties.get("y")], ) def _fill_tooltip( @@ -379,7 +381,7 @@ def _fill_tooltip( optional_anchors, "tooltip", [ - {"field": "rev"}, + {"field": REV}, {"field": self.properties.get("x")}, {"field": self.properties.get("y")}, *[{"field": field} for field in additional_fields], diff --git a/tests/test_vega.py b/tests/test_vega.py index 7d80fe2..d655187 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -476,7 +476,7 @@ def test_fill_anchor_in_string(tmp_dir): }, }, "datum.rev + '::' + datum.filename + '::' + datum.field", - ["rev", "filename", "field"], + ["rev", "filename::field"], ), ), ) From 71e8793f7484b655dc00a53ac812ea5e452950a2 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Tue, 28 Nov 2023 11:52:33 +1100 Subject: [PATCH 34/39] extend optional anchor tests --- tests/test_vega.py | 373 ++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 353 insertions(+), 20 deletions(-) diff --git a/tests/test_vega.py b/tests/test_vega.py index d655187..dcb3248 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -6,7 +6,7 @@ from dvc_render.vega import BadTemplateError, VegaRenderer from dvc_render.vega_templates import NoFieldInDataError, Template -# pylint: disable=missing-function-docstring, C1803 +# pylint: disable=missing-function-docstring, C1803, C0302 @pytest.mark.parametrize( @@ -269,7 +269,6 @@ def test_fill_anchor_in_string(tmp_dir): "y", "anchors_y_definitions", "expected_dp_keys", - "color_encoding", "stroke_dash_encoding", "pivot_field", "group_by", @@ -296,10 +295,6 @@ def test_fill_anchor_in_string(tmp_dir): "acc", [{"filename": "test", "field": "acc"}], ["rev", "acc", "step"], - { - "field": "rev", - "scale": {"domain": ["B"], "range": ["#945dd6"]}, - }, {}, "datum.rev", ["rev"], @@ -341,10 +336,6 @@ def test_fill_anchor_in_string(tmp_dir): {"filename": "train", "field": "acc"}, ], ["rev", "acc", "step", "filename"], - { - "field": "rev", - "scale": {"domain": ["B"], "range": ["#945dd6"]}, - }, { "field": "filename", "scale": {"domain": ["test", "train"], "range": [[1, 0], [8, 8]]}, @@ -393,10 +384,6 @@ def test_fill_anchor_in_string(tmp_dir): {"filename": "test", "field": "acc_norm"}, ], ["rev", "dvc_inferred_y_value", "step", "field"], - { - "field": "rev", - "scale": {"domain": ["B"], "range": ["#945dd6"]}, - }, { "field": "field", "scale": {"domain": ["acc", "acc_norm"], "range": [[1, 0], [8, 8]]}, @@ -460,10 +447,6 @@ def test_fill_anchor_in_string(tmp_dir): {"filename": "train", "field": "acc"}, ], ["rev", "dvc_inferred_y_value", "step", "filename::field"], - { - "field": "rev", - "scale": {"domain": ["B"], "range": ["#945dd6"]}, - }, { "field": "filename::field", "scale": { @@ -485,7 +468,6 @@ def test_optional_anchors_linear( y, anchors_y_definitions, expected_dp_keys, - color_encoding, stroke_dash_encoding, pivot_field, group_by, @@ -504,12 +486,363 @@ def test_optional_anchors_linear( plot_content = renderer.get_filled_template() assert plot_content["data"]["values"] == expected_datapoints - assert plot_content["encoding"]["color"] == color_encoding + assert plot_content["encoding"]["color"] == { + "field": "rev", + "scale": {"domain": ["B"], "range": ["#945dd6"]}, + } assert plot_content["encoding"]["strokeDash"] == stroke_dash_encoding assert plot_content["layer"][3]["transform"][0]["calculate"] == pivot_field assert plot_content["layer"][0]["transform"][0]["groupby"] == group_by +@pytest.mark.parametrize( + ",".join( + [ + "datapoints", + "y", + "anchors_y_definitions", + "expected_dp_keys", + "row_encoding", + "group_by_y", + "group_by_x", + ] + ), + ( + ( + [ + { + "rev": "B", + "predicted": "0.05", + "actual": "0.5", + "filename": "test", + "field": "predicted", + }, + { + "rev": "B", + "predicted": "0.9", + "actual": "0.9", + "filename": "test", + "field": "predicted", + }, + ], + "predicted", + [{"filename": "test", "field": "predicted"}], + ["rev", "predicted", "actual"], + {}, + ["rev", "predicted"], + ["rev", "actual"], + ), + ( + [ + { + "rev": "B", + "predicted": "0.05", + "actual": "0.5", + "filename": "test", + "field": "predicted", + }, + { + "rev": "B", + "predicted": "0.9", + "actual": "0.9", + "filename": "train", + "field": "predicted", + }, + ], + "predicted", + [ + {"filename": "test", "field": "predicted"}, + {"filename": "train", "field": "predicted"}, + ], + ["rev", "predicted", "actual"], + {"field": "filename", "sort": []}, + ["rev", "filename", "predicted"], + ["rev", "filename", "actual"], + ), + ( + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "predicted_test": "0.05", + "actual": "0.5", + "filename": "test", + "field": "predicted", + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.9", + "predicted_test": "0.9", + "actual": "0.9", + "filename": "test", + "field": "predicted", + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.9", + "predicted_train": "0.9", + "actual": "0.9", + "filename": "train", + "field": "predicted", + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.9", + "predicted_train": "0.9", + "actual": "0.9", + "filename": "train", + "field": "predicted", + }, + ], + "dvc_inferred_y_value", + [ + {"filename": "test", "field": "predicted_test"}, + {"filename": "train", "field": "predicted_train"}, + ], + ["rev", "predicted", "actual"], + {"field": "filename::field", "sort": []}, + ["rev", "filename::field", "dvc_inferred_y_value"], + ["rev", "filename::field", "actual"], + ), + ), +) +def test_optional_anchors_confusion( + datapoints, + y, + anchors_y_definitions, + expected_dp_keys, + row_encoding, + group_by_y, + group_by_x, +): # pylint: disable=too-many-arguments + props = { + "template": "confusion", + "x": "actual", + "y": y, + "revs_with_datapoints": ["B"], + "anchors_y_definitions": anchors_y_definitions, + } + + expected_datapoints = _get_expected_datapoints(datapoints, expected_dp_keys) + + renderer = VegaRenderer(datapoints, "foo", **props) + plot_content = renderer.get_filled_template() + + assert plot_content["data"]["values"] == expected_datapoints + assert plot_content["facet"]["row"] == row_encoding + assert plot_content["spec"]["transform"][0]["groupby"] == [y, "actual"] + assert plot_content["spec"]["transform"][1]["groupby"] == group_by_y + assert plot_content["spec"]["transform"][2]["groupby"] == group_by_x + assert plot_content["spec"]["layer"][0]["width"] == 300 + assert plot_content["spec"]["layer"][0]["height"] == 300 + + +@pytest.mark.parametrize( + ",".join( + [ + "datapoints", + "y", + "anchors_y_definitions", + "expected_dp_keys", + "shape_encoding", + "tooltip_encoding", + ] + ), + ( + ( + [ + { + "rev": "B", + "acc": "0.05", + "other": "field", + "filename": "test", + "field": "acc", + "loss": 0.1, + }, + { + "rev": "C", + "acc": "0.1", + "other": "field", + "filename": "test", + "field": "acc", + "loss": 2, + }, + ], + "acc", + [{"filename": "test", "field": "acc"}], + ["rev", "acc", "other", "loss"], + {}, + [{"field": "rev"}, {"field": "loss"}, {"field": "acc"}], + ), + ( + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "test_acc": "0.05", + "train_acc": "0.06", + "other": "field", + "filename": "data", + "field": "test_acc", + "loss": 0.1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.06", + "test_acc": "0.05", + "train_acc": "0.06", + "other": "field", + "filename": "data", + "field": "train_acc", + "loss": 0.1, + }, + { + "rev": "C", + "dvc_inferred_y_value": "0.1", + "train_acc": "0.1", + "test_acc": "0.2", + "other": "field", + "filename": "train_acc", + "field": "acc", + "loss": 2, + }, + { + "rev": "C", + "dvc_inferred_y_value": "0.2", + "train_acc": "0.1", + "test_acc": "0.2", + "other": "field", + "filename": "test_acc", + "field": "acc", + "loss": 2, + }, + ], + "dvc_inferred_y_value", + [ + {"filename": "data", "field": "train_acc"}, + {"filename": "data", "field": "test_acc"}, + ], + ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "other", "loss"], + { + "field": "field", + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + "scale": { + "domain": ["test_acc", "train_acc"], + "range": ["circle", "square"], + }, + }, + [ + {"field": "rev"}, + {"field": "loss"}, + {"field": "dvc_inferred_y_value"}, + {"field": "field"}, + ], + ), + ( + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "test_acc": "0.05", + "other": "field", + "filename": "test", + "field": "test_acc", + "loss": 0.1, + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.06", + "train_acc": "0.06", + "other": "field", + "filename": "train", + "field": "train_acc", + "loss": 0.1, + }, + { + "rev": "C", + "dvc_inferred_y_value": "0.2", + "test_acc": "0.2", + "other": "field", + "filename": "test_acc", + "field": "acc", + "loss": 2, + }, + { + "rev": "C", + "dvc_inferred_y_value": "0.2", + "train_acc": "0.1", + "other": "field", + "filename": "train_acc", + "field": "acc", + "loss": 2, + }, + ], + "dvc_inferred_y_value", + [ + {"filename": "train", "field": "train_acc"}, + {"filename": "test", "field": "test_acc"}, + ], + ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "other", "loss"], + { + "field": "filename::field", + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + "scale": { + "domain": ["test::test_acc", "train::train_acc"], + "range": ["circle", "square"], + }, + }, + [ + {"field": "rev"}, + {"field": "loss"}, + {"field": "dvc_inferred_y_value"}, + {"field": "filename::field"}, + ], + ), + ), +) +def test_optional_anchors_scatter( + datapoints, + y, + anchors_y_definitions, + expected_dp_keys, + shape_encoding, + tooltip_encoding, +): # pylint: disable=too-many-arguments + props = { + "template": "scatter", + "x": "loss", + "y": y, + "revs_with_datapoints": ["B", "C"], + "anchors_y_definitions": anchors_y_definitions, + } + + expected_datapoints = _get_expected_datapoints(datapoints, expected_dp_keys) + + renderer = VegaRenderer(datapoints, "foo", **props) + plot_content = renderer.get_filled_template() + + assert plot_content["data"]["values"] == expected_datapoints + assert plot_content["encoding"]["color"] == { + "field": "rev", + "scale": {"domain": ["B", "C"], "range": ["#945dd6", "#13adc7"]}, + } + assert plot_content["encoding"]["shape"] == shape_encoding + assert plot_content["encoding"]["tooltip"] == tooltip_encoding + assert plot_content["params"] == [ + { + "name": "grid", + "select": "interval", + "bind": "scales", + } + ] + + @pytest.mark.parametrize( "datapoints,y,anchors_y_definitions,expected_dp_keys,stroke_dash_encoding", ( From 86ae6c54b27f9040f1361aad481eabee665084d5 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Mon, 4 Dec 2023 11:25:19 +1100 Subject: [PATCH 35/39] ensure all 4 variations have test cases --- tests/test_vega.py | 172 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 155 insertions(+), 17 deletions(-) diff --git a/tests/test_vega.py b/tests/test_vega.py index dcb3248..0e3a26c 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -275,7 +275,7 @@ def test_fill_anchor_in_string(tmp_dir): ] ), ( - ( + pytest.param( [ { "rev": "B", @@ -298,8 +298,9 @@ def test_fill_anchor_in_string(tmp_dir): {}, "datum.rev", ["rev"], + id="single_source", ), - ( + pytest.param( [ { "rev": "B", @@ -346,8 +347,9 @@ def test_fill_anchor_in_string(tmp_dir): }, "datum.rev + '::' + datum.filename", ["rev", "filename"], + id="multi_filename", ), - ( + pytest.param( [ { "rev": "B", @@ -366,7 +368,7 @@ def test_fill_anchor_in_string(tmp_dir): { "rev": "B", "dvc_inferred_y_value": "0.04", - "filename": "train", + "filename": "test", "field": "acc_norm", "step": 1, }, @@ -394,8 +396,9 @@ def test_fill_anchor_in_string(tmp_dir): }, "datum.rev + '::' + datum.field", ["rev", "field"], + id="multi_field", ), - ( + pytest.param( [ { "rev": "B", @@ -460,6 +463,7 @@ def test_fill_anchor_in_string(tmp_dir): }, "datum.rev + '::' + datum.filename + '::' + datum.field", ["rev", "filename::field"], + id="multi_filename_field", ), ), ) @@ -508,7 +512,7 @@ def test_optional_anchors_linear( ] ), ( - ( + pytest.param( [ { "rev": "B", @@ -531,8 +535,9 @@ def test_optional_anchors_linear( {}, ["rev", "predicted"], ["rev", "actual"], + id="single_source", ), - ( + pytest.param( [ { "rev": "B", @@ -558,8 +563,41 @@ def test_optional_anchors_linear( {"field": "filename", "sort": []}, ["rev", "filename", "predicted"], ["rev", "filename", "actual"], + id="multi_filename", ), - ( + pytest.param( + [ + { + "rev": "B", + "dvc_inferred_y_value": "0.05", + "predicted_train": "0.05", + "predicted_test": "0.9", + "actual": "0.5", + "filename": "data", + "field": "predicted_test", + }, + { + "rev": "B", + "dvc_inferred_y_value": "0.9", + "predicted_train": "0.05", + "predicted_test": "0.9", + "actual": "0.5", + "filename": "data", + "field": "predicted_train", + }, + ], + "dvc_inferred_y_value", + [ + {"filename": "data", "field": "predicted_test"}, + {"filename": "data", "field": "predicted_train"}, + ], + ["rev", "dvc_inferred_y_value", "actual"], + {"field": "field", "sort": []}, + ["rev", "field", "dvc_inferred_y_value"], + ["rev", "field", "actual"], + id="multi_field", + ), + pytest.param( [ { "rev": "B", @@ -603,6 +641,7 @@ def test_optional_anchors_linear( {"field": "filename::field", "sort": []}, ["rev", "filename::field", "dvc_inferred_y_value"], ["rev", "filename::field", "actual"], + id="multi_filename_field", ), ), ) @@ -649,7 +688,7 @@ def test_optional_anchors_confusion( ] ), ( - ( + pytest.param( [ { "rev": "B", @@ -673,8 +712,69 @@ def test_optional_anchors_confusion( ["rev", "acc", "other", "loss"], {}, [{"field": "rev"}, {"field": "loss"}, {"field": "acc"}], + id="single_source", ), - ( + pytest.param( + [ + { + "rev": "B", + "acc": "0.05", + "other": "field", + "filename": "train", + "field": "acc", + "loss": "0.0001", + }, + { + "rev": "B", + "acc": "0.06", + "other": "field", + "filename": "test", + "field": "acc", + "loss": "200121", + }, + { + "rev": "C", + "acc": "0.1", + "other": "field", + "filename": "train", + "field": "acc", + "loss": "10", + }, + { + "rev": "C", + "acc": "0.1", + "other": "field", + "filename": "test", + "field": "acc", + "loss": "100", + }, + ], + "acc", + [ + {"filename": "train", "field": "acc"}, + {"filename": "test", "field": "acc"}, + ], + ["rev", "acc", "filename", "loss", "other"], + { + "field": "filename", + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + "scale": { + "domain": ["test", "train"], + "range": ["circle", "square"], + }, + }, + [ + {"field": "rev"}, + {"field": "loss"}, + {"field": "acc"}, + {"field": "filename"}, + ], + id="multi_filename", + ), + pytest.param( [ { "rev": "B", @@ -702,7 +802,7 @@ def test_optional_anchors_confusion( "train_acc": "0.1", "test_acc": "0.2", "other": "field", - "filename": "train_acc", + "filename": "data", "field": "acc", "loss": 2, }, @@ -712,7 +812,7 @@ def test_optional_anchors_confusion( "train_acc": "0.1", "test_acc": "0.2", "other": "field", - "filename": "test_acc", + "filename": "data", "field": "acc", "loss": 2, }, @@ -740,8 +840,9 @@ def test_optional_anchors_confusion( {"field": "dvc_inferred_y_value"}, {"field": "field"}, ], + id="multi_field", ), - ( + pytest.param( [ { "rev": "B", @@ -803,6 +904,7 @@ def test_optional_anchors_confusion( {"field": "dvc_inferred_y_value"}, {"field": "filename::field"}, ], + id="multi_filename_field", ), ), ) @@ -846,7 +948,7 @@ def test_optional_anchors_scatter( @pytest.mark.parametrize( "datapoints,y,anchors_y_definitions,expected_dp_keys,stroke_dash_encoding", ( - ( + pytest.param( [ { "rev": "B", @@ -867,8 +969,42 @@ def test_optional_anchors_scatter( [{"filename": "test", "field": "acc"}], ["rev", "acc", "step"], {}, + id="single_source", ), - ( + pytest.param( + [ + { + "rev": "B", + "acc": "0.05", + "filename": "test", + "field": "acc", + "step": 1, + }, + { + "rev": "B", + "acc": "0.04", + "filename": "train", + "field": "acc", + "step": 1, + }, + ], + "acc", + [ + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], + ["rev", "acc", "step", "field"], + { + "field": "filename", + "scale": {"domain": ["test", "train"], "range": [[1, 0], [8, 8]]}, + "legend": { + "symbolFillColor": "transparent", + "symbolStrokeColor": "grey", + }, + }, + id="multi_filename", + ), + pytest.param( [ { "rev": "B", @@ -880,7 +1016,7 @@ def test_optional_anchors_scatter( { "rev": "B", "dvc_inferred_y_value": "0.04", - "filename": "train", + "filename": "test", "field": "acc_norm", "step": 1, }, @@ -899,8 +1035,9 @@ def test_optional_anchors_scatter( "symbolStrokeColor": "grey", }, }, + id="multi_field", ), - ( + pytest.param( [ { "rev": "B", @@ -942,6 +1079,7 @@ def test_optional_anchors_scatter( "symbolStrokeColor": "grey", }, }, + id="multi_filename_field", ), ), ) From e9fa62fc9cd8c56676b5fa7859fb09d87da4fd45 Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Mon, 4 Dec 2023 11:53:33 +1100 Subject: [PATCH 36/39] add constants and use in tests --- src/dvc_render/vega.py | 96 ++++++++++++++++++++---------------------- tests/test_vega.py | 38 +++++++++++------ 2 files changed, 71 insertions(+), 63 deletions(-) diff --git a/src/dvc_render/vega.py b/src/dvc_render/vega.py index 5b270fb..9a62bc6 100644 --- a/src/dvc_render/vega.py +++ b/src/dvc_render/vega.py @@ -17,6 +17,48 @@ FILENAME_FIELD = [FILENAME, FIELD] CONCAT_FIELDS = FIELD_SEPARATOR.join(FILENAME_FIELD) +SPLIT_ANCHORS = [ + "color", + "data", + "plot_height", + "plot_width", + "shape", + "stroke_dash", + "title", + "tooltip", + "x_label", + "y_label", + "zoom_and_pan", +] +OPTIONAL_ANCHORS = [ + "color", + "column", + "group_by_x", + "group_by_y", + "group_by", + "pivot_field", + "plot_height", + "plot_width", + "row", + "shape", + "stroke_dash", + "tooltip", + "zoom_and_pan", +] +OPTIONAL_ANCHOR_RANGES: Dict[str, Union[List[str], List[List[int]]]] = { + "stroke_dash": [[1, 0], [8, 8], [8, 4], [4, 4], [4, 2], [2, 1], [1, 1]], + "color": [ + "#945dd6", + "#13adc7", + "#f46837", + "#48bb78", + "#4299e1", + "#ed8936", + "#f56565", + ], + "shape": ["circle", "square", "triangle", "diamond"], +} + class VegaRenderer(Renderer): """Renderer for vega plots.""" @@ -46,25 +88,6 @@ def __init__(self, datapoints: List, name: str, **properties): self.properties.get("template", None), self.properties.get("template_dir", None), ) - self._optional_anchor_ranges: Dict[ - str, - Union[ - List[str], - List[List[int]], - ], - ] = { - "stroke_dash": [[1, 0], [8, 8], [8, 4], [4, 4], [4, 2], [2, 1], [1, 1]], - "color": [ - "#945dd6", - "#13adc7", - "#f46837", - "#48bb78", - "#4299e1", - "#ed8936", - "#f56565", - ], - "shape": ["circle", "square", "triangle", "diamond"], - } self._split_content: Dict[str, str] = {} @@ -126,19 +149,7 @@ def get_partial_filled_template(self): Returns a partially filled template along with the split out anchor content """ content = self.get_filled_template( - split_anchors=[ - "color", - "data", - "plot_height", - "plot_width", - "shape", - "stroke_dash", - "title", - "tooltip", - "x_label", - "y_label", - "zoom_and_pan", - ], + split_anchors=SPLIT_ANCHORS, strict=True, ) return content, {"anchor_definitions": self._split_content} @@ -206,23 +217,7 @@ def get_revs(self): def _process_optional_anchors(self, split_anchors: List[str]): optional_anchors = [ - anchor - for anchor in [ - "color", - "column", - "group_by_x", - "group_by_y", - "group_by", - "pivot_field", - "plot_height", - "plot_width", - "row", - "shape", - "stroke_dash", - "tooltip", - "zoom_and_pan", - ] - if self.template.has_anchor(anchor) + anchor for anchor in OPTIONAL_ANCHORS if self.template.has_anchor(anchor) ] if not optional_anchors: return None @@ -443,7 +438,7 @@ def _get_optional_anchor_mapping( name: str, domain: List[str], ): - full_range_values: List[Any] = self._optional_anchor_ranges.get(name, []) + full_range_values: List[Any] = OPTIONAL_ANCHOR_RANGES.get(name, []) anchor_range_values = full_range_values.copy() anchor_range = [] @@ -454,6 +449,7 @@ def _get_optional_anchor_mapping( anchor_range.append(range_value) legend = ( + # fix stroke dash and shape legend entry appearance (use empty shapes) {"legend": {"symbolFillColor": "transparent", "symbolStrokeColor": "grey"}} if name != "color" else {} diff --git a/tests/test_vega.py b/tests/test_vega.py index 0e3a26c..9c6682c 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -3,7 +3,7 @@ import pytest -from dvc_render.vega import BadTemplateError, VegaRenderer +from dvc_render.vega import OPTIONAL_ANCHOR_RANGES, BadTemplateError, VegaRenderer from dvc_render.vega_templates import NoFieldInDataError, Template # pylint: disable=missing-function-docstring, C1803, C0302 @@ -339,7 +339,10 @@ def test_fill_anchor_in_string(tmp_dir): ["rev", "acc", "step", "filename"], { "field": "filename", - "scale": {"domain": ["test", "train"], "range": [[1, 0], [8, 8]]}, + "scale": { + "domain": ["test", "train"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, "legend": { "symbolFillColor": "transparent", "symbolStrokeColor": "grey", @@ -388,7 +391,10 @@ def test_fill_anchor_in_string(tmp_dir): ["rev", "dvc_inferred_y_value", "step", "field"], { "field": "field", - "scale": {"domain": ["acc", "acc_norm"], "range": [[1, 0], [8, 8]]}, + "scale": { + "domain": ["acc", "acc_norm"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, "legend": { "symbolFillColor": "transparent", "symbolStrokeColor": "grey", @@ -454,7 +460,7 @@ def test_fill_anchor_in_string(tmp_dir): "field": "filename::field", "scale": { "domain": ["test::acc", "test::acc_norm", "train::acc"], - "range": [[1, 0], [8, 8], [8, 4]], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:3], }, "legend": { "symbolFillColor": "transparent", @@ -492,7 +498,7 @@ def test_optional_anchors_linear( assert plot_content["data"]["values"] == expected_datapoints assert plot_content["encoding"]["color"] == { "field": "rev", - "scale": {"domain": ["B"], "range": ["#945dd6"]}, + "scale": {"domain": ["B"], "range": OPTIONAL_ANCHOR_RANGES["color"][0:1]}, } assert plot_content["encoding"]["strokeDash"] == stroke_dash_encoding assert plot_content["layer"][3]["transform"][0]["calculate"] == pivot_field @@ -763,7 +769,7 @@ def test_optional_anchors_confusion( }, "scale": { "domain": ["test", "train"], - "range": ["circle", "square"], + "range": OPTIONAL_ANCHOR_RANGES["shape"][0:2], }, }, [ @@ -831,7 +837,7 @@ def test_optional_anchors_confusion( }, "scale": { "domain": ["test_acc", "train_acc"], - "range": ["circle", "square"], + "range": OPTIONAL_ANCHOR_RANGES["shape"][0:2], }, }, [ @@ -895,7 +901,7 @@ def test_optional_anchors_confusion( }, "scale": { "domain": ["test::test_acc", "train::train_acc"], - "range": ["circle", "square"], + "range": OPTIONAL_ANCHOR_RANGES["shape"][0:2], }, }, [ @@ -932,7 +938,7 @@ def test_optional_anchors_scatter( assert plot_content["data"]["values"] == expected_datapoints assert plot_content["encoding"]["color"] == { "field": "rev", - "scale": {"domain": ["B", "C"], "range": ["#945dd6", "#13adc7"]}, + "scale": {"domain": ["B", "C"], "range": OPTIONAL_ANCHOR_RANGES["color"][0:2]}, } assert plot_content["encoding"]["shape"] == shape_encoding assert plot_content["encoding"]["tooltip"] == tooltip_encoding @@ -996,7 +1002,10 @@ def test_optional_anchors_scatter( ["rev", "acc", "step", "field"], { "field": "filename", - "scale": {"domain": ["test", "train"], "range": [[1, 0], [8, 8]]}, + "scale": { + "domain": ["test", "train"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, "legend": { "symbolFillColor": "transparent", "symbolStrokeColor": "grey", @@ -1029,7 +1038,10 @@ def test_optional_anchors_scatter( ["rev", "dvc_inferred_y_value", "step", "field"], { "field": "field", - "scale": {"domain": ["acc", "acc_norm"], "range": [[1, 0], [8, 8]]}, + "scale": { + "domain": ["acc", "acc_norm"], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:2], + }, "legend": { "symbolFillColor": "transparent", "symbolStrokeColor": "grey", @@ -1072,7 +1084,7 @@ def test_optional_anchors_scatter( "field": "filename::field", "scale": { "domain": ["test::acc", "test::acc_norm", "train::acc"], - "range": [[1, 0], [8, 8], [8, 4]], + "range": OPTIONAL_ANCHOR_RANGES["stroke_dash"][0:3], }, "legend": { "symbolFillColor": "transparent", @@ -1103,7 +1115,7 @@ def test_partial_filled_template( expected_split = { Template.anchor("color"): { "field": "rev", - "scale": {"domain": ["B"], "range": ["#945dd6"]}, + "scale": {"domain": ["B"], "range": OPTIONAL_ANCHOR_RANGES["color"][0:1]}, }, Template.anchor("data"): _get_expected_datapoints(datapoints, expected_dp_keys), Template.anchor("plot_height"): 300, From 7054de69c02d2466d56148830455c352916585ad Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Mon, 4 Dec 2023 12:50:13 +1100 Subject: [PATCH 37/39] hoist anchors_y_definition in tests --- tests/test_vega.py | 150 ++++++++++++++++++++++++--------------------- 1 file changed, 79 insertions(+), 71 deletions(-) diff --git a/tests/test_vega.py b/tests/test_vega.py index 9c6682c..ec20527 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -265,9 +265,9 @@ def test_fill_anchor_in_string(tmp_dir): @pytest.mark.parametrize( ",".join( [ + "anchors_y_definitions", "datapoints", "y", - "anchors_y_definitions", "expected_dp_keys", "stroke_dash_encoding", "pivot_field", @@ -276,6 +276,7 @@ def test_fill_anchor_in_string(tmp_dir): ), ( pytest.param( + [{"filename": "test", "field": "acc"}], [ { "rev": "B", @@ -293,7 +294,6 @@ def test_fill_anchor_in_string(tmp_dir): }, ], "acc", - [{"filename": "test", "field": "acc"}], ["rev", "acc", "step"], {}, "datum.rev", @@ -301,6 +301,10 @@ def test_fill_anchor_in_string(tmp_dir): id="single_source", ), pytest.param( + [ + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], [ { "rev": "B", @@ -332,10 +336,6 @@ def test_fill_anchor_in_string(tmp_dir): }, ], "acc", - [ - {"filename": "test", "field": "acc"}, - {"filename": "train", "field": "acc"}, - ], ["rev", "acc", "step", "filename"], { "field": "filename", @@ -353,6 +353,10 @@ def test_fill_anchor_in_string(tmp_dir): id="multi_filename", ), pytest.param( + [ + {"filename": "test", "field": "acc"}, + {"filename": "test", "field": "acc_norm"}, + ], [ { "rev": "B", @@ -384,10 +388,6 @@ def test_fill_anchor_in_string(tmp_dir): }, ], "dvc_inferred_y_value", - [ - {"filename": "test", "field": "acc"}, - {"filename": "test", "field": "acc_norm"}, - ], ["rev", "dvc_inferred_y_value", "step", "field"], { "field": "field", @@ -405,6 +405,11 @@ def test_fill_anchor_in_string(tmp_dir): id="multi_field", ), pytest.param( + [ + {"filename": "test", "field": "acc_norm"}, + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], [ { "rev": "B", @@ -450,11 +455,6 @@ def test_fill_anchor_in_string(tmp_dir): }, ], "dvc_inferred_y_value", - [ - {"filename": "test", "field": "acc_norm"}, - {"filename": "test", "field": "acc"}, - {"filename": "train", "field": "acc"}, - ], ["rev", "dvc_inferred_y_value", "step", "filename::field"], { "field": "filename::field", @@ -474,20 +474,20 @@ def test_fill_anchor_in_string(tmp_dir): ), ) def test_optional_anchors_linear( + anchors_y_definitions, datapoints, y, - anchors_y_definitions, expected_dp_keys, stroke_dash_encoding, pivot_field, group_by, ): # pylint: disable=too-many-arguments props = { + "anchors_y_definitions": anchors_y_definitions, + "revs_with_datapoints": ["B"], "template": "linear", "x": "step", "y": y, - "revs_with_datapoints": ["B"], - "anchors_y_definitions": anchors_y_definitions, } expected_datapoints = _get_expected_datapoints(datapoints, expected_dp_keys) @@ -508,9 +508,9 @@ def test_optional_anchors_linear( @pytest.mark.parametrize( ",".join( [ + "anchors_y_definitions", "datapoints", "y", - "anchors_y_definitions", "expected_dp_keys", "row_encoding", "group_by_y", @@ -519,6 +519,7 @@ def test_optional_anchors_linear( ), ( pytest.param( + [{"filename": "test", "field": "predicted"}], [ { "rev": "B", @@ -536,7 +537,6 @@ def test_optional_anchors_linear( }, ], "predicted", - [{"filename": "test", "field": "predicted"}], ["rev", "predicted", "actual"], {}, ["rev", "predicted"], @@ -544,6 +544,10 @@ def test_optional_anchors_linear( id="single_source", ), pytest.param( + [ + {"filename": "test", "field": "predicted"}, + {"filename": "train", "field": "predicted"}, + ], [ { "rev": "B", @@ -561,10 +565,6 @@ def test_optional_anchors_linear( }, ], "predicted", - [ - {"filename": "test", "field": "predicted"}, - {"filename": "train", "field": "predicted"}, - ], ["rev", "predicted", "actual"], {"field": "filename", "sort": []}, ["rev", "filename", "predicted"], @@ -572,6 +572,10 @@ def test_optional_anchors_linear( id="multi_filename", ), pytest.param( + [ + {"filename": "data", "field": "predicted_test"}, + {"filename": "data", "field": "predicted_train"}, + ], [ { "rev": "B", @@ -593,10 +597,6 @@ def test_optional_anchors_linear( }, ], "dvc_inferred_y_value", - [ - {"filename": "data", "field": "predicted_test"}, - {"filename": "data", "field": "predicted_train"}, - ], ["rev", "dvc_inferred_y_value", "actual"], {"field": "field", "sort": []}, ["rev", "field", "dvc_inferred_y_value"], @@ -604,6 +604,10 @@ def test_optional_anchors_linear( id="multi_field", ), pytest.param( + [ + {"filename": "test", "field": "predicted_test"}, + {"filename": "train", "field": "predicted_train"}, + ], [ { "rev": "B", @@ -639,10 +643,6 @@ def test_optional_anchors_linear( }, ], "dvc_inferred_y_value", - [ - {"filename": "test", "field": "predicted_test"}, - {"filename": "train", "field": "predicted_train"}, - ], ["rev", "predicted", "actual"], {"field": "filename::field", "sort": []}, ["rev", "filename::field", "dvc_inferred_y_value"], @@ -652,20 +652,20 @@ def test_optional_anchors_linear( ), ) def test_optional_anchors_confusion( + anchors_y_definitions, datapoints, y, - anchors_y_definitions, expected_dp_keys, row_encoding, group_by_y, group_by_x, ): # pylint: disable=too-many-arguments props = { + "anchors_y_definitions": anchors_y_definitions, + "revs_with_datapoints": ["B"], "template": "confusion", "x": "actual", "y": y, - "revs_with_datapoints": ["B"], - "anchors_y_definitions": anchors_y_definitions, } expected_datapoints = _get_expected_datapoints(datapoints, expected_dp_keys) @@ -685,9 +685,9 @@ def test_optional_anchors_confusion( @pytest.mark.parametrize( ",".join( [ + "anchors_y_definitions", "datapoints", "y", - "anchors_y_definitions", "expected_dp_keys", "shape_encoding", "tooltip_encoding", @@ -695,6 +695,7 @@ def test_optional_anchors_confusion( ), ( pytest.param( + [{"filename": "test", "field": "acc"}], [ { "rev": "B", @@ -714,13 +715,16 @@ def test_optional_anchors_confusion( }, ], "acc", - [{"filename": "test", "field": "acc"}], ["rev", "acc", "other", "loss"], {}, [{"field": "rev"}, {"field": "loss"}, {"field": "acc"}], id="single_source", ), pytest.param( + [ + {"filename": "train", "field": "acc"}, + {"filename": "test", "field": "acc"}, + ], [ { "rev": "B", @@ -756,10 +760,6 @@ def test_optional_anchors_confusion( }, ], "acc", - [ - {"filename": "train", "field": "acc"}, - {"filename": "test", "field": "acc"}, - ], ["rev", "acc", "filename", "loss", "other"], { "field": "filename", @@ -781,6 +781,10 @@ def test_optional_anchors_confusion( id="multi_filename", ), pytest.param( + [ + {"filename": "data", "field": "train_acc"}, + {"filename": "data", "field": "test_acc"}, + ], [ { "rev": "B", @@ -824,10 +828,6 @@ def test_optional_anchors_confusion( }, ], "dvc_inferred_y_value", - [ - {"filename": "data", "field": "train_acc"}, - {"filename": "data", "field": "test_acc"}, - ], ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "other", "loss"], { "field": "field", @@ -849,6 +849,10 @@ def test_optional_anchors_confusion( id="multi_field", ), pytest.param( + [ + {"filename": "train", "field": "train_acc"}, + {"filename": "test", "field": "test_acc"}, + ], [ { "rev": "B", @@ -888,10 +892,6 @@ def test_optional_anchors_confusion( }, ], "dvc_inferred_y_value", - [ - {"filename": "train", "field": "train_acc"}, - {"filename": "test", "field": "test_acc"}, - ], ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "other", "loss"], { "field": "filename::field", @@ -915,19 +915,19 @@ def test_optional_anchors_confusion( ), ) def test_optional_anchors_scatter( + anchors_y_definitions, datapoints, y, - anchors_y_definitions, expected_dp_keys, shape_encoding, tooltip_encoding, ): # pylint: disable=too-many-arguments props = { + "anchors_y_definitions": anchors_y_definitions, + "revs_with_datapoints": ["B", "C"], "template": "scatter", "x": "loss", "y": y, - "revs_with_datapoints": ["B", "C"], - "anchors_y_definitions": anchors_y_definitions, } expected_datapoints = _get_expected_datapoints(datapoints, expected_dp_keys) @@ -952,9 +952,18 @@ def test_optional_anchors_scatter( @pytest.mark.parametrize( - "datapoints,y,anchors_y_definitions,expected_dp_keys,stroke_dash_encoding", + ",".join( + [ + "anchors_y_definitions", + "datapoints", + "y", + "expected_dp_keys", + "stroke_dash_encoding", + ] + ), ( pytest.param( + [{"filename": "test", "field": "acc"}], [ { "rev": "B", @@ -972,12 +981,15 @@ def test_optional_anchors_scatter( }, ], "acc", - [{"filename": "test", "field": "acc"}], ["rev", "acc", "step"], {}, id="single_source", ), pytest.param( + [ + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], [ { "rev": "B", @@ -995,10 +1007,6 @@ def test_optional_anchors_scatter( }, ], "acc", - [ - {"filename": "test", "field": "acc"}, - {"filename": "train", "field": "acc"}, - ], ["rev", "acc", "step", "field"], { "field": "filename", @@ -1014,6 +1022,10 @@ def test_optional_anchors_scatter( id="multi_filename", ), pytest.param( + [ + {"filename": "test", "field": "acc"}, + {"filename": "test", "field": "acc_norm"}, + ], [ { "rev": "B", @@ -1031,10 +1043,6 @@ def test_optional_anchors_scatter( }, ], "dvc_inferred_y_value", - [ - {"filename": "test", "field": "acc"}, - {"filename": "test", "field": "acc_norm"}, - ], ["rev", "dvc_inferred_y_value", "step", "field"], { "field": "field", @@ -1050,6 +1058,11 @@ def test_optional_anchors_scatter( id="multi_field", ), pytest.param( + [ + {"filename": "test", "field": "acc_norm"}, + {"filename": "test", "field": "acc"}, + {"filename": "train", "field": "acc"}, + ], [ { "rev": "B", @@ -1074,11 +1087,6 @@ def test_optional_anchors_scatter( }, ], "dvc_inferred_y_value", - [ - {"filename": "test", "field": "acc_norm"}, - {"filename": "test", "field": "acc"}, - {"filename": "train", "field": "acc"}, - ], ["rev", "dvc_inferred_y_value", "step", "filename::field"], { "field": "filename::field", @@ -1096,20 +1104,20 @@ def test_optional_anchors_scatter( ), ) def test_partial_filled_template( + anchors_y_definitions, datapoints, y, - anchors_y_definitions, expected_dp_keys, stroke_dash_encoding, ): title = f"{y} by step" props = { + "anchors_y_definitions": anchors_y_definitions, + "revs_with_datapoints": ["B"], "template": "linear", + "title": title, "x": "step", "y": y, - "revs_with_datapoints": ["B"], - "anchors_y_definitions": anchors_y_definitions, - "title": title, } expected_split = { From 740c3494140da7c168fea3d43f53f3c94d47168f Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Mon, 4 Dec 2023 15:32:03 +1100 Subject: [PATCH 38/39] fix up test datapoints --- tests/test_vega.py | 79 +++++++++++++++++++++++++++++++--------------- 1 file changed, 53 insertions(+), 26 deletions(-) diff --git a/tests/test_vega.py b/tests/test_vega.py index ec20527..89fa0e9 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -361,6 +361,8 @@ def test_fill_anchor_in_string(tmp_dir): { "rev": "B", "dvc_inferred_y_value": "0.05", + "acc": "0.05", + "acc_norm": "0.04", "filename": "test", "field": "acc", "step": 1, @@ -368,6 +370,8 @@ def test_fill_anchor_in_string(tmp_dir): { "rev": "B", "dvc_inferred_y_value": "0.1", + "acc": "0.1", + "acc_norm": "0.09", "filename": "test", "field": "acc", "step": 2, @@ -375,6 +379,8 @@ def test_fill_anchor_in_string(tmp_dir): { "rev": "B", "dvc_inferred_y_value": "0.04", + "acc": "0.05", + "acc_norm": "0.04", "filename": "test", "field": "acc_norm", "step": 1, @@ -382,13 +388,15 @@ def test_fill_anchor_in_string(tmp_dir): { "rev": "B", "dvc_inferred_y_value": "0.09", + "acc": "0.1", + "acc_norm": "0.09", "filename": "test", "field": "acc_norm", "step": 2, }, ], "dvc_inferred_y_value", - ["rev", "dvc_inferred_y_value", "step", "field"], + ["rev", "dvc_inferred_y_value", "acc", "acc_norm", "step", "field"], { "field": "field", "scale": { @@ -414,6 +422,8 @@ def test_fill_anchor_in_string(tmp_dir): { "rev": "B", "dvc_inferred_y_value": "0.05", + "acc": "0.05", + "acc_norm": "0.02", "filename": "test", "field": "acc", "step": 1, @@ -421,6 +431,8 @@ def test_fill_anchor_in_string(tmp_dir): { "rev": "B", "dvc_inferred_y_value": "0.1", + "acc": "0.01", + "acc_norm": "0.07", "filename": "test", "field": "acc", "step": 2, @@ -428,6 +440,7 @@ def test_fill_anchor_in_string(tmp_dir): { "rev": "B", "dvc_inferred_y_value": "0.04", + "acc": "0.04", "filename": "train", "field": "acc", "step": 1, @@ -435,6 +448,7 @@ def test_fill_anchor_in_string(tmp_dir): { "rev": "B", "dvc_inferred_y_value": "0.09", + "acc": "0.09", "filename": "train", "field": "acc", "step": 2, @@ -442,6 +456,8 @@ def test_fill_anchor_in_string(tmp_dir): { "rev": "B", "dvc_inferred_y_value": "0.02", + "acc": "0.05", + "acc_norm": "0.02", "filename": "test", "field": "acc_norm", "step": 1, @@ -449,13 +465,22 @@ def test_fill_anchor_in_string(tmp_dir): { "rev": "B", "dvc_inferred_y_value": "0.07", + "acc": "0.01", + "acc_norm": "0.07", "filename": "test", "field": "acc_norm", "step": 2, }, ], "dvc_inferred_y_value", - ["rev", "dvc_inferred_y_value", "step", "filename::field"], + [ + "rev", + "dvc_inferred_y_value", + "acc", + "acc_norm", + "step", + "filename::field", + ], { "field": "filename::field", "scale": { @@ -615,7 +640,7 @@ def test_optional_anchors_linear( "predicted_test": "0.05", "actual": "0.5", "filename": "test", - "field": "predicted", + "field": "predicted_test", }, { "rev": "B", @@ -623,7 +648,7 @@ def test_optional_anchors_linear( "predicted_test": "0.9", "actual": "0.9", "filename": "test", - "field": "predicted", + "field": "predicted_test", }, { "rev": "B", @@ -631,7 +656,7 @@ def test_optional_anchors_linear( "predicted_train": "0.9", "actual": "0.9", "filename": "train", - "field": "predicted", + "field": "predicted_train", }, { "rev": "B", @@ -639,7 +664,7 @@ def test_optional_anchors_linear( "predicted_train": "0.9", "actual": "0.9", "filename": "train", - "field": "predicted", + "field": "predicted_train", }, ], "dvc_inferred_y_value", @@ -700,7 +725,6 @@ def test_optional_anchors_confusion( { "rev": "B", "acc": "0.05", - "other": "field", "filename": "test", "field": "acc", "loss": 0.1, @@ -708,14 +732,13 @@ def test_optional_anchors_confusion( { "rev": "C", "acc": "0.1", - "other": "field", "filename": "test", "field": "acc", "loss": 2, }, ], "acc", - ["rev", "acc", "other", "loss"], + ["rev", "acc", "loss"], {}, [{"field": "rev"}, {"field": "loss"}, {"field": "acc"}], id="single_source", @@ -729,7 +752,6 @@ def test_optional_anchors_confusion( { "rev": "B", "acc": "0.05", - "other": "field", "filename": "train", "field": "acc", "loss": "0.0001", @@ -737,7 +759,6 @@ def test_optional_anchors_confusion( { "rev": "B", "acc": "0.06", - "other": "field", "filename": "test", "field": "acc", "loss": "200121", @@ -745,7 +766,6 @@ def test_optional_anchors_confusion( { "rev": "C", "acc": "0.1", - "other": "field", "filename": "train", "field": "acc", "loss": "10", @@ -753,14 +773,13 @@ def test_optional_anchors_confusion( { "rev": "C", "acc": "0.1", - "other": "field", "filename": "test", "field": "acc", "loss": "100", }, ], "acc", - ["rev", "acc", "filename", "loss", "other"], + ["rev", "acc", "filename", "loss"], { "field": "filename", "legend": { @@ -791,7 +810,6 @@ def test_optional_anchors_confusion( "dvc_inferred_y_value": "0.05", "test_acc": "0.05", "train_acc": "0.06", - "other": "field", "filename": "data", "field": "test_acc", "loss": 0.1, @@ -801,7 +819,6 @@ def test_optional_anchors_confusion( "dvc_inferred_y_value": "0.06", "test_acc": "0.05", "train_acc": "0.06", - "other": "field", "filename": "data", "field": "train_acc", "loss": 0.1, @@ -811,7 +828,6 @@ def test_optional_anchors_confusion( "dvc_inferred_y_value": "0.1", "train_acc": "0.1", "test_acc": "0.2", - "other": "field", "filename": "data", "field": "acc", "loss": 2, @@ -821,14 +837,13 @@ def test_optional_anchors_confusion( "dvc_inferred_y_value": "0.2", "train_acc": "0.1", "test_acc": "0.2", - "other": "field", "filename": "data", "field": "acc", "loss": 2, }, ], "dvc_inferred_y_value", - ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "other", "loss"], + ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "loss"], { "field": "field", "legend": { @@ -858,7 +873,6 @@ def test_optional_anchors_confusion( "rev": "B", "dvc_inferred_y_value": "0.05", "test_acc": "0.05", - "other": "field", "filename": "test", "field": "test_acc", "loss": 0.1, @@ -867,7 +881,6 @@ def test_optional_anchors_confusion( "rev": "B", "dvc_inferred_y_value": "0.06", "train_acc": "0.06", - "other": "field", "filename": "train", "field": "train_acc", "loss": 0.1, @@ -876,7 +889,6 @@ def test_optional_anchors_confusion( "rev": "C", "dvc_inferred_y_value": "0.2", "test_acc": "0.2", - "other": "field", "filename": "test_acc", "field": "acc", "loss": 2, @@ -885,14 +897,13 @@ def test_optional_anchors_confusion( "rev": "C", "dvc_inferred_y_value": "0.2", "train_acc": "0.1", - "other": "field", "filename": "train_acc", "field": "acc", "loss": 2, }, ], "dvc_inferred_y_value", - ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "other", "loss"], + ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "loss"], { "field": "filename::field", "legend": { @@ -1030,6 +1041,8 @@ def test_optional_anchors_scatter( { "rev": "B", "dvc_inferred_y_value": "0.05", + "acc": "0.05", + "acc_norm": "0.04", "filename": "test", "field": "acc", "step": 1, @@ -1037,13 +1050,15 @@ def test_optional_anchors_scatter( { "rev": "B", "dvc_inferred_y_value": "0.04", + "acc": "0.05", + "acc_norm": "0.04", "filename": "test", "field": "acc_norm", "step": 1, }, ], "dvc_inferred_y_value", - ["rev", "dvc_inferred_y_value", "step", "field"], + ["rev", "dvc_inferred_y_value", "acc", "acc_norm", "step", "field"], { "field": "field", "scale": { @@ -1067,6 +1082,8 @@ def test_optional_anchors_scatter( { "rev": "B", "dvc_inferred_y_value": "0.05", + "acc": "0.05", + "acc_norm": "0.02", "filename": "test", "field": "acc", "step": 1, @@ -1074,6 +1091,7 @@ def test_optional_anchors_scatter( { "rev": "B", "dvc_inferred_y_value": "0.04", + "acc": "0.04", "filename": "train", "field": "acc", "step": 1, @@ -1082,12 +1100,21 @@ def test_optional_anchors_scatter( "rev": "B", "dvc_inferred_y_value": "0.02", "filename": "test", + "acc": "0.05", + "acc_norm": "0.02", "field": "acc_norm", "step": 1, }, ], "dvc_inferred_y_value", - ["rev", "dvc_inferred_y_value", "step", "filename::field"], + [ + "rev", + "dvc_inferred_y_value", + "acc", + "acc_norm", + "step", + "filename::field", + ], { "field": "filename::field", "scale": { From aea81873a11aec822d33e0ff6783a662f4cd91bb Mon Sep 17 00:00:00 2001 From: Matt Seddon Date: Tue, 5 Dec 2023 09:25:49 +1100 Subject: [PATCH 39/39] separate color test --- tests/test_vega.py | 218 +++++++++++++++++++++++++++++++++------------ 1 file changed, 161 insertions(+), 57 deletions(-) diff --git a/tests/test_vega.py b/tests/test_vega.py index 89fa0e9..482c1bb 100644 --- a/tests/test_vega.py +++ b/tests/test_vega.py @@ -729,13 +729,6 @@ def test_optional_anchors_confusion( "field": "acc", "loss": 0.1, }, - { - "rev": "C", - "acc": "0.1", - "filename": "test", - "field": "acc", - "loss": 2, - }, ], "acc", ["rev", "acc", "loss"], @@ -763,20 +756,6 @@ def test_optional_anchors_confusion( "field": "acc", "loss": "200121", }, - { - "rev": "C", - "acc": "0.1", - "filename": "train", - "field": "acc", - "loss": "10", - }, - { - "rev": "C", - "acc": "0.1", - "filename": "test", - "field": "acc", - "loss": "100", - }, ], "acc", ["rev", "acc", "filename", "loss"], @@ -823,24 +802,6 @@ def test_optional_anchors_confusion( "field": "train_acc", "loss": 0.1, }, - { - "rev": "C", - "dvc_inferred_y_value": "0.1", - "train_acc": "0.1", - "test_acc": "0.2", - "filename": "data", - "field": "acc", - "loss": 2, - }, - { - "rev": "C", - "dvc_inferred_y_value": "0.2", - "train_acc": "0.1", - "test_acc": "0.2", - "filename": "data", - "field": "acc", - "loss": 2, - }, ], "dvc_inferred_y_value", ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "loss"], @@ -885,22 +846,6 @@ def test_optional_anchors_confusion( "field": "train_acc", "loss": 0.1, }, - { - "rev": "C", - "dvc_inferred_y_value": "0.2", - "test_acc": "0.2", - "filename": "test_acc", - "field": "acc", - "loss": 2, - }, - { - "rev": "C", - "dvc_inferred_y_value": "0.2", - "train_acc": "0.1", - "filename": "train_acc", - "field": "acc", - "loss": 2, - }, ], "dvc_inferred_y_value", ["rev", "dvc_inferred_y_value", "train_acc", "test_acc", "loss"], @@ -935,7 +880,7 @@ def test_optional_anchors_scatter( ): # pylint: disable=too-many-arguments props = { "anchors_y_definitions": anchors_y_definitions, - "revs_with_datapoints": ["B", "C"], + "revs_with_datapoints": ["B"], "template": "scatter", "x": "loss", "y": y, @@ -949,7 +894,7 @@ def test_optional_anchors_scatter( assert plot_content["data"]["values"] == expected_datapoints assert plot_content["encoding"]["color"] == { "field": "rev", - "scale": {"domain": ["B", "C"], "range": OPTIONAL_ANCHOR_RANGES["color"][0:2]}, + "scale": {"domain": ["B"], "range": OPTIONAL_ANCHOR_RANGES["color"][0:1]}, } assert plot_content["encoding"]["shape"] == shape_encoding assert plot_content["encoding"]["tooltip"] == tooltip_encoding @@ -962,6 +907,165 @@ def test_optional_anchors_scatter( ] +@pytest.mark.parametrize( + ",".join( + [ + "revs", + "datapoints", + ] + ), + ( + pytest.param( + ["B"], + [ + { + "rev": "B", + "acc": "0.05", + "step": "1", + "filename": "acc", + "field": "acc", + }, + ], + id="rev_count_1", + ), + pytest.param( + ["B", "C", "D", "E", "F"], + [ + { + "rev": "B", + "acc": "0.05", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "C", + "acc": "0.1", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "D", + "acc": "0.06", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "E", + "acc": "0.6", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "F", + "acc": "1.0", + "step": "1", + "filename": "acc", + "field": "acc", + }, + ], + id="rev_count_5", + ), + pytest.param( + ["B", "C", "D", "E", "F", "G", "H", "I", "J"], + [ + { + "rev": "B", + "acc": "0.05", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "C", + "acc": "0.1", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "D", + "acc": "0.06", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "E", + "acc": "0.6", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "F", + "acc": "1.0", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "G", + "acc": "0.006", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "H", + "acc": "0.00001", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "I", + "acc": "0.8", + "step": "1", + "filename": "acc", + "field": "acc", + }, + { + "rev": "J", + "acc": "0.001", + "step": "1", + "filename": "acc", + "field": "acc", + }, + ], + id="rev_count_9", + ), + ), +) +def test_color_anchor(revs, datapoints): + props = { + "anchors_y_definitions": [{"filename": "acc", "field": "acc"}], + "revs_with_datapoints": revs, + "template": "linear", + "x": "step", + "y": "acc", + } + + renderer = VegaRenderer(datapoints, "foo", **props) + plot_content = renderer.get_filled_template() + + colors = OPTIONAL_ANCHOR_RANGES["color"] + color_range = colors[0 : len(revs)] + if len(revs) > len(colors): + color_range.extend(colors[0 : len(revs) - len(colors)]) + + assert plot_content["encoding"]["color"] == { + "field": "rev", + "scale": { + "domain": revs, + "range": color_range, + }, + } + + @pytest.mark.parametrize( ",".join( [