Skip to content

Commit

Permalink
Clean up and test SQLTransforms (#384)
Browse files Browse the repository at this point in the history
* Clean up and test SQLTransforms

* Error if Pipeline.transforms given SQLTransform

* Apply isort fixes
  • Loading branch information
philippjfr authored Nov 14, 2022
1 parent cf524cb commit 0ce3420
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 32 deletions.
2 changes: 2 additions & 0 deletions lumen/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class Pipeline(Component):
def __init__(self, *, source, table, **params):
if 'schema' not in params:
params['schema'] = source.get_schema(table)
if any(isinstance(t, SQLTransform) for t in params.get('transforms', [])):
raise TypeError('Pipeline.transforms must be regular Transform components, not SQLTransform.')
super().__init__(source=source, table=table, **params)
self._update_widget = pn.Param(self.param['update'], widgets={'update': {'button_type': 'success'}})[0]
self._init_callbacks()
Expand Down
85 changes: 85 additions & 0 deletions lumen/tests/transforms/test_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import datetime as dt

from lumen.transforms.sql import (
SQLColumns, SQLDistinct, SQLFilter, SQLGroupBy, SQLLimit, SQLMinMax,
)


def test_sql_group_by_single_column():
assert (
SQLGroupBy.apply_to('SELECT * FROM TABLE', by=['A'], aggregates={'AVG': 'B'}) ==
"""SELECT\n A, AVG(B) AS B\nFROM ( SELECT * FROM TABLE )\nGROUP BY A"""
)

def test_sql_group_by_multi_columns():
assert (
SQLGroupBy.apply_to('SELECT * FROM TABLE', by=['A'], aggregates={'AVG': ['B', 'C']}) ==
"""SELECT\n A, AVG(B) AS B, AVG(C) AS C\nFROM ( SELECT * FROM TABLE )\nGROUP BY A"""
)

def test_sql_limit():
assert (
SQLLimit.apply_to('SELECT * FROM TABLE', limit=10) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nLIMIT 10"""
)

def test_sql_columns():
assert (
SQLColumns.apply_to('SELECT * FROM TABLE', columns=['A', 'B']) ==
"""SELECT\n A, B\nFROM ( SELECT * FROM TABLE )"""
)

def test_sql_distinct():
assert (
SQLDistinct.apply_to('SELECT * FROM TABLE', columns=['A', 'B']) ==
"""SELECT DISTINCT\n A, B\nFROM ( SELECT * FROM TABLE )"""
)

def test_sql_min_max():
assert (
SQLMinMax.apply_to('SELECT * FROM TABLE', columns=['A', 'B']) ==
"""SELECT\n MIN(A) as A_min, MAX(A) as A_max, MIN(B) as B_min, MAX(B) as B_max\nFROM ( SELECT * FROM TABLE )"""
)

def test_sql_filter_none():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', None)]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A IS NULL )"""
)

def test_sql_filter_scalar():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', 1)]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A = 1 )"""
)


def test_sql_filter_isin():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', ['A', 'B', 'C'])]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A IN ('A', 'B', 'C') )"""
)

def test_sql_filter_datetime():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', dt.datetime(2017, 4, 14))]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A = '2017-04-14 00:00:00' )"""
)

def test_sql_filter_date():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', dt.date(2017, 4, 14))]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A BETWEEN '2017-04-14 00:00:00' AND '2017-04-14 23:59:59' )"""
)

def test_sql_filter_date_range():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', (dt.date(2017, 2, 22), dt.date(2017, 4, 14)))]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A BETWEEN '2017-02-22 00:00:00' AND '2017-04-14 23:59:59' )"""
)

def test_sql_filter_datetime_range():
assert (
SQLFilter.apply_to('SELECT * FROM TABLE', conditions=[('A', (dt.datetime(2017, 2, 22), dt.datetime(2017, 4, 14)))]) ==
"""SELECT\n *\nFROM ( SELECT * FROM TABLE )\nWHERE ( A BETWEEN '2017-02-22 00:00:00' AND '2017-04-14 00:00:00' )"""
)
56 changes: 24 additions & 32 deletions lumen/transforms/sql.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime as dt
import textwrap

import numpy as np
import param
Expand Down Expand Up @@ -47,6 +48,11 @@ def apply(self, sql_in):
"""
return sql_in

@classmethod
def _render_template(cls, template, **params):
template = textwrap.dedent(template).lstrip()
return Template(template, trim_blocks=True, lstrip_blocks=True).render(**params)


class SQLGroupBy(SQLTransform):
"""
Expand All @@ -64,18 +70,17 @@ class SQLGroupBy(SQLTransform):
def apply(self, sql_in):
template = """
SELECT
{{by_cols}},
{{aggs}}
{{by_cols}}, {{aggs}}
FROM ( {{sql_in}} )
GROUP BY {{by_cols}}
"""
GROUP BY {{by_cols}}"""
by_cols = ', '.join(self.by)
aggs = ', '.join([
f'{agg}({col}) AS {col}' for agg, col in self.aggregates.items()
])
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
by_cols=by_cols, aggs=aggs, sql_in=sql_in
)
aggs = []
for agg, cols in self.aggregates.items():
if isinstance(cols, str):
cols = [cols]
for col in cols:
aggs.append(f'{agg}({col}) AS {col}')
return self._render_template(template, by_cols=by_cols, aggs=', '.join(aggs), sql_in=sql_in)


class SQLLimit(SQLTransform):
Expand All @@ -94,9 +99,7 @@ def apply(self, sql_in):
FROM ( {{sql_in}} )
LIMIT {{limit}}
"""
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
limit=self.limit, sql_in=sql_in
)
return self._render_template(template, sql_in=sql_in, limit=self.limit)


class SQLDistinct(SQLTransform):
Expand All @@ -109,11 +112,8 @@ def apply(self, sql_in):
template = """
SELECT DISTINCT
{{columns}}
FROM ( {{sql_in}} )
"""
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
columns=', '.join(self.columns), sql_in=sql_in
)
FROM ( {{sql_in}} )"""
return self._render_template(template, sql_in=sql_in, columns=', '.join(self.columns))


class SQLMinMax(SQLTransform):
Expand All @@ -130,11 +130,8 @@ def apply(self, sql_in):
template = """
SELECT
{{columns}}
FROM ( {{sql_in}} )
"""
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
columns=', '.join(aggs), sql_in=sql_in
)
FROM ( {{sql_in}} )"""
return self._render_template(template, sql_in=sql_in, columns=', '.join(aggs))


class SQLColumns(SQLTransform):
Expand All @@ -149,9 +146,7 @@ def apply(self, sql_in):
{{columns}}
FROM ( {{sql_in}} )
"""
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
columns=', '.join(self.columns), sql_in=sql_in
)
return self._render_template(template, sql_in=sql_in, columns=', '.join(self.columns))


class SQLFilter(SQLTransform):
Expand All @@ -172,7 +167,7 @@ def _range_filter(cls, col, v1, v2):
if isinstance(v1, dt.date) and not isinstance(v1, dt.datetime):
start += ' 00:00:00'
if isinstance(v2, dt.date) and not isinstance(v2, dt.datetime):
end += ' 00:00:00'
end += ' 23:59:59'
return f'{col} BETWEEN {start!r} AND {end!r}'

def apply(self, sql_in):
Expand Down Expand Up @@ -220,8 +215,5 @@ def apply(self, sql_in):
SELECT
*
FROM ( {{sql_in}} )
WHERE ( {{conditions}} )
"""
return Template(template, trim_blocks=True, lstrip_blocks=True).render(
conditions=' AND '.join(conditions), sql_in=sql_in
)
WHERE ( {{conditions}} )"""
return self._render_template(template, sql_in=sql_in, conditions=' AND '.join(conditions))

0 comments on commit 0ce3420

Please sign in to comment.