Skip to content

Commit

Permalink
fixed PR issues
Browse files Browse the repository at this point in the history
prog
  • Loading branch information
Yejashi committed May 20, 2024
1 parent 76e0f7a commit a97b34d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
21 changes: 18 additions & 3 deletions thicket/tests/test_query_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ def check_query(th_x, hnids, query):
filt_th = th_x.query_stats(query)
filt_nodes = list(filt_th.graph.traverse())

# Get statsframe nodes
sframe_nodes = filt_th.statsframe.dataframe.reset_index()["node"]

# MultiIndex check
if isinstance(th_x.statsframe.dataframe.columns, pd.MultiIndex):
assert isinstance(filt_th.statsframe.dataframe.columns, pd.MultiIndex)
Expand All @@ -49,6 +52,10 @@ def check_query(th_x, hnids, query):
th_df_profiles.unique().to_list()
)

assert len(sframe_nodes) == len(match)
assert all([n.frame in match_frames for n in sframe_nodes])
assert all([n.frame["name"] in match_names for n in sframe_nodes])

check_identity(th_x, filt_th, "default_metric")


Expand Down Expand Up @@ -95,20 +102,23 @@ def test_object_dialect_column_multi_index(rajaperf_seq_O3_1M_cali):
)
)

new_th = th_cj.query_stats(query, multi_index_mode="all")
new_th = th_cj.query_stats(query)

th.stats.mean(new_th, columns=[(0, "Min time/rank")])
th.stats.mean(new_th, columns=[(1, "Min time/rank")])

# return new_th

queried_nodes = list(new_th.graph.traverse())

# Get statsframe nodes
sframe_nodes = new_th.statsframe.dataframe.reset_index()["node"]

match_frames = list(sorted([n.frame for n in match]))
queried_frames = list(sorted([n.frame for n in queried_nodes]))

assert len(queried_nodes) == len(match)
assert all(m == q for m, q in zip(match_frames, queried_frames))

assert len(sframe_nodes) == len(match)
idx = pd.IndexSlice
assert (
(
Expand Down Expand Up @@ -152,6 +162,9 @@ def test_string_dialect_column_multi_index(rajaperf_seq_O3_1M_cali):
new_th = th_cj.query_stats(query)
queried_nodes = list(new_th.graph.traverse())

# Get statsframe nodes
sframe_nodes = new_th.statsframe.dataframe.reset_index()["node"]

match_frames = list(sorted([n.frame for n in match]))
queried_frames = list(sorted([n.frame for n in queried_nodes]))

Expand All @@ -160,6 +173,8 @@ def test_string_dialect_column_multi_index(rajaperf_seq_O3_1M_cali):

assert len(queried_nodes) == len(match)
assert all(m == q for m, q in zip(match_frames, queried_frames))

assert len(sframe_nodes) == len(match)
idx = pd.IndexSlice
assert (
(
Expand Down
28 changes: 16 additions & 12 deletions thicket/thicket.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,9 +1173,7 @@ def query(
return filtered_th.squash(update_inc_cols=update_inc_cols)
return filtered_th

def query_stats(
self, query_obj, squash=True, update_inc_cols=True, multi_index_mode="off"
):
def query_stats(self, query_obj, squash=True, update_inc_cols=True):
"""Apply a Hatchet query to the Thicket object.
Arguments:
Expand All @@ -1191,19 +1189,17 @@ def query_stats(
"""
local_query_obj = query_obj
if isinstance(query_obj, list):
local_query_obj = ObjectQuery(query_obj, multi_index_mode=multi_index_mode)
local_query_obj = ObjectQuery(query_obj, multi_index_mode="off")
elif isinstance(query_obj, str):
local_query_obj = parse_string_dialect(
query_obj, multi_index_mode=multi_index_mode
)
local_query_obj = parse_string_dialect(query_obj, multi_index_mode="off")
elif not is_hatchet_query(query_obj):
raise TypeError(
"Encountered unrecognized query type (expected Query, CompoundQuery, or AbstractQuery, got {})".format(
type(query_obj)
)
)
sframe_copy = self.statsframe.dataframe.copy()
index_names = self.statsframe.dataframe.index.names
sf_index_names = self.statsframe.dataframe.index.names
sframe_copy.reset_index(inplace=True)

query = (
Expand All @@ -1219,19 +1215,27 @@ def query_stats(
if filtered_sf_df.shape[0] == 0:
raise EmptyQuery("The provided query would have produced an empty Thicket.")

index_names = self.dataframe.index.names
df_index_names = self.dataframe.index.names
dframe_copy = self.dataframe.copy()
index_names = self.dataframe.index.names
dframe_copy.reset_index(inplace=True)

filtered_df = dframe_copy[dframe_copy["node"].isin(query_matches)]
filtered_df.set_index(index_names, inplace=True)
filtered_df.set_index(df_index_names, inplace=True)

filtered_th = self.deepcopy()
filtered_th.dataframe = filtered_df

sframe_copy = sframe_copy[
sframe_copy["node"].apply(lambda x: x in query_matches)
]
sframe_copy.set_index(sf_index_names, inplace=True)
filtered_th.statsframe.dataframe = sframe_copy

if squash:
return filtered_th.squash(update_inc_cols=update_inc_cols)
return filtered_th.squash(
update_inc_cols=update_inc_cols, new_statsframe=False
)

return filtered_th

def groupby(self, by):
Expand Down

0 comments on commit a97b34d

Please sign in to comment.