Skip to content

Commit 36b3afc

Browse files
authored
Merge pull request #1266 from mito-ds/parser-index-labels-bug
Fix bug in parsing specific index labels
2 parents 3c1a6b8 + 65a56dc commit 36b3afc

File tree

3 files changed

+48
-7
lines changed

3 files changed

+48
-7
lines changed

mitosheet/mitosheet/parser.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -1045,12 +1045,7 @@ def parse_formula(
10451045
if is_datetime_index(df.index) and is_prev_version(get_pandas_version(), '1.0.0'):
10461046
index_labels = pd.to_datetime(index_labels)
10471047

1048-
if len(column_header_dependencies) > 0:
1049-
final_set_code = f'({code_with_functions}).loc[{get_column_header_list_as_transpiled_code(index_labels)}]' # type: ignore
1050-
else:
1051-
final_set_code = f'{code_with_functions}'
1052-
1053-
final_code = f'{df_name}.loc[{get_column_header_list_as_transpiled_code(index_labels)}, [{transpiled_column_header}]] = {final_set_code}' # type: ignore
1048+
final_code = f'{df_name}.loc[{get_column_header_list_as_transpiled_code(index_labels)}, [{transpiled_column_header}]] = {code_with_functions}' # type: ignore
10541049

10551050
else:
10561051
final_code = f'{code_with_functions}'

mitosheet/mitosheet/tests/step_performers/column_steps/test_set_column_formula.py

+14
Original file line numberDiff line numberDiff line change
@@ -535,6 +535,20 @@ def test_set_specific_index_labels_twice():
535535

536536
assert mito.dfs[0].equals(pd.DataFrame({'A': [1, 2, 3], 'B': [2, 0, 0]}))
537537

538+
def test_set_specific_index_labels_to_header_header():
539+
mito = create_mito_wrapper(pd.DataFrame({'A': [1, 2, 3]}))
540+
mito.add_column(0, 'B')
541+
mito.set_formula('=SUM(A:A)', 0, 'B', index_labels=[0])
542+
543+
assert mito.dfs[0].equals(pd.DataFrame({'A': [1, 2, 3], 'B': [6, 0, 0]}))
544+
545+
546+
def test_set_specific_index_labels_to_header_header_multiple_header_dependencies():
547+
mito = create_mito_wrapper(pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}))
548+
mito.add_column(0, 'C')
549+
mito.set_formula('=SUM(A:B)', 0, 'C', index_labels=[0])
550+
551+
assert mito.dfs[0].equals(pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6], 'C': [21, 0, 0]}))
538552

539553
CROSS_SHEET_TESTS = [
540554
(

mitosheet/mitosheet/tests/test_parse.py

+33-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from mitosheet.errors import MitoError
1313
from mitosheet.parser import get_backend_formula_from_frontend_formula, parse_formula, safe_contains, get_frontend_formula
14-
from mitosheet.types import FORMULA_ENTIRE_COLUMN_TYPE
14+
from mitosheet.types import FORMULA_ENTIRE_COLUMN_TYPE, FORMULA_SPECIFIC_INDEX_LABELS_TYPE
1515
from mitosheet.tests.decorators import pandas_post_1_2_only
1616

1717

@@ -1274,6 +1274,38 @@ def test_parse_cross_sheet_formulas(formula, column_header, formula_label, dfs,
12741274
)
12751275
]
12761276

1277+
def test_specific_index_labels_header_header():
1278+
formula = '=SUM(A:A)'
1279+
column_header = 'B'
1280+
formula_label = 0
1281+
df = pd.DataFrame(get_number_data_for_df(['A', 'B'], 2), index=pd.RangeIndex(0, 2))
1282+
python_code = 'df.loc[[0], [\'B\']] = SUM(df[[\'A\']])'
1283+
functions = set(['SUM'])
1284+
columns = set(['A'])
1285+
code, funcs, cols, _ = parse_formula(formula, column_header, formula_label, {'type': FORMULA_SPECIFIC_INDEX_LABELS_TYPE, 'index_labels': [0]}, [df], ['df'], 0)
1286+
assert (code, funcs, cols) == \
1287+
(
1288+
python_code,
1289+
functions,
1290+
columns
1291+
)
1292+
1293+
def test_specific_index_labels_header_header_multiple_header_dependencies():
1294+
formula = '=SUM(A:B)'
1295+
column_header = 'C'
1296+
formula_label = 0
1297+
df = pd.DataFrame(get_number_data_for_df(['A', 'B', 'C'], 2), index=pd.RangeIndex(0, 2))
1298+
python_code = 'df.loc[[0], [\'C\']] = SUM(df.loc[:, \'A\':\'B\'])'
1299+
functions = set(['SUM'])
1300+
columns = set(['A', 'B'])
1301+
code, funcs, cols, _ = parse_formula(formula, column_header, formula_label, {'type': FORMULA_SPECIFIC_INDEX_LABELS_TYPE, 'index_labels': [0]}, [df], ['df'], 0)
1302+
assert (code, funcs, cols) == \
1303+
(
1304+
python_code,
1305+
functions,
1306+
columns
1307+
)
1308+
12771309
@pytest.mark.parametrize("formula,column_header,formula_label,dfs,df_names,sheet_index,python_code,functions,columns", POST_PD_1_2_VLOOKUP_TESTS)
12781310
@pandas_post_1_2_only
12791311
def post_pandas_1_2_cross_sheet_tests(formula,column_header,formula_label,dfs,df_names,sheet_index,python_code,functions,columns):

0 commit comments

Comments
 (0)