Skip to content

Commit

Permalink
be dumb and refactor before adding tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mattseddon committed Sep 20, 2023
1 parent 9b2c513 commit 925d4e8
Showing 1 changed file with 69 additions and 43 deletions.
112 changes: 69 additions & 43 deletions src/dvc_render/vega.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def get_filled_template(
self.properties.setdefault("y_label", self.properties.get("y"))
self.properties.setdefault("data", self.datapoints)

self._fill_optional_anchors(skip_anchors)
self._process_optional_anchors(skip_anchors)

names = ["title", "x", "y", "x_label", "y_label", "data"]
for name in names:
Expand Down Expand Up @@ -164,7 +164,7 @@ def generate_markdown(self, report_path=None) -> str:

return ""

def _fill_optional_anchors(self, skip_anchors: List[str]):
def _process_optional_anchors(self, skip_anchors: List[str]):
optional_anchors = [
anchor
for anchor in [
Expand All @@ -175,67 +175,79 @@ def _fill_optional_anchors(self, skip_anchors: List[str]):
"stroke_dash",
"shape",
]
if anchor not in skip_anchors and self.template.has_anchor(anchor)
if self.template.has_anchor(anchor)
]
if not optional_anchors:
return
if optional_anchors:
# split varied_keys out from _fill_optional_anchors to avoid bugs
# but first.... tests
varied_keys = self._fill_optional_anchors(skip_anchors, optional_anchors)
self._update_datapoints(varied_keys)

self._fill_color(optional_anchors)
def _fill_optional_anchors(
self, skip_anchors: List[str], optional_anchors: List[str]
) -> List[str]:
self._fill_color(skip_anchors, optional_anchors)

if not optional_anchors:
return
return []

y_defn = self.properties.get("anchors_y_defn", [])

if len(y_defn) <= 1:
self._fill_optional_anchor(optional_anchors, "group_by", ["rev"])
self._fill_optional_anchor(optional_anchors, "pivot_field", "rev")
self._fill_optional_anchor(
skip_anchors, optional_anchors, "group_by", ["rev"]
)
self._fill_optional_anchor(
skip_anchors, optional_anchors, "pivot_field", "rev"
)
for anchor in optional_anchors:
self.template.fill_anchor(anchor, {})
self._update_datapoints(to_remove=["filename", "field"])
return
return []

keys, variations = self._collect_variations(y_defn)
grouped_keys = ["rev", *keys]
concat_field = "::".join(keys)
self._fill_optional_anchor(optional_anchors, "group_by", grouped_keys)
varied_keys, variations = self._collect_variations(y_defn)
grouped_keys = ["rev", *varied_keys]
concat_field = "::".join(varied_keys)
self._fill_optional_anchor(
skip_anchors, optional_anchors, "group_by", grouped_keys
)
self._fill_optional_anchor(
skip_anchors,
optional_anchors,
"pivot_field",
"+ '::'".join([f"datum.{key}" for key in grouped_keys]),
" + '::' + ".join([f"datum.{key}" for key in grouped_keys]),
)
# concatenate grouped_keys together
self._fill_optional_anchor(optional_anchors, "row", {"field": concat_field})
self._fill_optional_anchor(
skip_anchors, optional_anchors, "row", {"field": concat_field}
)

if not optional_anchors:
return
return varied_keys

if len(keys) == 2:
self._update_datapoints(
to_remove=["filename", "field"], to_concatenate=[["filename", "field"]]
)
if len(varied_keys) == 2:
domain = ["::".join([d.get("filename"), d.get("field")]) for d in y_defn]
else:
filenameOrField = keys[0]
to_remove = ["filename", "field"]
to_remove.remove(filenameOrField)
self._update_datapoints(to_remove=to_remove)

filenameOrField = varied_keys[0]
domain = list(variations[filenameOrField])

stroke_dash_scale = self._set_optional_anchor_scale(
optional_anchors, concat_field, "stroke_dash", domain
)
self._fill_optional_anchor(optional_anchors, "stroke_dash", stroke_dash_scale)
self._fill_optional_anchor(
skip_anchors, optional_anchors, "stroke_dash", stroke_dash_scale
)

shape_scale = self._set_optional_anchor_scale(
optional_anchors, concat_field, "shape", domain
)
self._fill_optional_anchor(optional_anchors, "shape", shape_scale)
self._fill_optional_anchor(skip_anchors, optional_anchors, "shape", shape_scale)

def _fill_color(self, optional_anchors: List[str]):
return varied_keys

def _fill_color(self, skip_anchors: List[str], optional_anchors: List[str]):
all_revs = self.properties.get("anchor_revs", [])
self._fill_optional_anchor(
skip_anchors,
optional_anchors,
"color",
{
Expand Down Expand Up @@ -277,11 +289,21 @@ def _collect_variations(
lessValuesThanVariations.sort(reverse=True)
return lessValuesThanVariations, variations

def _fill_optional_anchor(self, optional_anchors: List[str], name: str, value: Any):
def _fill_optional_anchor(
self,
skip_anchors: List[str],
optional_anchors: List[str],
name: str,
value: Any,
):
if name not in optional_anchors:
return

optional_anchors.remove(name)

if name in skip_anchors:
return

self.template.fill_anchor(name, value)

def _set_optional_anchor_scale(
Expand All @@ -301,21 +323,25 @@ def _set_optional_anchor_scale(
self._optional_anchor_values[name][domain_value] = range_value
anchor_range.append(range_value)

return {"field": field, "scale": {"domain": domain, "range": anchor_range}}
return {
"field": field,
"scale": {"domain": domain, "range": anchor_range},
"legend": {"symbolFillColor": "transparent", "symbolStrokeColor": "grey"},
}

def _update_datapoints(
self,
to_concatenate: Optional[List[List[str]]] = None,
to_remove: Optional[List[str]] = None,
):
if to_concatenate is None:
def _update_datapoints(self, varied_keys: List[str]):
if len(varied_keys) == 2:
to_concatenate = varied_keys
to_remove = varied_keys
else:
to_concatenate = []
if to_remove is None:
to_remove = []
to_remove = [key for key in ["filename", "field"] if key not in varied_keys]

for datapoint in self.datapoints:
for keys in to_concatenate:
concat_key = "::".join(keys)
datapoint[concat_key] = "::".join([datapoint.get(k) for k in keys])
if to_concatenate:
concat_key = "::".join(to_concatenate)
datapoint[concat_key] = "::".join(
[datapoint.get(k) for k in to_concatenate]
)
for key in to_remove:
datapoint.pop(key, None)

0 comments on commit 925d4e8

Please sign in to comment.