Skip to content

Commit

Permalink
Implement Variable API (#242)
Browse files Browse the repository at this point in the history
  • Loading branch information
philippjfr authored Feb 16, 2022
1 parent d89cb11 commit 4e46494
Show file tree
Hide file tree
Showing 20 changed files with 662 additions and 137 deletions.
49 changes: 49 additions & 0 deletions lumen/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from functools import partial

import param

from .state import state
from .util import resolve_module_reference


class Component(param.Parameterized):
"""
Baseclass for all Lumen component types including Source, Filter,
Transform, Variable and View types.
"""

__abstract = True

def __init__(self, **params):
self._refs = params.pop('refs', {})
super().__init__(**params)
for p, ref in self._refs.items():
if isinstance(ref, str) and ref.startswith('$variables.'):
ref = ref.split('$variables.')[1]
state.variables.param.watch(partial(self._update_ref, p), ref)

def _update_ref(self, pname, event):
"""
Component should implement appropriate downstream events
following a change in a variable.
"""
self.param.update({pname: event.new})

@property
def refs(self):
return [k for k, v in self._refs.items() if v.startswith('$variables.')]

@classmethod
def _get_type(cls, component_type):
clsname = cls.__name__
clslower = clsname.lower()
if '.' in component_type:
return resolve_module_reference(component_type, cls)
try:
__import__(f'lumen.{clslower}s.{component_type}')
except Exception:
pass
for component in param.concrete_descendents(cls).values():
if getattr(component, f'{clsname.lower()}_type') == component_type:
return component
raise ValueError(f"No {clsname} for {clslower}_type '{component_type}' could be found.")
6 changes: 6 additions & 0 deletions lumen/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from .target import Target
from .transforms import Transform # noqa
from .util import expand_spec
from .variables import Variables
from .views import View # noqa

pn.config.css_files.append(
Expand Down Expand Up @@ -223,6 +224,8 @@ def __init__(self, specification=None, **params):
self.auth = Auth.from_spec(state.spec.get('auth', {}))
self.config = Config.from_spec(state.spec.get('config', {}))
self.defaults = Defaults.from_spec(state.spec.get('defaults', {}))
vars = Variables.from_spec(state.spec.get('variables', {}))
self.variables = state._variables[pn.state.curdoc] = vars
self.defaults.apply()

# Load and populate template
Expand Down Expand Up @@ -479,6 +482,9 @@ def _get_global_filters(self):
def _render_filters(self):
self._global_filters, global_panel = self._get_global_filters()
filters = [] if global_panel is None else [global_panel]
variable_panel = self.variables.panel
if len(variable_panel):
filters.append(variable_panel)
for i, target in enumerate(self.targets):
if isinstance(self._layout, pn.Tabs) and i != self._layout.active:
continue
Expand Down
17 changes: 2 additions & 15 deletions lumen/filters/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import panel as pn
import param

from ..base import Component
from ..schema import JSONSchema
from ..state import state
from ..util import resolve_module_reference


class Filter(param.Parameterized):
class Filter(Component):
"""
A Filter provides a query which will be used to filter the data
returned by a Source.
Expand Down Expand Up @@ -53,19 +53,6 @@ def _url_sync_error(self, values):
Called when URL syncing errors.
"""

@classmethod
def _get_type(cls, filter_type):
if '.' in filter_type:
return resolve_module_reference(filter_type, Filter)
try:
__import__(f'lumen.filters.{filter_type}')
except Exception:
pass
for filt in param.concrete_descendents(cls).values():
if filt.filter_type == filter_type:
return filt
raise ValueError(f"No Filter for filter_type '{filter_type}' could be found.")

@classmethod
def from_spec(cls, spec, source_schema, source_filters=None):
"""
Expand Down
73 changes: 20 additions & 53 deletions lumen/sources/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@
import param
import requests

from ..base import Component
from ..filters import Filter
from ..state import state
from ..transforms import Transform
from ..util import (
get_dataframe_schema, merge_schemas, resolve_module_reference
)
from ..util import get_dataframe_schema, is_ref, merge_schemas


def cached(with_query=True):
Expand Down Expand Up @@ -74,8 +73,7 @@ def wrapped(self, table=None):
return wrapped



class Source(param.Parameterized):
class Source(Component):
"""
A Source provides a set of tables which declare their available
fields. The Source must also be able to return a schema describing
Expand All @@ -98,18 +96,9 @@ class Source(param.Parameterized):

__abstract = True

@classmethod
def _get_type(cls, source_type):
if '.' in source_type:
return resolve_module_reference(source_type, Source)
try:
__import__(f'lumen.sources.{source_type}')
except Exception:
pass
for source in param.concrete_descendents(cls).values():
if source.source_type == source_type:
return source
raise ValueError(f"No Source for source_type '{source_type}' could be found.")
def _update_ref(self, pname, event):
self.clear_cache()
super()._update_ref(pname, event)

@classmethod
def _range_filter(cls, column, start, end):
Expand Down Expand Up @@ -172,34 +161,9 @@ def _filter_dataframe(cls, df, **query):
df = df[mask]
return df

@classmethod
def _resolve_reference(cls, reference):
refs = reference[1:].split('.')
if len(refs) == 3:
sourceref, table, field = refs
elif len(refs) == 2:
sourceref, table = refs
elif len(refs) == 1:
(sourceref,) = refs

source = cls.from_spec(sourceref)
if len(refs) == 1:
return source
if len(refs) == 2:
return source.get(table)
table_schema = source.get_schema(table)
if field not in table_schema:
raise ValueError(f"Field '{field}' was not found in "
f"'{sourceref}' table '{table}'.")
field_schema = table_schema[field]
if 'enum' not in field_schema:
raise ValueError(f"Field '{field}' schema does not "
"declare an enum.")
return field_schema['enum']

@classmethod
def _recursive_resolve(cls, spec, source_type):
resolved_spec = {}
resolved_spec, refs = {}, {}
if 'sources' in source_type.param and 'sources' in spec:
resolved_spec['sources'] = {
source: cls.from_spec(source)
Expand All @@ -208,17 +172,22 @@ def _recursive_resolve(cls, spec, source_type):
if 'source' in source_type.param and 'source' in spec:
resolved_spec['source'] = cls.from_spec(spec.pop('source'))
for k, v in spec.items():
if isinstance(v, str) and v.startswith('@'):
v = cls._resolve_reference(v)
if is_ref(v):
refs[k] = v
v = state.resolve_reference(v)
elif isinstance(v, dict):
v = cls._recursive_resolve(v, source_type)
v, subrefs = cls._recursive_resolve(v, source_type)
if subrefs:
cls.param.warning(
"Resolving nested variable references currently not supported."
)
if k == 'filters' and 'source' in resolved_spec:
source_schema = resolved_spec['source'].get_schema()
v = [Filter.from_spec(fspec, source_schema) for fspec in v]
if k == 'transforms':
v = [Transform.from_spec(tspec) for tspec in v]
resolved_spec[k] = v
return resolved_spec
return resolved_spec, refs

@classmethod
def from_spec(cls, spec):
Expand Down Expand Up @@ -250,8 +219,8 @@ def from_spec(cls, spec):

spec = dict(spec)
source_type = Source._get_type(spec.pop('type'))
resolved_spec = cls._recursive_resolve(dict(spec), source_type)
return source_type(**resolved_spec)
resolved_spec, refs = cls._recursive_resolve(spec, source_type)
return source_type(refs=refs, **resolved_spec)

def __init__(self, **params):
from ..config import config
Expand Down Expand Up @@ -523,8 +492,7 @@ def _named_files(self):

def _resolve_template_vars(self, table):
for m in self._template_re.findall(table):
ref = f'@{m[2:-1]}'
values = self._resolve_reference(ref)
values = state.resolve_reference(f'${m[2:-1]}')
values = ','.join([v for v in values])
table = table.replace(m, quote(values))
return [table]
Expand Down Expand Up @@ -601,8 +569,7 @@ def _resolve_template_vars(self, template):
template_vars = self._template_re.findall(template)
template_values = []
for m in template_vars:
ref = f'@{m[2:-1]}'
values = self._resolve_reference(ref)
values = state.resolve_reference(f'${m[2:-1]}')
template_values.append(values)
tables = []
cross_product = list(product(*template_values))
Expand Down
41 changes: 41 additions & 0 deletions lumen/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import panel as pn

from .util import is_ref


class _session_state:
"""
Expand All @@ -24,6 +26,8 @@ class _session_state:

_filters = WeakKeyDictionary() if pn.state.curdoc else {}

_variables = WeakKeyDictionary() if pn.state.curdoc else {}

@property
def app(self):
return self._apps.get(pn.state.curdoc)
Expand All @@ -47,6 +51,10 @@ def sources(self):
self._sources[pn.state.curdoc] = dict(self.global_sources)
return self._sources[pn.state.curdoc]

@property
def variables(self):
return self._variables.get(pn.state.curdoc, {})

@property
def loading_msg(self):
return self._loading.get(pn.state.curdoc)
Expand Down Expand Up @@ -131,5 +139,38 @@ def resolve_views(self):
for ext in exts:
__import__(pn.extension._imports[ext])

def _resolve_source_ref(self, refs):
if len(refs) == 3:
sourceref, table, field = refs
elif len(refs) == 2:
sourceref, table = refs
elif len(refs) == 1:
(sourceref,) = refs

from .sources import Source
source = Source.from_spec(sourceref)
if len(refs) == 1:
return source
if len(refs) == 2:
return source.get(table)
table_schema = source.get_schema(table)
if field not in table_schema:
raise ValueError(f"Field '{field}' was not found in "
f"'{sourceref}' table '{table}'.")
field_schema = table_schema[field]
if 'enum' not in field_schema:
raise ValueError(f"Field '{field}' schema does not "
"declare an enum.")
return field_schema['enum']

def resolve_reference(self, reference, variables=None):
if not is_ref(reference):
raise ValueError('References should be prefixed by $ symbol.')
refs = reference[1:].split('.')
vars = variables or self.variables
if refs[0] == 'variables':
return vars[refs[1]]
return self._resolve_source_ref(refs)


state = _session_state()
32 changes: 17 additions & 15 deletions lumen/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,11 +248,13 @@ def __init__(self, **params):

# Set up watchers
self.facet.param.watch(self._resort, ['sort', 'reverse'])
rerender = partial(self._rerender, invalidate_cache=True)
for filt in self.filters:
if isinstance(filt, FacetFilter):
continue
filt.param.watch(partial(self._rerender, invalidate_cache=True), 'value')
filt.param.watch(rerender, 'value')
self._update_views(init=True)
self.source.param.watch(rerender, self.source.refs)

def _resort(self, *events):
self._rerender(update_views=False)
Expand Down Expand Up @@ -430,11 +432,11 @@ def _update_views(self, invalidate_cache=True, update_views=True, init=False, ev
# Only the controls for the first facet is shown so link
# the other facets to the controls of the first
for v1, v2 in zip(linked_views, views):
v1.param.watch(partial(self._sync_component, v2), v1.controls)
v1.param.watch(partial(self._sync_component, v2), v1.refs)
for t1, t2 in zip(v1.transforms, v2.transforms):
t1.param.watch(partial(self._sync_component, t2), t1.controls)
t1.param.watch(partial(self._sync_component, t2), t1.refs)
for t1, t2 in zip(v1.sql_transforms, v2.sql_transforms):
t1.param.watch(partial(self._sync_component, t2), t1.controls)
t1.param.watch(partial(self._sync_component, t2), t1.refs)

# Validate that all filters are applied
for filt in self.filters:
Expand All @@ -445,23 +447,22 @@ def _update_views(self, invalidate_cache=True, update_views=True, init=False, ev
'found that matches such a field.'
)

# Re-render target when controls update but we ensure that
# all other views linked to the controls are updated first
# Re-render target when controls or refs update but we ensure
# that all other views linked to the controls are updated first
if init:
rerender = partial(self._rerender, invalidate_cache=False)
rerender_cache = partial(self._rerender, invalidate_cache=True)
transforms = []
for view in linked_views:
if view.controls:
view.param.watch(rerender, view.controls)
if view.refs:
view.param.watch(rerender_cache, view.refs)
for transform in view.transforms:
if transform.controls and not transform in transforms:
if transform.refs and not transform in transforms:
transforms.append(transform)
transform.param.watch(rerender_cache, transform.controls)
transform.param.watch(rerender_cache, transform.refs)
for transform in view.sql_transforms:
if transform.controls and not transform in transforms:
if transform.refs and not transform in transforms:
transforms.append(transform)
transform.param.watch(rerender_cache, transform.controls)
transform.param.watch(rerender_cache, transform.refs)

self._view_controls[:] = controls

Expand Down Expand Up @@ -649,11 +650,12 @@ def start(self, event=None):
self.update, refresh_rate
)

def update(self, *events):
def update(self, *events, clear_cache=True):
"""
Updates the views on this target by clearing any caches and
rerendering the views on this Target.
"""
self.source.clear_cache()
if clear_cache:
self.source.clear_cache()
self._rerender(invalidate_cache=True)
self._timestamp.object = f'Last updated: {dt.datetime.now().strftime(self.tsformat)}'
Loading

0 comments on commit 4e46494

Please sign in to comment.