Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cache for dask.dataframe #305

Merged
merged 7 commits into from
Aug 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 39 additions & 28 deletions lumen/sources/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
import hashlib
import json
import os
import pathlib
import re
import shutil
import sys
Expand All @@ -11,8 +9,9 @@
from concurrent import futures
from functools import wraps
from itertools import product
from os.path import basename
from pathlib import Path
from urllib.parse import quote
from urllib.parse import quote, urlparse

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -127,6 +126,9 @@ class Source(Component):
dashboard. If set to `True` the Source will be loaded on
initial server load.""")

root = param.ClassSelector(class_=Path, precedence=-1, doc="""
Root folder of the cache_dir, default is config.root""")

source_type = None

# Declare whether source supports SQL transforms
Expand Down Expand Up @@ -202,7 +204,7 @@ def from_spec(cls, spec):

def __init__(self, **params):
from ..config import config
self.root = params.pop('root', config.root)
params['root'] = Path(params.get('root', config.root))
super().__init__(**params)
self.param.watch(self.clear_cache, self._reload_params)
self._cache = {}
Expand All @@ -222,8 +224,8 @@ def _get_key(self, table, **query):
def _get_schema_cache(self):
schema = self._schema_cache if self._schema_cache else None
if self.cache_dir:
path = os.path.join(self.root, self.cache_dir, f'{self.name}.json')
if not os.path.isfile(path):
path = self.root / self.cache_dir / f'{self.name}.json'
if not path.is_file():
return schema
with open(path) as f:
json_schema = json.load(f)
Expand All @@ -246,7 +248,7 @@ def _get_schema_cache(self):
def _set_schema_cache(self, schema):
self._schema_cache = schema
if self.cache_dir:
path = Path(os.path.join(self.root, self.cache_dir))
path = self.root / self.cache_dir
path.mkdir(parents=True, exist_ok=True)
try:
with open(path / f'{self.name}.json', 'w') as f:
Expand All @@ -267,30 +269,37 @@ def _get_cache(self, table, **query):
filename = f'{key}_{table}.parq'
else:
filename = f'{table}.parq'
path = os.path.join(self.root, self.cache_dir, filename)
if os.path.isfile(path) or os.path.isdir(path):
if 'dask.dataframe' in sys.modules or os.path.isdir(path):
import dask.dataframe as dd
return dd.read_parquet(path), not bool(query)
path = self.root / self.cache_dir / filename
if path.is_file():
return pd.read_parquet(path), not bool(query)
if 'dask.dataframe' in sys.modules and path.is_dir():
import dask.dataframe as dd
return dd.read_parquet(path), not bool(query)
path = path.with_suffix('')
if 'dask.dataframe' in sys.modules and path.is_dir():
import dask.dataframe as dd
return dd.read_parquet(path), not bool(query)
return None, not bool(query)

def _set_cache(self, data, table, write_to_file=True, **query):
query.pop('__dask', None)
key = self._get_key(table, **query)
self._cache[key] = data
if self.cache_dir and write_to_file:
path = os.path.join(self.root, self.cache_dir)
Path(path).mkdir(parents=True, exist_ok=True)
path = self.root / self.cache_dir
path.mkdir(parents=True, exist_ok=True)
if query:
filename = f'{key}_{table}.parq'
else:
filename = f'{table}.parq'
filepath = os.path.join(path, filename)
filepath = path / filename
if 'dask.dataframe' in sys.modules:
import dask.dataframe as dd
if isinstance(data, dd.DataFrame):
filepath = filepath.with_suffix('')
try:
data.to_parquet(filepath)
except Exception as e:
path = pathlib.Path(filepath)
if path.is_file():
path.unlink()
elif path.is_dir():
Expand All @@ -307,8 +316,8 @@ def clear_cache(self, *events):
self._cache = {}
self._schema_cache = {}
if self.cache_dir:
path = os.path.join(self.root, self.cache_dir)
if os.path.isdir(path):
path = self.root / self.cache_dir
if path.is_dir():
shutil.rmtree(path)

@property
Expand Down Expand Up @@ -484,24 +493,27 @@ def _named_files(self):
if f.startswith('http'):
name = f
else:
name = '.'.join(os.path.basename(f).split('.')[:-1])
name = '.'.join(basename(f).split('.')[:-1])
tables[name] = f
else:
tables = self.tables
files = {}
for name, table in tables.items():
ext = None
if isinstance(table, (list, tuple)):
table, ext = table
else:
basename = os.path.basename(table)
if '.' in basename:
ext = basename.split('.')[-1]
if table.startswith('http'):
file = basename(urlparse(table).path)
else:
file = basename(table)
ext = re.search(r"\.(\w+)$", file)
if ext:
ext = ext.group(1)
files[name] = (table, ext)
return files

def _resolve_template_vars(self, table):
for m in self._template_re.findall(table):
for m in self._template_re.findall(str(table)):
values = state.resolve_reference(f'${m[2:-1]}')
values = ','.join([v for v in values])
table = table.replace(m, quote(values))
Expand All @@ -512,10 +524,9 @@ def get_tables(self):

def _load_table(self, table, dask=True):
df = None
for name, filepath in self._named_files.items():
filepath, ext = filepath
if '://' not in filepath:
filepath = os.path.join(self.root, filepath)
for name, (filepath, ext) in self._named_files.items():
if isinstance(filepath, Path) or '://' not in filepath:
filepath = self.root / filepath
if name != table:
continue
load_fn, kwargs = self._load_fn(ext, dask=dask)
Expand Down
6 changes: 6 additions & 0 deletions lumen/tests/sources/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,9 @@ def test_file_source_variable(make_variable_filesource):
df = source.get('test')
expected = pd._testing.makeMixedDataFrame().iloc[::-1].reset_index(drop=True)
pd.testing.assert_frame_equal(df, expected)


def test_extension_of_comlicated_url(source):
url = "https://api.tfl.gov.uk/Occupancy/BikePoints/@{stations.stations.id}?app_key=random_numbers"
source.tables["test"] = url
assert source._named_files["test"][1] is None