From a97b34dc8f8ba68fcc53fc754f729ac366695900 Mon Sep 17 00:00:00 2001 From: yejashi Date: Mon, 20 May 2024 08:06:12 -0700 Subject: [PATCH] fixed PR issues prog --- thicket/tests/test_query_stats.py | 21 ++++++++++++++++++--- thicket/thicket.py | 28 ++++++++++++++++------------ 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/thicket/tests/test_query_stats.py b/thicket/tests/test_query_stats.py index 58572abb..4ef8a36a 100644 --- a/thicket/tests/test_query_stats.py +++ b/thicket/tests/test_query_stats.py @@ -33,6 +33,9 @@ def check_query(th_x, hnids, query): filt_th = th_x.query_stats(query) filt_nodes = list(filt_th.graph.traverse()) + # Get statsframe nodes + sframe_nodes = filt_th.statsframe.dataframe.reset_index()["node"] + # MultiIndex check if isinstance(th_x.statsframe.dataframe.columns, pd.MultiIndex): assert isinstance(filt_th.statsframe.dataframe.columns, pd.MultiIndex) @@ -49,6 +52,10 @@ def check_query(th_x, hnids, query): th_df_profiles.unique().to_list() ) + assert len(sframe_nodes) == len(match) + assert all([n.frame in match_frames for n in sframe_nodes]) + assert all([n.frame["name"] in match_names for n in sframe_nodes]) + check_identity(th_x, filt_th, "default_metric") @@ -95,20 +102,23 @@ def test_object_dialect_column_multi_index(rajaperf_seq_O3_1M_cali): ) ) - new_th = th_cj.query_stats(query, multi_index_mode="all") + new_th = th_cj.query_stats(query) th.stats.mean(new_th, columns=[(0, "Min time/rank")]) th.stats.mean(new_th, columns=[(1, "Min time/rank")]) - # return new_th - queried_nodes = list(new_th.graph.traverse()) + # Get statsframe nodes + sframe_nodes = new_th.statsframe.dataframe.reset_index()["node"] + match_frames = list(sorted([n.frame for n in match])) queried_frames = list(sorted([n.frame for n in queried_nodes])) assert len(queried_nodes) == len(match) assert all(m == q for m, q in zip(match_frames, queried_frames)) + + assert len(sframe_nodes) == len(match) idx = pd.IndexSlice assert ( ( @@ -152,6 +162,9 @@ def test_string_dialect_column_multi_index(rajaperf_seq_O3_1M_cali): new_th = th_cj.query_stats(query) queried_nodes = list(new_th.graph.traverse()) + # Get statsframe nodes + sframe_nodes = new_th.statsframe.dataframe.reset_index()["node"] + match_frames = list(sorted([n.frame for n in match])) queried_frames = list(sorted([n.frame for n in queried_nodes])) @@ -160,6 +173,8 @@ def test_string_dialect_column_multi_index(rajaperf_seq_O3_1M_cali): assert len(queried_nodes) == len(match) assert all(m == q for m, q in zip(match_frames, queried_frames)) + + assert len(sframe_nodes) == len(match) idx = pd.IndexSlice assert ( ( diff --git a/thicket/thicket.py b/thicket/thicket.py index 6c0b2ded..7d01ed64 100644 --- a/thicket/thicket.py +++ b/thicket/thicket.py @@ -1173,9 +1173,7 @@ def query( return filtered_th.squash(update_inc_cols=update_inc_cols) return filtered_th - def query_stats( - self, query_obj, squash=True, update_inc_cols=True, multi_index_mode="off" - ): + def query_stats(self, query_obj, squash=True, update_inc_cols=True): """Apply a Hatchet query to the Thicket object. Arguments: @@ -1191,11 +1189,9 @@ def query_stats( """ local_query_obj = query_obj if isinstance(query_obj, list): - local_query_obj = ObjectQuery(query_obj, multi_index_mode=multi_index_mode) + local_query_obj = ObjectQuery(query_obj, multi_index_mode="off") elif isinstance(query_obj, str): - local_query_obj = parse_string_dialect( - query_obj, multi_index_mode=multi_index_mode - ) + local_query_obj = parse_string_dialect(query_obj, multi_index_mode="off") elif not is_hatchet_query(query_obj): raise TypeError( "Encountered unrecognized query type (expected Query, CompoundQuery, or AbstractQuery, got {})".format( @@ -1203,7 +1199,7 @@ def query_stats( ) ) sframe_copy = self.statsframe.dataframe.copy() - index_names = self.statsframe.dataframe.index.names + sf_index_names = self.statsframe.dataframe.index.names sframe_copy.reset_index(inplace=True) query = ( @@ -1219,19 +1215,27 @@ def query_stats( if filtered_sf_df.shape[0] == 0: raise EmptyQuery("The provided query would have produced an empty Thicket.") - index_names = self.dataframe.index.names + df_index_names = self.dataframe.index.names dframe_copy = self.dataframe.copy() - index_names = self.dataframe.index.names dframe_copy.reset_index(inplace=True) filtered_df = dframe_copy[dframe_copy["node"].isin(query_matches)] - filtered_df.set_index(index_names, inplace=True) + filtered_df.set_index(df_index_names, inplace=True) filtered_th = self.deepcopy() filtered_th.dataframe = filtered_df + sframe_copy = sframe_copy[ + sframe_copy["node"].apply(lambda x: x in query_matches) + ] + sframe_copy.set_index(sf_index_names, inplace=True) + filtered_th.statsframe.dataframe = sframe_copy + if squash: - return filtered_th.squash(update_inc_cols=update_inc_cols) + return filtered_th.squash( + update_inc_cols=update_inc_cols, new_statsframe=False + ) + return filtered_th def groupby(self, by):