diff --git a/lumen/tests/views/test_base.py b/lumen/tests/views/test_base.py index be89ec2ae..dd92ae829 100644 --- a/lumen/tests/views/test_base.py +++ b/lumen/tests/views/test_base.py @@ -238,3 +238,32 @@ def test_view_title_download(set_root, view_type): button._on_click() assert button.data.startswith('data:text/plain;charset=UTF-8;base64') + + assert view.download.filename is None + assert view.download.format == 'csv' + + +@pytest.mark.parametrize("view_type", ("table", "hvplot")) +def test_view_title_download_filename(set_root, view_type): + set_root(str(Path(__file__).parent.parent)) + source = FileSource(tables={'test': 'sources/test.csv'}) + view = { + 'type': view_type, + 'table': 'test', + 'title': 'Test title', + 'download': 'example.csv', + } + + view = View.from_spec(view, source) + + title = view.panel[0][0] + assert isinstance(title, pn.pane.HTML) + + button = view.panel[0][1] + assert isinstance(button, DownloadButton) + + button._on_click() + assert button.data.startswith('data:text/plain;charset=UTF-8;base64') + + assert view.download.filename == 'example' + assert view.download.format == 'csv' diff --git a/lumen/util.py b/lumen/util.py index a8efa05f2..e884ae38a 100644 --- a/lumen/util.py +++ b/lumen/util.py @@ -5,6 +5,7 @@ import os import re import sys +import unicodedata from functools import wraps from logging import getLogger @@ -314,3 +315,25 @@ def wrapper(*args, **kwargs): return decorator(function) return decorator + + +def slugify(value, allow_unicode=False) -> str: + """ + Convert to ASCII if 'allow_unicode' is False. Convert spaces or repeated + dashes to single dashes. Remove characters that aren't alphanumerics, + underscores, or hyphens. Convert to lowercase. Also strip leading and + trailing whitespace, dashes, and underscores. + + From: https://docs.djangoproject.com/en/4.0/_modules/django/utils/text/#slugify + """ + value = str(value) + if allow_unicode: + value = unicodedata.normalize("NFKC", value) + else: + value = ( + unicodedata.normalize("NFKD", value) + .encode("ascii", "ignore") + .decode("ascii") + ) + value = re.sub(r"[^\w\s-]", "", value.lower()) + return re.sub(r"[-\s]+", "-", value).strip("-_") diff --git a/lumen/views/base.py b/lumen/views/base.py index 863d91535..88cbc8f19 100644 --- a/lumen/views/base.py +++ b/lumen/views/base.py @@ -33,7 +33,7 @@ from ..transforms.base import Transform from ..transforms.sql import SQLTransform from ..util import ( - VARIABLE_RE, catch_and_notify, is_ref, resolve_module_reference, + VARIABLE_RE, catch_and_notify, is_ref, resolve_module_reference, slugify, ) from ..validation import ValidationError @@ -52,12 +52,20 @@ class Download(Component, Viewer): color = param.Color(default='grey', allow_None=True, doc=""" The color of the download button.""") - hide = param.Boolean(default=False, doc=""" - Whether the download button hides when not in focus.""") + filename = param.String(default=None, doc=""" + The filename of the downloaded table. + File extension is added automatic based on the format. + If filename is not defined, it will be the name of the orignal table of the view.""") format = param.ObjectSelector(default=None, objects=DOWNLOAD_FORMATS, doc=""" The format to download the data in.""") + hide = param.Boolean(default=False, doc=""" + Whether the download button hides when not in focus.""") + + index = param.Boolean(default=True, doc=""" + Whether the downloaded table has an index.""") + kwargs = param.Dict(default={}, doc=""" Keyword arguments passed to the serialization function, e.g. data.to_csv(file_obj, **kwargs).""") @@ -89,18 +97,19 @@ def _table_data(self) -> IO: io = BytesIO() data = self.view.get_data() if self.format == 'csv': - data.to_csv(io, **self.kwargs) + data.to_csv(io, index=self.index, **self.kwargs) elif self.format == 'json': - data.to_json(io, **self.kwargs) + data.to_json(io, index=self.index, **self.kwargs) elif self.format == 'xlsx': - data.to_excel(io, **self.kwargs) + data.to_excel(io, index=self.index, **self.kwargs) elif self.format == 'parquet': - data.to_parquet(io, **self.kwargs) + data.to_parquet(io, index=self.index, **self.kwargs) io.seek(0) return io def __panel__(self) -> DownloadButton: - filename = f'{self.view.pipeline.table}.{self.format}' + filename = self.filename or slugify(self.view.pipeline.table) + filename = f'{filename}.{self.format}' return DownloadButton( callback=self._table_data, filename=filename, color=self.color, size=18, hide=self.hide @@ -176,7 +185,9 @@ def __init__(self, **params): if pipeline is None: raise ValueError("Views must declare a Pipeline.") if isinstance(params.get("download"), str): - params["download"] = Download(format=params["download"]) + *filenames, ext = params.get("download").split(".") + filename = ".".join(filenames) or None + params["download"] = Download(filename=filename, format=ext) fields = list(pipeline.schema) for fp in self._field_params: if isinstance(self.param[fp], param.Selector): @@ -372,7 +383,9 @@ def from_spec( # Resolve download options download_spec = spec.pop('download', {}) if isinstance(download_spec, str): - download_spec = {'format': download_spec} + *filenames, ext = download_spec.split('.') + filename = '.'.join(filenames) or None + download_spec = {'filename': filename, 'format': ext} resolved_spec['download'] = Download.from_spec(download_spec) view = view_type(refs=refs, **resolved_spec) diff --git a/setup.py b/setup.py index 26077e5d0..bb3fba3c2 100644 --- a/setup.py +++ b/setup.py @@ -31,6 +31,7 @@ def get_setup_version(reponame): 'sql': [ 'duckdb', 'intake-sql', + 'sqlalchemy <2', # Don't work with pandas yet ], 'tests': [ 'pytest',