From c1944a23892b91529987118fd3d36a9725347c60 Mon Sep 17 00:00:00 2001 From: Doris Lee Date: Wed, 6 Jan 2021 12:02:37 +0800 Subject: [PATCH 1/8] bugfix for describe and convert_dtypes --- lux/core/frame.py | 2 +- lux/core/series.py | 16 ++++++++-------- lux/executor/PandasExecutor.py | 18 +++++------------- tests/test_nan.py | 2 ++ tests/test_pandas.py | 14 ++++++++++++++ 5 files changed, 30 insertions(+), 22 deletions(-) diff --git a/lux/core/frame.py b/lux/core/frame.py index e4ed9e3e..8546c168 100644 --- a/lux/core/frame.py +++ b/lux/core/frame.py @@ -155,7 +155,7 @@ def _set_item(self, key, value): def _infer_structure(self): # If the dataframe is very small and the index column is not a range index, then it is likely that this is an aggregated data is_multi_index_flag = self.index.nlevels != 1 - not_int_index_flag = self.index.dtype != "int64" + not_int_index_flag = not pd.api.types.is_integer_dtype(self.index) small_df_flag = len(self) < 100 self.pre_aggregated = (is_multi_index_flag or not_int_index_flag) and small_df_flag if "Number of Records" in self.columns: diff --git a/lux/core/series.py b/lux/core/series.py index aea13d0c..aebcabbd 100644 --- a/lux/core/series.py +++ b/lux/core/series.py @@ -45,14 +45,14 @@ def _constructor(self): def _constructor_expanddim(self): from lux.core.frame import LuxDataFrame - def f(*args, **kwargs): - df = LuxDataFrame(*args, **kwargs) - for attr in self._metadata: - df.__dict__[attr] = getattr(self, attr, None) - return df - - f._get_axis_number = super(LuxSeries, self)._get_axis_number - return f + # def f(*args, **kwargs): + # df = LuxDataFrame(*args, **kwargs) + # for attr in self._metadata: + # df.__dict__[attr] = getattr(self, attr, None) + # return df + + # f._get_axis_number = super(LuxSeries, self)._get_axis_number + return LuxDataFrame def to_pandas(self): import lux.core diff --git a/lux/executor/PandasExecutor.py b/lux/executor/PandasExecutor.py index 9708d8eb..56422866 100644 --- a/lux/executor/PandasExecutor.py +++ b/lux/executor/PandasExecutor.py @@ -428,9 +428,7 @@ def compute_data_type(self, ldf: LuxDataFrame): ldf.data_type[attr] = "temporal" else: ldf.data_type[attr] = "nominal" - # for attr in list(df.dtypes[df.dtypes=="int64"].keys()): - # if self.cardinality[attr]>50: - if ldf.index.dtype != "int64" and ldf.index.name: + if not pd.api.types.is_integer_dtype(ldf.index) and ldf.index.name: ldf.data_type[ldf.index.name] = "nominal" non_datetime_attrs = [] @@ -489,21 +487,15 @@ def compute_stats(self, ldf: LuxDataFrame): ldf.unique_values[attribute_repr] = list(ldf[attribute_repr].unique()) ldf.cardinality[attribute_repr] = len(ldf.unique_values[attribute_repr]) - # commenting this optimization out to make sure I can filter by cardinality when showing recommended vis - - # if ldf.dtypes[attribute] != "float64":# and not pd.api.types.is_datetime64_ns_dtype(self.dtypes[attribute]): - # ldf.unique_values[attribute_repr] = list(ldf[attribute].unique()) - # ldf.cardinality[attribute_repr] = len(ldf.unique_values[attribute]) - # else: - # ldf.cardinality[attribute_repr] = 999 # special value for non-numeric attribute - - if ldf.dtypes[attribute] == "float64" or ldf.dtypes[attribute] == "int64": + if pd.api.types.is_float_dtype(ldf.dtypes[attribute]) or pd.api.types.is_integer_dtype( + ldf.dtypes[attribute] + ): ldf._min_max[attribute_repr] = ( ldf[attribute].min(), ldf[attribute].max(), ) - if ldf.index.dtype != "int64": + if not pd.api.types.is_integer_dtype(ldf.index): index_column_name = ldf.index.name ldf.unique_values[index_column_name] = list(ldf.index) ldf.cardinality[index_column_name] = len(ldf.index) diff --git a/tests/test_nan.py b/tests/test_nan.py index b2d28fed..1701215f 100644 --- a/tests/test_nan.py +++ b/tests/test_nan.py @@ -22,11 +22,13 @@ def test_nan_column(global_var): df = pytest.college_df + old_geo = df["Geography"] df["Geography"] = np.nan df._repr_html_() for visList in df.recommendation.keys(): for vis in df.recommendation[visList]: assert vis.get_attr_by_attr_name("Geography") == [] + df["Geography"] = old_geo def test_nan_data_type_detection(): diff --git a/tests/test_pandas.py b/tests/test_pandas.py index b43cc1f9..34f68605 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -44,3 +44,17 @@ def test_head_tail(global_var): "Lux is visualizing the previous version of the dataframe before you applied tail." in df._message.to_html() ) + + +def test_describe(global_var): + df = pytest.college_df + summary = df.describe() + summary._repr_html_() + assert len(summary.recommendation["Column Groups"]) == len(summary.columns) == 10 + + +def test_convert_dtype(global_var): + df = pytest.college_df + cdf = df.convert_dtypes() + cdf._repr_html_() + assert list(cdf.recommendation.keys()) == ["Correlation", "Distribution", "Occurrence"] From 5c8b2849d449b1b1e0a7c0b5d57be26c926ae160 Mon Sep 17 00:00:00 2001 From: Doris Lee Date: Wed, 6 Jan 2021 12:08:02 +0800 Subject: [PATCH 2/8] added back metadata series test --- tests/test_pandas.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 34f68605..e4935fde 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -16,16 +16,17 @@ import pytest import pandas as pd -# def test_df_to_series(): -# # Ensure metadata is kept when going from df to series -# df = pd.read_csv("lux/data/car.csv") -# df._repr_html_() # compute metadata -# assert df.cardinality is not None -# series = df["Weight"] -# assert isinstance(series,lux.core.series.LuxSeries), "Derived series is type LuxSeries." -# assert df["Weight"]._metadata == ['name','_intent', 'data_type_lookup', 'data_type', 'data_model_lookup', 'data_model', 'unique_values', 'cardinality', 'min_max', '_current_vis', '_widget', '_recommendation'], "Metadata is lost when going from Dataframe to Series." -# assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." -# assert series.name == "Weight", "Pandas Series original `name` property not retained." +def test_df_to_series(): + # Ensure metadata is kept when going from df to series + df = pd.read_csv("lux/data/car.csv") + df._repr_html_() # compute metadata + assert df.cardinality is not None + series = df["Weight"] + assert isinstance(series,lux.core.series.LuxSeries), "Derived series is type LuxSeries." + print (df["Weight"]._metadata) + assert df["Weight"]._metadata == ['_intent', 'data_type', 'unique_values', 'cardinality', '_rec_info', '_pandas_only', '_min_max', 'plot_config', '_current_vis', '_widget', '_recommendation', '_prev', '_history', '_saved_export', 'name'], "Metadata is lost when going from Dataframe to Series." + assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." + assert series.name == "Weight", "Pandas Series original `name` property not retained." def test_head_tail(global_var): From 49daeecbdc4b09b7e0013b6a9b940d7bb303e716 Mon Sep 17 00:00:00 2001 From: Doris Lee Date: Wed, 6 Jan 2021 12:17:39 +0800 Subject: [PATCH 3/8] black --- tests/test_pandas.py | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index e4935fde..4b38ae1a 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -16,15 +16,32 @@ import pytest import pandas as pd + def test_df_to_series(): # Ensure metadata is kept when going from df to series df = pd.read_csv("lux/data/car.csv") - df._repr_html_() # compute metadata + df._repr_html_() # compute metadata assert df.cardinality is not None series = df["Weight"] - assert isinstance(series,lux.core.series.LuxSeries), "Derived series is type LuxSeries." - print (df["Weight"]._metadata) - assert df["Weight"]._metadata == ['_intent', 'data_type', 'unique_values', 'cardinality', '_rec_info', '_pandas_only', '_min_max', 'plot_config', '_current_vis', '_widget', '_recommendation', '_prev', '_history', '_saved_export', 'name'], "Metadata is lost when going from Dataframe to Series." + assert isinstance(series, lux.core.series.LuxSeries), "Derived series is type LuxSeries." + print(df["Weight"]._metadata) + assert df["Weight"]._metadata == [ + "_intent", + "data_type", + "unique_values", + "cardinality", + "_rec_info", + "_pandas_only", + "_min_max", + "plot_config", + "_current_vis", + "_widget", + "_recommendation", + "_prev", + "_history", + "_saved_export", + "name", + ], "Metadata is lost when going from Dataframe to Series." assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." assert series.name == "Weight", "Pandas Series original `name` property not retained." From 801b469fe5e375916f64a9abf858f0f35ff24fde Mon Sep 17 00:00:00 2001 From: Doris Lee Date: Wed, 6 Jan 2021 15:58:42 +0800 Subject: [PATCH 4/8] default to pandas display when df.dtypes printed --- lux/core/series.py | 4 +++- tests/test_pandas.py | 29 ------------------------ tests/test_series.py | 53 ++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 56 insertions(+), 30 deletions(-) create mode 100644 tests/test_series.py diff --git a/lux/core/series.py b/lux/core/series.py index aebcabbd..1e3c4f8c 100644 --- a/lux/core/series.py +++ b/lux/core/series.py @@ -16,6 +16,7 @@ import lux import warnings import traceback +import numpy as np class LuxSeries(pd.Series): @@ -75,7 +76,8 @@ def __repr__(self): ldf = LuxDataFrame(self) try: - if ldf._pandas_only: + is_dtype_series = all(isinstance(val, np.dtype) for val in self.values) + if ldf._pandas_only or is_dtype_series: print(series_repr) ldf._pandas_only = False else: diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 4b38ae1a..26cd7333 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -17,35 +17,6 @@ import pandas as pd -def test_df_to_series(): - # Ensure metadata is kept when going from df to series - df = pd.read_csv("lux/data/car.csv") - df._repr_html_() # compute metadata - assert df.cardinality is not None - series = df["Weight"] - assert isinstance(series, lux.core.series.LuxSeries), "Derived series is type LuxSeries." - print(df["Weight"]._metadata) - assert df["Weight"]._metadata == [ - "_intent", - "data_type", - "unique_values", - "cardinality", - "_rec_info", - "_pandas_only", - "_min_max", - "plot_config", - "_current_vis", - "_widget", - "_recommendation", - "_prev", - "_history", - "_saved_export", - "name", - ], "Metadata is lost when going from Dataframe to Series." - assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." - assert series.name == "Weight", "Pandas Series original `name` property not retained." - - def test_head_tail(global_var): df = pytest.car_df df._repr_html_() diff --git a/tests/test_series.py b/tests/test_series.py new file mode 100644 index 00000000..62a4697f --- /dev/null +++ b/tests/test_series.py @@ -0,0 +1,53 @@ +# Copyright 2019-2020 The Lux Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .context import lux +import pytest +import pandas as pd +import warnings + + +def test_df_to_series(): + # Ensure metadata is kept when going from df to series + df = pd.read_csv("lux/data/car.csv") + df._repr_html_() # compute metadata + assert df.cardinality is not None + series = df["Weight"] + assert isinstance(series, lux.core.series.LuxSeries), "Derived series is type LuxSeries." + print(df["Weight"]._metadata) + assert df["Weight"]._metadata == [ + "_intent", + "data_type", + "unique_values", + "cardinality", + "_rec_info", + "_pandas_only", + "_min_max", + "plot_config", + "_current_vis", + "_widget", + "_recommendation", + "_prev", + "_history", + "_saved_export", + "name", + ], "Metadata is lost when going from Dataframe to Series." + assert df.cardinality is not None, "Metadata is lost when going from Dataframe to Series." + assert series.name == "Weight", "Pandas Series original `name` property not retained." + + +def test_print_dtypes(global_var): + df = pytest.college_df + with warnings.catch_warnings(record=True) as w: + print(df.dtypes) + assert len(w) == 0, "Warning displayed when printing dtypes" From a8ab02e7f8583d5fadb7ad699756e800e64a86c6 Mon Sep 17 00:00:00 2001 From: Doris Lee Date: Thu, 7 Jan 2021 21:03:54 +0800 Subject: [PATCH 5/8] various fixes to support int columns --- lux/action/enhance.py | 4 +- lux/action/filter.py | 2 +- lux/action/generalize.py | 6 +-- lux/executor/PandasExecutor.py | 6 +-- lux/processor/Parser.py | 59 +++++++++++++++-------------- lux/processor/Validator.py | 12 +++--- lux/utils/utils.py | 15 ++++---- lux/vis/Clause.py | 20 +++++----- lux/vis/Vis.py | 11 ++++-- lux/vis/VisList.py | 23 +++++------ lux/vislib/altair/AltairChart.py | 14 ++++++- lux/vislib/altair/AltairRenderer.py | 11 +++--- lux/vislib/altair/BarChart.py | 35 ++++++++++------- lux/vislib/altair/Heatmap.py | 14 ++++--- lux/vislib/altair/Histogram.py | 18 +++++---- lux/vislib/altair/LineChart.py | 36 ++++++++++++------ lux/vislib/altair/ScatterChart.py | 20 +++++----- tests/test_columns.py | 10 +++++ 18 files changed, 184 insertions(+), 132 deletions(-) diff --git a/lux/action/enhance.py b/lux/action/enhance.py index 0370b2ce..94a4ea60 100644 --- a/lux/action/enhance.py +++ b/lux/action/enhance.py @@ -37,8 +37,8 @@ def enhance(ldf): # Collect variables that already exist in the intent attr_specs = list(filter(lambda x: x.value == "" and x.attribute != "Record", ldf._intent)) fltr_str = [fltr.attribute + fltr.filter_op + str(fltr.value) for fltr in filters] - attr_str = [clause.attribute for clause in attr_specs] - intended_attrs = '

' + ", ".join(attr_str + fltr_str) + "

" + attr_str = [str(clause.attribute) for clause in attr_specs] + intended_attrs = f'

{", ".join(attr_str + fltr_str)}

' if len(attr_specs) == 1: recommendation = { "action": "Enhance", diff --git a/lux/action/filter.py b/lux/action/filter.py index e8833b0f..b36a4c2c 100644 --- a/lux/action/filter.py +++ b/lux/action/filter.py @@ -91,7 +91,7 @@ def get_complementary_ops(fltr_op): else: intended_attrs = ", ".join( [ - clause.attribute + str(clause.attribute) for clause in ldf._intent if clause.value == "" and clause.attribute != "Record" ] diff --git a/lux/action/generalize.py b/lux/action/generalize.py index d95bcb26..91b83239 100644 --- a/lux/action/generalize.py +++ b/lux/action/generalize.py @@ -42,8 +42,8 @@ def generalize(ldf): filters = utils.get_filter_specs(ldf._intent) fltr_str = [fltr.attribute + fltr.filter_op + str(fltr.value) for fltr in filters] - attr_str = [clause.attribute for clause in attributes] - intended_attrs = '

' + ", ".join(attr_str + fltr_str) + "

" + attr_str = [str(clause.attribute) for clause in attributes] + intended_attrs = f'

{", ".join(attr_str + fltr_str)}

' recommendation = { "action": "Generalize", @@ -66,7 +66,7 @@ def generalize(ldf): temp_vis.remove_column_from_spec(column, remove_first=True) excluded_columns.append(column) output.append(temp_vis) - elif type(columns) == str: + else: if columns not in excluded_columns: temp_vis = Vis(ldf.copy_intent(), score=1) temp_vis.remove_column_from_spec(columns, remove_first=True) diff --git a/lux/executor/PandasExecutor.py b/lux/executor/PandasExecutor.py index 56422866..cb8a8ce7 100644 --- a/lux/executor/PandasExecutor.py +++ b/lux/executor/PandasExecutor.py @@ -90,11 +90,11 @@ def execute(vislist: VisList, ldf: LuxDataFrame): # Select relevant data based on attribute information attributes = set([]) for clause in vis._inferred_intent: - if clause.attribute: - if clause.attribute != "Record": - attributes.add(clause.attribute) + if clause.attribute != "Record": + attributes.add(clause.attribute) # TODO: Add some type of cap size on Nrows ? vis._vis_data = vis.data[list(attributes)] + if vis.mark == "bar" or vis.mark == "line": PandasExecutor.execute_aggregate(vis, isFiltered=filter_executed) elif vis.mark == "histogram": diff --git a/lux/processor/Parser.py b/lux/processor/Parser.py index a6538e09..090c619d 100644 --- a/lux/processor/Parser.py +++ b/lux/processor/Parser.py @@ -46,7 +46,6 @@ def parse(intent: List[Union[Clause, str]]) -> List[Clause]: ) import re - # intent = ldf.get_context() new_context = [] # checks for and converts users' string inputs into lux specifications for clause in intent: @@ -59,37 +58,40 @@ def parse(intent: List[Union[Clause, str]]) -> List[Clause]: valid_values.append(v) temp_spec = Clause(attribute=valid_values) new_context.append(temp_spec) - elif isinstance(clause, str): - # case where user specifies a filter - if "=" in clause: - eqInd = clause.index("=") - var = clause[0:eqInd] - if "|" in clause: - values = clause[eqInd + 1 :].split("|") - for v in values: - # if v in ldf.unique_values[var]: #TODO: Move validation check to Validator - valid_values.append(v) + elif isinstance(clause, Clause): + new_context.append(clause) + else: + if isinstance(clause, str): + # case where user specifies a filter + if "=" in clause: + eqInd = clause.index("=") + var = clause[0:eqInd] + if "|" in clause: + values = clause[eqInd + 1 :].split("|") + for v in values: + # if v in ldf.unique_values[var]: #TODO: Move validation check to Validator + valid_values.append(v) + else: + valid_values = clause[eqInd + 1 :] + # if var in list(ldf.columns): #TODO: Move validation check to Validator + temp_spec = Clause(attribute=var, filter_op="=", value=valid_values) + new_context.append(temp_spec) + # case where user specifies a variable else: - valid_values = clause[eqInd + 1 :] - # if var in list(ldf.columns): #TODO: Move validation check to Validator - temp_spec = Clause(attribute=var, filter_op="=", value=valid_values) - new_context.append(temp_spec) - # case where user specifies a variable + if "|" in clause: + values = clause.split("|") + for v in values: + # if v in list(ldf.columns): #TODO: Move validation check to Validator + valid_values.append(v) + else: + valid_values = clause + temp_spec = Clause(attribute=valid_values) + new_context.append(temp_spec) else: - if "|" in clause: - values = clause.split("|") - for v in values: - # if v in list(ldf.columns): #TODO: Move validation check to Validator - valid_values.append(v) - else: - valid_values = clause - temp_spec = Clause(attribute=valid_values) + temp_spec = Clause(attribute=clause) new_context.append(temp_spec) - elif type(clause) is Clause: - new_context.append(clause) - intent = new_context - # ldf._intent = new_context + intent = new_context for clause in intent: if clause.description: # TODO: Move validation check to Validator @@ -112,4 +114,3 @@ def parse(intent: List[Union[Clause, str]]) -> List[Clause]: else: # then it is probably a value clause.value = clause.description return intent - # ldf._intent = intent diff --git a/lux/processor/Validator.py b/lux/processor/Validator.py index c72dc63b..2550ac31 100644 --- a/lux/processor/Validator.py +++ b/lux/processor/Validator.py @@ -57,9 +57,7 @@ def validate_intent(intent: List[Clause], ldf: LuxDataFrame) -> None: def validate_clause(clause): warn_msg = "" - if not ( - (clause.attribute and clause.attribute == "?") or (clause.value and clause.value == "?") - ): + if not (clause.attribute == "?" or clause.value == "?" or clause.attribute == ""): if isinstance(clause.attribute, list): for attr in clause.attribute: if attr not in list(ldf.columns): @@ -69,7 +67,9 @@ def validate_clause(clause): else: if clause.attribute != "Record": # we don't value check datetime since datetime can take filter values that don't exactly match the exact TimeStamp representation - if clause.attribute and not is_datetime_string(clause.attribute): + if isinstance(clause.attribute, str) and not is_datetime_string( + clause.attribute + ): if not clause.attribute in list(ldf.columns): search_val = clause.attribute match_attr = False @@ -80,9 +80,7 @@ def validate_clause(clause): warn_msg = f"\n- The input '{search_val}' looks like a value that belongs to the '{match_attr}' attribute. \n Please specify the value fully, as something like {match_attr}={search_val}." else: warn_msg = f"\n- The input attribute '{clause.attribute}' does not exist in the DataFrame. \n Please check your input intent for typos." - if clause.value and clause.attribute and clause.filter_op == "=": - import math - + if clause.value != "" and clause.attribute != "" and clause.filter_op == "=": # Skip check for NaN filter values if not lux.utils.utils.like_nan(clause.value): series = ldf[clause.attribute] diff --git a/lux/utils/utils.py b/lux/utils/utils.py index e19afcf4..3ae4503d 100644 --- a/lux/utils/utils.py +++ b/lux/utils/utils.py @@ -57,16 +57,17 @@ def check_import_lux_widget(): def get_agg_title(clause): + attr = str(clause.attribute) if clause.aggregation is None: - if len(clause.attribute) > 25: - return clause.attribute[:15] + "..." + clause.attribute[-10:] - return f"{clause.attribute}" - elif clause.attribute == "Record": + if len(attr) > 25: + return attr[:15] + "..." + attr[-10:] + return f"{attr}" + elif attr == "Record": return f"Number of Records" else: - if len(clause.attribute) > 15: - return f"{clause._aggregation_name.capitalize()} of {clause.attribute[:15]}..." - return f"{clause._aggregation_name.capitalize()} of {clause.attribute}" + if len(attr) > 15: + return f"{clause._aggregation_name.capitalize()} of {attr[:15]}..." + return f"{clause._aggregation_name.capitalize()} of {attr}" def check_if_id_like(df, attribute): diff --git a/lux/vis/Clause.py b/lux/vis/Clause.py index ca7efd76..fcaf71d3 100644 --- a/lux/vis/Clause.py +++ b/lux/vis/Clause.py @@ -116,7 +116,7 @@ def to_string(self): if isinstance(self.attribute, list): clauseStr = "|".join(self.attribute) elif self.value == "": - clauseStr = self.attribute + clauseStr = str(self.attribute) else: clauseStr = f"{self.attribute}{self.filter_op}{self.value}" return clauseStr @@ -126,23 +126,23 @@ def __repr__(self): if self.description != "": attributes.append(f" description: {self.description}") if self.channel != "": - attributes.append(" channel: " + self.channel) - if len(self.attribute) != 0: - attributes.append(" attribute: " + str(self.attribute)) + attributes.append(f" channel: {self.channel}") + if self.attribute != "": + attributes.append(f" attribute: {str(self.attribute)}") if self.filter_op != "=": attributes.append(f" filter_op: {str(self.filter_op)}") if self.aggregation != "" and self.aggregation is not None: attributes.append(" aggregation: " + self._aggregation_name) if self.value != "" or len(self.value) != 0: - attributes.append(" value: " + str(self.value)) + attributes.append(f" value: {str(self.value)}") if self.data_model != "": - attributes.append(" data_model: " + self.data_model) + attributes.append(f" data_model: {self.data_model}") if len(self.data_type) != 0: - attributes.append(" data_type: " + str(self.data_type)) - if self.bin_size != None: - attributes.append(" bin_size: " + str(self.bin_size)) + attributes.append(f" data_type: {str(self.data_type)}") + if self.bin_size != 0: + attributes.append(f" bin_size: {str(self.bin_size)}") if len(self.exclude) != 0: - attributes.append(" exclude: " + str(self.exclude)) + attributes.append(f" exclude: {str(self.exclude)}") attributes[0] = " 0: - attribute = "BIN(" + clause.attribute + ")" + attribute = f"BIN({clause.attribute})" else: attribute = clause.attribute if clause.channel == "x": @@ -64,7 +64,7 @@ def __repr__(self): channels.extend(additional_channels) str_channels = "" for channel in channels: - str_channels += channel[0] + ": " + channel[1] + ", " + str_channels += f"{channel[0]}: {channel[1]}, " if filter_intents: return f"" @@ -324,5 +324,8 @@ def check_not_vislist_intent(self): for i in range(len(self._intent)): clause = self._intent[i] - if type(clause) != Clause and ("|" in clause or type(clause) == list or "?" in clause): + if isinstance(clause, str): + if "|" in clause or "?" in clause: + raise TypeError(syntaxMsg) + if isinstance(clause, list): raise TypeError(syntaxMsg) diff --git a/lux/vis/VisList.py b/lux/vis/VisList.py index 9e74961a..a346e6cc 100644 --- a/lux/vis/VisList.py +++ b/lux/vis/VisList.py @@ -133,16 +133,17 @@ def __repr__(self): for vis in self._collection: filter_intents = None for clause in vis._inferred_intent: + attr = str(clause.attribute) if clause.value != "": filter_intents = clause if clause.aggregation != "" and clause.aggregation is not None: - attribute = clause._aggregation_name.upper() + "(" + clause.attribute + ")" + attribute = clause._aggregation_name.upper() + f"({attr})" elif clause.bin_size > 0: - attribute = "BIN(" + clause.attribute + ")" + attribute = f"BIN({attr})" else: - attribute = clause.attribute - + attribute = attr + attribute = str(attribute) if clause.channel == "x" and len(x_channel) < len(attribute): x_channel = attribute if clause.channel == "y" and len(y_channel) < len(attribute): @@ -151,9 +152,9 @@ def __repr__(self): largest_mark = len(vis.mark) if ( filter_intents - and len(str(filter_intents.value)) + len(filter_intents.attribute) > largest_filter + and len(str(filter_intents.value)) + len(str(filter_intents.attribute)) > largest_filter ): - largest_filter = len(str(filter_intents.value)) + len(filter_intents.attribute) + largest_filter = len(str(filter_intents.value)) + len(str(filter_intents.attribute)) vis_repr = [] largest_x_length = len(x_channel) largest_y_length = len(y_channel) @@ -164,16 +165,16 @@ def __repr__(self): y_channel = "" additional_channels = [] for clause in vis._inferred_intent: + attr = str(clause.attribute) if clause.value != "": filter_intents = clause if clause.aggregation != "" and clause.aggregation is not None and vis.mark != "scatter": - attribute = clause._aggregation_name.upper() + "(" + clause.attribute + ")" + attribute = clause._aggregation_name.upper() + f"({attr})" elif clause.bin_size > 0: - attribute = "BIN(" + clause.attribute + ")" + attribute = f"BIN({attr})" else: - attribute = clause.attribute - + attribute = attr if clause.channel == "x": x_channel = attribute.ljust(largest_x_length) elif clause.channel == "y": @@ -197,7 +198,7 @@ def __repr__(self): if filter_intents: aligned_filter = ( " -- [" - + filter_intents.attribute + + str(filter_intents.attribute) + filter_intents.filter_op + str(filter_intents.value) + "]" diff --git a/lux/vislib/altair/AltairChart.py b/lux/vislib/altair/AltairChart.py index de4830f7..77fef1ec 100644 --- a/lux/vislib/altair/AltairChart.py +++ b/lux/vislib/altair/AltairChart.py @@ -87,7 +87,7 @@ def encode_color(self): timeUnit = compute_date_granularity(self.vis.data[color_attr_name]) self.chart = self.chart.encode( color=alt.Color( - color_attr_name, + str(color_attr_name), type=color_attr_type, timeUnit=timeUnit, title=color_attr_name, @@ -95,7 +95,9 @@ def encode_color(self): ) self.code += f"chart = chart.encode(color=alt.Color('{color_attr_name}',type='{color_attr_type}',timeUnit='{timeUnit}',title='{color_attr_name}'))" else: - self.chart = self.chart.encode(color=alt.Color(color_attr_name, type=color_attr_type)) + self.chart = self.chart.encode( + color=alt.Color(str(color_attr_name), type=color_attr_type) + ) self.code += f"chart = chart.encode(color=alt.Color('{color_attr_name}',type='{color_attr_type}'))\n" elif len(color_attr) > 1: raise ValueError( @@ -111,3 +113,11 @@ def add_title(self): def initialize_chart(self): return NotImplemented + + @classmethod + def sanitize_dataframe(self, df): + for attr in df.columns: + # Altair can not visualize non-string columns + # convert all non-string columns in to strings + df = df.rename(columns={attr: str(attr)}) + return df diff --git a/lux/vislib/altair/AltairRenderer.py b/lux/vislib/altair/AltairRenderer.py index 2957cd17..080dd8a2 100644 --- a/lux/vislib/altair/AltairRenderer.py +++ b/lux/vislib/altair/AltairRenderer.py @@ -66,11 +66,12 @@ def create_vis(self, vis, standalone=True): vis.data[attr].iloc[0], pd.Interval ): vis.data[attr] = vis.data[attr].astype(str) - if "." in attr: - attr_clause = vis.get_attr_by_attr_name(attr)[0] - # Suppress special character ".", not displayable in Altair - # attr_clause.attribute = attr_clause.attribute.replace(".", "") - vis._vis_data = vis.data.rename(columns={attr: attr.replace(".", "")}) + if isinstance(attr, str): + if "." in attr: + attr_clause = vis.get_attr_by_attr_name(attr)[0] + # Suppress special character ".", not displayable in Altair + # attr_clause.attribute = attr_clause.attribute.replace(".", "") + vis._vis_data = vis.data.rename(columns={attr: attr.replace(".", "")}) if vis.mark == "histogram": chart = Histogram(vis) elif vis.mark == "bar": diff --git a/lux/vislib/altair/BarChart.py b/lux/vislib/altair/BarChart.py index 91e17b29..b989d857 100644 --- a/lux/vislib/altair/BarChart.py +++ b/lux/vislib/altair/BarChart.py @@ -40,28 +40,32 @@ def initialize_chart(self): x_attr = self.vis.get_attr_by_channel("x")[0] y_attr = self.vis.get_attr_by_channel("y")[0] - x_attr_abv = x_attr.attribute - y_attr_abv = y_attr.attribute + x_attr_abv = str(x_attr.attribute) + y_attr_abv = str(y_attr.attribute) - if len(x_attr.attribute) > 25: + if len(x_attr_abv) > 25: x_attr_abv = x_attr.attribute[:15] + "..." + x_attr.attribute[-10:] - if len(y_attr.attribute) > 25: + if len(y_attr_abv) > 25: y_attr_abv = y_attr.attribute[:15] + "..." + y_attr.attribute[-10:] - - x_attr.attribute = x_attr.attribute.replace(".", "") - y_attr.attribute = y_attr.attribute.replace(".", "") + if isinstance(x_attr.attribute, str): + x_attr.attribute = x_attr.attribute.replace(".", "") + if isinstance(y_attr.attribute, str): + y_attr.attribute = y_attr.attribute.replace(".", "") if x_attr.data_model == "measure": agg_title = get_agg_title(x_attr) measure_attr = x_attr.attribute bar_attr = y_attr.attribute y_attr_field = alt.Y( - y_attr.attribute, + str(y_attr.attribute), type=y_attr.data_type, axis=alt.Axis(labelOverlap=True, title=y_attr_abv), ) x_attr_field = alt.X( - x_attr.attribute, type=x_attr.data_type, title=agg_title, axis=alt.Axis(title=agg_title) + str(x_attr.attribute), + type=x_attr.data_type, + title=agg_title, + axis=alt.Axis(title=agg_title), ) y_attr_field_code = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', axis=alt.Axis(labelOverlap=True, title='{y_attr_abv}'))" x_attr_field_code = f"alt.X('{x_attr.attribute}', type= '{x_attr.data_type}', title='{agg_title}', axis=alt.Axis(title='{agg_title}'))" @@ -74,13 +78,16 @@ def initialize_chart(self): measure_attr = y_attr.attribute bar_attr = x_attr.attribute x_attr_field = alt.X( - x_attr.attribute, + str(x_attr.attribute), type=x_attr.data_type, axis=alt.Axis(labelOverlap=True, title=x_attr_abv), ) x_attr_field_code = f"alt.X('{x_attr.attribute}', type= '{x_attr.data_type}', axis=alt.Axis(labelOverlap=True, title='{x_attr_abv}'))" y_attr_field = alt.Y( - y_attr.attribute, type=y_attr.data_type, title=agg_title, axis=alt.Axis(title=agg_title) + str(y_attr.attribute), + type=y_attr.data_type, + title=agg_title, + axis=alt.Axis(title=agg_title), ) y_attr_field_code = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', title='{agg_title}', axis=alt.Axis(title='{agg_title}'))" if x_attr.sort == "ascending": @@ -89,9 +96,11 @@ def initialize_chart(self): k = 10 self._topkcode = "" n_bars = len(self.data.iloc[:, 0].unique()) + if n_bars > k: # Truncating to only top k remaining_bars = n_bars - k - self.data = self.data.nlargest(k, measure_attr) + self.data = self.data.nlargest(k, columns=measure_attr) + self.data = AltairChart.sanitize_dataframe(self.data) self.text = alt.Chart(self.data).mark_text( x=155, y=142, @@ -110,7 +119,7 @@ def initialize_chart(self): text=f"+ {remaining_bars} more ..." ) chart = chart + text\n""" - + self.data = AltairChart.sanitize_dataframe(self.data) chart = alt.Chart(self.data).mark_bar().encode(y=y_attr_field, x=x_attr_field) # TODO: tooltip messes up the count() bar charts diff --git a/lux/vislib/altair/Heatmap.py b/lux/vislib/altair/Heatmap.py index 4432de56..f83a3bbb 100644 --- a/lux/vislib/altair/Heatmap.py +++ b/lux/vislib/altair/Heatmap.py @@ -39,16 +39,18 @@ def initialize_chart(self): x_attr = self.vis.get_attr_by_channel("x")[0] y_attr = self.vis.get_attr_by_channel("y")[0] - x_attr_abv = x_attr.attribute - y_attr_abv = y_attr.attribute + x_attr_abv = str(x_attr.attribute) + y_attr_abv = str(y_attr.attribute) - if len(x_attr.attribute) > 25: + if len(x_attr_abv) > 25: x_attr_abv = x_attr.attribute[:15] + "..." + x_attr.attribute[-10:] - if len(y_attr.attribute) > 25: + if len(y_attr_abv) > 25: y_attr_abv = y_attr.attribute[:15] + "..." + y_attr.attribute[-10:] - x_attr.attribute = x_attr.attribute.replace(".", "") - y_attr.attribute = y_attr.attribute.replace(".", "") + if isinstance(x_attr.attribute, str): + x_attr.attribute = x_attr.attribute.replace(".", "") + if isinstance(y_attr.attribute, str): + y_attr.attribute = y_attr.attribute.replace(".", "") chart = ( alt.Chart(self.data) diff --git a/lux/vislib/altair/Histogram.py b/lux/vislib/altair/Histogram.py index 38e578ab..60c2d999 100644 --- a/lux/vislib/altair/Histogram.py +++ b/lux/vislib/altair/Histogram.py @@ -38,28 +38,30 @@ def initialize_chart(self): self.tooltip = False measure = self.vis.get_attr_by_data_model("measure", exclude_record=True)[0] msr_attr = self.vis.get_attr_by_channel(measure.channel)[0] + msr_attr_abv = str(msr_attr.attribute) - msr_attr_abv = msr_attr.attribute - - if len(msr_attr.attribute) > 17: - msr_attr_abv = msr_attr.attribute[:10] + "..." + msr_attr.attribute[-7:] + if len(msr_attr_abv) > 17: + msr_attr_abv = msr_attr_abv[:10] + "..." + msr_attr_abv[-7:] x_min = self.vis.min_max[msr_attr.attribute][0] x_max = self.vis.min_max[msr_attr.attribute][1] - msr_attr.attribute = msr_attr.attribute.replace(".", "") + if isinstance(msr_attr.attribute, str): + msr_attr.attribute = msr_attr.attribute.replace(".", "") - x_range = abs(max(self.vis.data[msr_attr.attribute]) - min(self.vis.data[msr_attr.attribute])) + colval = self.vis.data[msr_attr.attribute] + x_range = abs(max(colval) - min(colval)) plot_range = abs(x_max - x_min) markbar = x_range / plot_range * 12 + self.data = AltairChart.sanitize_dataframe(self.data) if measure.channel == "x": chart = ( alt.Chart(self.data) .mark_bar(size=markbar) .encode( alt.X( - msr_attr.attribute, + str(msr_attr.attribute), title=f"{msr_attr.attribute} (binned)", bin=alt.Bin(binned=True), type=msr_attr.data_type, @@ -76,7 +78,7 @@ def initialize_chart(self): .encode( x=alt.X("Number of Records", type="quantitative"), y=alt.Y( - msr_attr.attribute, + str(msr_attr.attribute), title=f"{msr_attr.attribute} (binned)", bin=alt.Bin(binned=True), axis=alt.Axis(labelOverlap=True, title=f"{msr_attr_abv} (binned)"), diff --git a/lux/vislib/altair/LineChart.py b/lux/vislib/altair/LineChart.py index 54b28c46..ae589030 100644 --- a/lux/vislib/altair/LineChart.py +++ b/lux/vislib/altair/LineChart.py @@ -40,16 +40,18 @@ def initialize_chart(self): x_attr = self.vis.get_attr_by_channel("x")[0] y_attr = self.vis.get_attr_by_channel("y")[0] - x_attr_abv = x_attr.attribute - y_attr_abv = y_attr.attribute + x_attr_abv = str(x_attr.attribute) + y_attr_abv = str(y_attr.attribute) - if len(x_attr.attribute) > 25: + if len(x_attr_abv) > 25: x_attr_abv = x_attr.attribute[:15] + "..." + x_attr.attribute[-10:] - if len(y_attr.attribute) > 25: + if len(y_attr_abv) > 25: y_attr_abv = y_attr.attribute[:15] + "..." + y_attr.attribute[-10:] - x_attr.attribute = x_attr.attribute.replace(".", "") - y_attr.attribute = y_attr.attribute.replace(".", "") + if isinstance(x_attr.attribute, str): + x_attr.attribute = x_attr.attribute.replace(".", "") + if isinstance(y_attr.attribute, str): + y_attr.attribute = y_attr.attribute.replace(".", "") # Remove NaNs only for Line Charts (offsets axis range) self.data = self.data.dropna(subset=[x_attr.attribute, y_attr.attribute]) @@ -60,21 +62,31 @@ def initialize_chart(self): if y_attr.data_model == "measure": agg_title = get_agg_title(y_attr) - x_attr_spec = alt.X(x_attr.attribute, type=x_attr.data_type, axis=alt.Axis(title=x_attr_abv)) + x_attr_spec = alt.X( + str(x_attr.attribute), type=x_attr.data_type, axis=alt.Axis(title=x_attr_abv) + ) y_attr_spec = alt.Y( - y_attr.attribute, type=y_attr.data_type, title=agg_title, axis=alt.Axis(title=y_attr_abv) + str(y_attr.attribute), + type=y_attr.data_type, + title=agg_title, + axis=alt.Axis(title=y_attr_abv), ) x_attr_field_code = f"alt.X('{x_attr.attribute}', type = '{x_attr.data_type}', axis=alt.Axis(title='{x_attr_abv}'))" y_attr_fieldCode = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', title='{agg_title}', axis=alt.Axis(title='{y_attr_abv}')" else: agg_title = get_agg_title(x_attr) x_attr_spec = alt.X( - x_attr.attribute, type=x_attr.data_type, title=agg_title, axis=alt.Axis(title=x_attr_abv) + str(x_attr.attribute), + type=x_attr.data_type, + title=agg_title, + axis=alt.Axis(title=x_attr_abv), + ) + y_attr_spec = alt.Y( + str(y_attr.attribute), type=y_attr.data_type, axis=alt.Axis(title=y_attr_abv) ) - y_attr_spec = alt.Y(y_attr.attribute, type=y_attr.data_type, axis=alt.Axis(title=y_attr_abv)) x_attr_field_code = f"alt.X('{x_attr.attribute}', type = '{x_attr.data_type}', title='{agg_title}', axis=alt.Axis(title='{x_attr_abv}')" - y_attr_fieldCode = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', axis=alt.Axis(title='{u_attr_abv}')" - + y_attr_fieldCode = f"alt.Y('{y_attr.attribute}', type= '{y_attr.data_type}', axis=alt.Axis(title='{y_attr_abv}')" + self.data = AltairChart.sanitize_dataframe(self.data) chart = alt.Chart(self.data).mark_line().encode(x=x_attr_spec, y=y_attr_spec) chart = chart.interactive() # Enable Zooming and Panning self.code += f""" diff --git a/lux/vislib/altair/ScatterChart.py b/lux/vislib/altair/ScatterChart.py index da645cda..21ae39ab 100644 --- a/lux/vislib/altair/ScatterChart.py +++ b/lux/vislib/altair/ScatterChart.py @@ -38,12 +38,12 @@ def initialize_chart(self): x_attr = self.vis.get_attr_by_channel("x")[0] y_attr = self.vis.get_attr_by_channel("y")[0] - x_attr_abv = x_attr.attribute - y_attr_abv = y_attr.attribute + x_attr_abv = str(x_attr.attribute) + y_attr_abv = str(y_attr.attribute) - if len(x_attr.attribute) > 25: + if len(x_attr_abv) > 25: x_attr_abv = x_attr.attribute[:15] + "..." + x_attr.attribute[-10:] - if len(y_attr.attribute) > 25: + if len(y_attr_abv) > 25: y_attr_abv = y_attr.attribute[:15] + "..." + y_attr.attribute[-10:] x_min = self.vis.min_max[x_attr.attribute][0] @@ -52,21 +52,23 @@ def initialize_chart(self): y_min = self.vis.min_max[y_attr.attribute][0] y_max = self.vis.min_max[y_attr.attribute][1] - x_attr.attribute = x_attr.attribute.replace(".", "") - y_attr.attribute = y_attr.attribute.replace(".", "") - + if isinstance(x_attr.attribute, str): + x_attr.attribute = x_attr.attribute.replace(".", "") + if isinstance(y_attr.attribute, str): + y_attr.attribute = y_attr.attribute.replace(".", "") + self.data = AltairChart.sanitize_dataframe(self.data) chart = ( alt.Chart(self.data) .mark_circle() .encode( x=alt.X( - x_attr.attribute, + str(x_attr.attribute), scale=alt.Scale(domain=(x_min, x_max)), type=x_attr.data_type, axis=alt.Axis(title=x_attr_abv), ), y=alt.Y( - y_attr.attribute, + str(y_attr.attribute), scale=alt.Scale(domain=(y_min, y_max)), type=y_attr.data_type, axis=alt.Axis(title=y_attr_abv), diff --git a/tests/test_columns.py b/tests/test_columns.py index 6216b471..19db44b0 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -121,3 +121,13 @@ def test_abbrev_agg(): test = pd.DataFrame(dataset) vis = Vis([long_var, "normal"], test).to_Altair() assert "axis=alt.Axis(title='Mean of Lorem ipsum dol...')" in vis + + +def test_int_columns(global_var): + df = pd.read_csv("lux/data/college.csv") + df.columns = range(len(df.columns)) + assert list(df.recommendation.keys()) == ["Correlation", "Distribution", "Occurrence"] + df.intent = [8, 3] + assert list(df.recommendation.keys()) == ["Enhance", "Filter", "Generalize"] + df.intent = [0] + assert list(df.recommendation.keys()) == ["Enhance", "Filter"] From 74f2c7e9d365d85545560cab5c576cad46c5a911 Mon Sep 17 00:00:00 2001 From: Doris Lee Date: Fri, 8 Jan 2021 09:23:38 +0800 Subject: [PATCH 6/8] skip series vis for df.iterrows series element --- lux/core/series.py | 6 +++++- tests/test_series.py | 9 +++++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/lux/core/series.py b/lux/core/series.py index 3a4068d3..89382660 100644 --- a/lux/core/series.py +++ b/lux/core/series.py @@ -84,8 +84,12 @@ def __repr__(self): ldf = LuxDataFrame(self) try: + # Ignore recommendations when Series a results of: + # 1) Values of the series are of dtype objects (df.dtypes) is_dtype_series = all(isinstance(val, np.dtype) for val in self.values) - if ldf._pandas_only or is_dtype_series: + # 2) Mixed type, often a result of a "row" acting as a series (df.iterrows, df.iloc[0]) + mixed_dtype = len(set([type(val) for val in self.values])) > 1 + if ldf._pandas_only or is_dtype_series or mixed_dtype: print(series_repr) ldf._pandas_only = False else: diff --git a/tests/test_series.py b/tests/test_series.py index 62a4697f..6bbbed26 100644 --- a/tests/test_series.py +++ b/tests/test_series.py @@ -51,3 +51,12 @@ def test_print_dtypes(global_var): with warnings.catch_warnings(record=True) as w: print(df.dtypes) assert len(w) == 0, "Warning displayed when printing dtypes" + + +def test_print_iterrow(global_var): + df = pytest.college_df + with warnings.catch_warnings(record=True) as w: + for index, row in df.iterrows(): + print(row) + break + assert len(w) == 0, "Warning displayed when printing iterrow" From 298a87e81c92797126e03cbf3cf64e0ec6f9fd5f Mon Sep 17 00:00:00 2001 From: Doris Lee Date: Fri, 8 Jan 2021 16:26:40 +0800 Subject: [PATCH 7/8] config setting for modifying top K and sorting --- doc/source/reference/config.rst | 32 +++++++++++ .../gen/lux._config.config.Config.rst | 4 ++ .../gen/lux.core.series.LuxSeries.rst | 1 - .../reference/gen/lux.vis.VisList.VisList.rst | 3 +- ....vislib.altair.AltairChart.AltairChart.rst | 1 + .../lux.vislib.altair.BarChart.BarChart.rst | 1 + .../lux.vislib.altair.Histogram.Histogram.rst | 1 + .../lux.vislib.altair.LineChart.LineChart.rst | 1 + ...islib.altair.ScatterChart.ScatterChart.rst | 1 + lux/_config/config.py | 53 ++++++++++++++++++- lux/action/correlation.py | 3 +- lux/action/enhance.py | 3 +- lux/action/filter.py | 3 +- lux/action/generalize.py | 1 + lux/action/univariate.py | 1 - lux/vis/VisList.py | 22 ++++---- tests/test_config.py | 41 +++++++++++++- 17 files changed, 152 insertions(+), 20 deletions(-) diff --git a/doc/source/reference/config.rst b/doc/source/reference/config.rst index 7b85b687..bb5bd455 100644 --- a/doc/source/reference/config.rst +++ b/doc/source/reference/config.rst @@ -108,3 +108,35 @@ The above results in the following changes: See `this page `__ for more details. +Modify Sorting and Ranking in Recommendations +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +In Lux, we select a small subset of visualizations to display in each action tab to avoid displaying too many charts at once. +Certain recommendation categories ranks and selects the top K most interesting visualizations to display. +You can modify the sorting order and selection cutoff via :code:`lux.config`. +By default, the recommendations are sorted in a :code:`"descending"` order based on their interestingness score, you can reverse the ordering by setting the sort order as: + +.. code-block:: python + + lux.config.sort = "ascending" + +To turn off the sorting of visualizations based on its score completely and ensure that the visualizations show up in the same order across all dataframes, you can set the sorting as "none": + +.. code-block:: python + + lux.config.sort = "none" + +For recommendation actions that generate a lot of visualizations, we select the cutoff criteria as the top 15 visualizations. If you would like to see only see the top 6 visualizations, you can set: + +.. code-block:: python + + lux.config.topk = 6 + +If you would like to turn off the selection criteria completely and display everything, you can turn off the top K selection by: + +.. code-block:: python + + lux.config.topk = False + +Beware that this may generate large numbers of visualizations (e.g., for 10 quantitative variables, this will generate 45 scatterplots in the Correlation action!) + diff --git a/doc/source/reference/gen/lux._config.config.Config.rst b/doc/source/reference/gen/lux._config.config.Config.rst index 0000b36f..48db70fb 100644 --- a/doc/source/reference/gen/lux._config.config.Config.rst +++ b/doc/source/reference/gen/lux._config.config.Config.rst @@ -14,6 +14,8 @@ lux.\_config.config.Config .. autosummary:: ~Config.__init__ + ~Config.register_action + ~Config.remove_action ~Config.set_SQL_connection ~Config.set_executor_type @@ -30,5 +32,7 @@ lux.\_config.config.Config ~Config.sampling ~Config.sampling_cap ~Config.sampling_start + ~Config.sort + ~Config.topk \ No newline at end of file diff --git a/doc/source/reference/gen/lux.core.series.LuxSeries.rst b/doc/source/reference/gen/lux.core.series.LuxSeries.rst index 0f50d3e4..a136115e 100644 --- a/doc/source/reference/gen/lux.core.series.LuxSeries.rst +++ b/doc/source/reference/gen/lux.core.series.LuxSeries.rst @@ -53,7 +53,6 @@ lux.core.series.LuxSeries ~LuxSeries.cumsum ~LuxSeries.describe ~LuxSeries.diff - ~LuxSeries.display_pandas ~LuxSeries.div ~LuxSeries.divide ~LuxSeries.divmod diff --git a/doc/source/reference/gen/lux.vis.VisList.VisList.rst b/doc/source/reference/gen/lux.vis.VisList.VisList.rst index daf9c501..f22b5bae 100644 --- a/doc/source/reference/gen/lux.vis.VisList.VisList.rst +++ b/doc/source/reference/gen/lux.vis.VisList.VisList.rst @@ -14,7 +14,6 @@ lux.vis.VisList.VisList .. autosummary:: ~VisList.__init__ - ~VisList.bottomK ~VisList.get ~VisList.map ~VisList.normalize_score @@ -23,8 +22,8 @@ lux.vis.VisList.VisList ~VisList.remove_index ~VisList.set ~VisList.set_intent + ~VisList.showK ~VisList.sort - ~VisList.topK diff --git a/doc/source/reference/gen/lux.vislib.altair.AltairChart.AltairChart.rst b/doc/source/reference/gen/lux.vislib.altair.AltairChart.AltairChart.rst index accd69eb..b2cafeed 100644 --- a/doc/source/reference/gen/lux.vislib.altair.AltairChart.AltairChart.rst +++ b/doc/source/reference/gen/lux.vislib.altair.AltairChart.AltairChart.rst @@ -19,6 +19,7 @@ lux.vislib.altair.AltairChart.AltairChart ~AltairChart.apply_default_config ~AltairChart.encode_color ~AltairChart.initialize_chart + ~AltairChart.sanitize_dataframe diff --git a/doc/source/reference/gen/lux.vislib.altair.BarChart.BarChart.rst b/doc/source/reference/gen/lux.vislib.altair.BarChart.BarChart.rst index 5c4878f8..b55c95b3 100644 --- a/doc/source/reference/gen/lux.vislib.altair.BarChart.BarChart.rst +++ b/doc/source/reference/gen/lux.vislib.altair.BarChart.BarChart.rst @@ -20,6 +20,7 @@ lux.vislib.altair.BarChart.BarChart ~BarChart.apply_default_config ~BarChart.encode_color ~BarChart.initialize_chart + ~BarChart.sanitize_dataframe diff --git a/doc/source/reference/gen/lux.vislib.altair.Histogram.Histogram.rst b/doc/source/reference/gen/lux.vislib.altair.Histogram.Histogram.rst index 47733466..920d6394 100644 --- a/doc/source/reference/gen/lux.vislib.altair.Histogram.Histogram.rst +++ b/doc/source/reference/gen/lux.vislib.altair.Histogram.Histogram.rst @@ -19,6 +19,7 @@ lux.vislib.altair.Histogram.Histogram ~Histogram.apply_default_config ~Histogram.encode_color ~Histogram.initialize_chart + ~Histogram.sanitize_dataframe diff --git a/doc/source/reference/gen/lux.vislib.altair.LineChart.LineChart.rst b/doc/source/reference/gen/lux.vislib.altair.LineChart.LineChart.rst index 3143e2f9..89257108 100644 --- a/doc/source/reference/gen/lux.vislib.altair.LineChart.LineChart.rst +++ b/doc/source/reference/gen/lux.vislib.altair.LineChart.LineChart.rst @@ -19,6 +19,7 @@ lux.vislib.altair.LineChart.LineChart ~LineChart.apply_default_config ~LineChart.encode_color ~LineChart.initialize_chart + ~LineChart.sanitize_dataframe diff --git a/doc/source/reference/gen/lux.vislib.altair.ScatterChart.ScatterChart.rst b/doc/source/reference/gen/lux.vislib.altair.ScatterChart.ScatterChart.rst index f7a1d283..be0569f7 100644 --- a/doc/source/reference/gen/lux.vislib.altair.ScatterChart.ScatterChart.rst +++ b/doc/source/reference/gen/lux.vislib.altair.ScatterChart.ScatterChart.rst @@ -19,6 +19,7 @@ lux.vislib.altair.ScatterChart.ScatterChart ~ScatterChart.apply_default_config ~ScatterChart.encode_color ~ScatterChart.initialize_chart + ~ScatterChart.sanitize_dataframe diff --git a/lux/_config/config.py b/lux/_config/config.py index 419f9909..09acb132 100644 --- a/lux/_config/config.py +++ b/lux/_config/config.py @@ -3,9 +3,9 @@ For more resources, see https://github.com/pandas-dev/pandas/blob/master/pandas/_config """ from collections import namedtuple -from typing import Any, Callable, Dict, Iterable, List, Optional -import warnings +from typing import Any, Callable, Dict, Iterable, List, Optional, Union import lux +import warnings RegisteredOption = namedtuple("RegisteredOption", "name action display_condition args") @@ -30,6 +30,55 @@ def __init__(self): self._sampling_cap = 30000 self._sampling_flag = True self._heatmap_flag = True + self._topk = 15 + self._sort = "descending" + + @property + def topk(self): + return self._topk + + @topk.setter + def topk(self, k: Union[int, bool]): + """ + Setting parameter to display top k visualizations in each action + + Parameters + ---------- + k : Union[int,bool] + False: if display all visualizations (no top-k) + k: number of visualizations to display + """ + if isinstance(k, int) or isinstance(k, bool): + self._topk = k + else: + warnings.warn( + "Parameter to lux.config.topk must be an integer or a boolean.", + stacklevel=2, + ) + + @property + def sort(self): + return self._sort + + @sort.setter + def sort(self, flag: Union[str]): + """ + Setting parameter to determine sort order of each action + + Parameters + ---------- + flag : Union[str] + "none", "ascending","descending" + No sorting, sort by ascending order, sort by descending order + """ + flag = flag.lower() + if isinstance(flag, str) and flag in ["none", "ascending", "descending"]: + self._sort = flag + else: + warnings.warn( + "Parameter to lux.config.sort must be one of the following: 'none', 'ascending', or 'descending'.", + stacklevel=2, + ) @property def sampling_cap(self): diff --git a/lux/action/correlation.py b/lux/action/correlation.py index 53cc8540..6d178e84 100644 --- a/lux/action/correlation.py +++ b/lux/action/correlation.py @@ -77,7 +77,8 @@ def correlation(ldf: LuxDataFrame, ignore_transpose: bool = True): if ignore_rec_flag: recommendation["collection"] = [] return recommendation - vlist = vlist.topK(15) + vlist.sort() + vlist = vlist.showK() recommendation["collection"] = vlist return recommendation diff --git a/lux/action/enhance.py b/lux/action/enhance.py index 94a4ea60..c6f240eb 100644 --- a/lux/action/enhance.py +++ b/lux/action/enhance.py @@ -66,6 +66,7 @@ def enhance(ldf): for vis in vlist: vis.score = interestingness(vis, ldf) - vlist = vlist.topK(15) + vlist.sort() + vlist = vlist.showK() recommendation["collection"] = vlist return recommendation diff --git a/lux/action/filter.py b/lux/action/filter.py index 6b27b843..a353d449 100644 --- a/lux/action/filter.py +++ b/lux/action/filter.py @@ -132,7 +132,8 @@ def get_complementary_ops(fltr_op): vlist_copy = lux.vis.VisList.VisList(output, ldf) for i in range(len(vlist_copy)): vlist[i].score = interestingness(vlist_copy[i], ldf) - vlist = vlist.topK(15) + vlist.sort() + vlist = vlist.showK() if recommendation["action"] == "Similarity": recommendation["collection"] = vlist[1:] else: diff --git a/lux/action/generalize.py b/lux/action/generalize.py index 91b83239..45e9d0f8 100644 --- a/lux/action/generalize.py +++ b/lux/action/generalize.py @@ -93,5 +93,6 @@ def generalize(ldf): vlist.remove_duplicates() vlist.sort(remove_invalid=True) + vlist._collection = list(filter(lambda x: x.score != -1, vlist._collection)) recommendation["collection"] = vlist return recommendation diff --git a/lux/action/univariate.py b/lux/action/univariate.py index 030a6f03..740f9105 100644 --- a/lux/action/univariate.py +++ b/lux/action/univariate.py @@ -82,7 +82,6 @@ def univariate(ldf, *args): vlist = VisList(intent, ldf) for vis in vlist: vis.score = interestingness(vis, ldf) - # vlist = vlist.topK(15) # Basic visualizations should not be capped vlist.sort() recommendation["collection"] = vlist return recommendation diff --git a/lux/vis/VisList.py b/lux/vis/VisList.py index a346e6cc..e3bdfa3e 100644 --- a/lux/vis/VisList.py +++ b/lux/vis/VisList.py @@ -233,18 +233,22 @@ def sort(self, remove_invalid=True, descending=True): # remove the items that have invalid (-1) score if remove_invalid: self._collection = list(filter(lambda x: x.score != -1, self._collection)) + if lux.config.sort == "none": + return + elif lux.config.sort == "ascending": + descending = False + elif lux.config.sort == "descending": + descending = True # sort in-place by “score” by default if available, otherwise user-specified field to sort by self._collection.sort(key=lambda x: x.score, reverse=descending) - def topK(self, k): - # sort and truncate list to first K items - self.sort(remove_invalid=True) - return VisList(self._collection[:k]) - - def bottomK(self, k): - # sort and truncate list to first K items - self.sort(descending=False, remove_invalid=True) - return VisList(self._collection[:k]) + def showK(self): + k = lux.config.topk + if k == False: + return self + elif isinstance(k, int): + k = abs(k) + return VisList(self._collection[:k]) def normalize_score(self, invert_order=False): max_score = max(list(self.get("score"))) diff --git a/tests/test_config.py b/tests/test_config.py index 644c4628..8a721d97 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -28,7 +28,8 @@ def random_categorical(ldf): vlist = VisList(intent, ldf) for vis in vlist: vis.score = 10 - vlist = vlist.topK(15) + vlist.sort() + vlist = vlist.showK() return { "action": "bars", "description": "Random list of Bar charts", @@ -105,7 +106,8 @@ def random_categorical(ldf): vlist = VisList(intent, ldf) for vis in vlist: vis.score = 10 - vlist = vlist.topK(15) + vlist.sort() + vlist = vlist.showK() return { "action": "bars", "description": "Random list of Bar charts", @@ -235,6 +237,41 @@ def test_heatmap_flag_config(): lux.config.heatmap = True +def test_topk(global_var): + df = pd.read_csv("lux/data/college.csv") + lux.config.topk = False + df._repr_html_() + assert len(df.recommendation["Correlation"]) == 45, "Turn off top K" + lux.config.topk = 20 + df = pd.read_csv("lux/data/college.csv") + df._repr_html_() + assert len(df.recommendation["Correlation"]) == 20, "Show top 20" + for vis in df.recommendation["Correlation"]: + assert vis.score > 0.2 + + +def test_sort(global_var): + df = pd.read_csv("lux/data/college.csv") + lux.config.topk = 15 + df._repr_html_() + assert len(df.recommendation["Correlation"]) == 15, "Show top 15" + for vis in df.recommendation["Correlation"]: + assert vis.score > 0.2 + df = pd.read_csv("lux/data/college.csv") + lux.config.sort = "ascending" + df._repr_html_() + assert len(df.recommendation["Correlation"]) == 15, "Show bottom 15" + for vis in df.recommendation["Correlation"]: + assert vis.score < 0.2 + + lux.config.sort = "none" + df = pd.read_csv("lux/data/college.csv") + df._repr_html_() + scorelst = [x.score for x in df.recommendation["Distribution"]] + assert sorted(scorelst) != scorelst, "unsorted setting" + lux.config.sort = "descending" + + # TODO: This test does not pass in pytest but is working in Jupyter notebook. # def test_plot_setting(global_var): # df = pytest.car_df From a495180d7ccc69f4e40a67738f8aa5e2922185f3 Mon Sep 17 00:00:00 2001 From: Doris Lee Date: Fri, 8 Jan 2021 16:58:24 +0800 Subject: [PATCH 8/8] note about regenerated config --- doc/source/reference/config.rst | 23 ++++++++++++++++++++++- 1 file changed, 22 insertions(+), 1 deletion(-) diff --git a/doc/source/reference/config.rst b/doc/source/reference/config.rst index bb5bd455..14ac5e48 100644 --- a/doc/source/reference/config.rst +++ b/doc/source/reference/config.rst @@ -2,7 +2,28 @@ Configuration Settings *********************** -In Lux, users can customize various global settings to configure the behavior of Lux through :py:class:`lux.config.Config`. This page documents some of the configurations that you can apply in Lux. +In Lux, users can customize various global settings to configure the behavior of Lux through :py:class:`lux.config.Config`. These configurations are applied across all dataframes in the session. This page documents some of the configurations that you can apply in Lux. + +.. note:: + + Lux caches past generated recommendations, so if you have already printed the dataframe in the past, the recommendations would not be regenerated with the new config properties. In order for the config properties to apply, you would need to explicitly expire the recommendations as such: + + .. code-block:: python + + df = pd.read_csv("..") + df # recommendations already generated here + + df.expire_recs() + lux.config.SOME_SETTING = "..." + df # recommendation will be generated again here + + Alternatively, you can place the config settings before you first print out the dataframe for the first time: + + .. code-block:: python + + df = pd.read_csv("..") + lux.config.SOME_SETTING = "..." + df # recommendations generated for the first time with config Change the default display of Lux