diff --git a/thicket/ensemble.py b/thicket/ensemble.py index c129175a..1535c501 100644 --- a/thicket/ensemble.py +++ b/thicket/ensemble.py @@ -76,21 +76,21 @@ def _check_structures(): """Check that the structures of the thicket objects are valid for the incoming operations.""" # Required/expected format of the data for th in thickets: - verify_thicket_structures(th.dataframe, index=["node", "profile"]) + assert th.dataframe.index.nlevels == 2 + assert th.metadata.index.nlevels == 1 + assert th.dataframe.index.names[1] == th.metadata.index.name verify_thicket_structures(th.statsframe.dataframe, index=["node"]) - verify_thicket_structures(th.metadata, index=["profile"]) # Check for metadata_key in metadata if metadata_key: for th in thickets: - verify_thicket_structures(th.metadata, columns=[metadata_key]) + if metadata_key != th.metadata.index.name: + verify_thicket_structures(th.metadata, columns=[metadata_key]) # Check length of profiles match if metadata key is not provided if metadata_key is None: for i in range(len(thickets) - 1): if len(thickets[i].profile) != len(thickets[i + 1].profile): raise ValueError( - "Length of all thicket profiles must match if 'metadata_key' is not provided. {} != {}".format( - len(thickets[i].profile), len(thickets[i + 1].profile) - ) + f"Length of all thicket profiles must match if 'metadata_key' is not provided. {len(thickets[i].profile)} != {len(thickets[i + 1].profile)}" ) # Ensure all thickets profiles are sorted. Must be true when metadata_key=None to # guarantee performance data table and metadata table match up. @@ -120,14 +120,16 @@ def _create_multiindex_columns(df, upper_idx_name): def _handle_metadata(): """Handle operations to create new concatenated columnar axis metadata table.""" # Update index to reflect performance data table index - for i in range(len(thickets_cp)): - thickets_cp[i].metadata.reset_index(drop=True, inplace=True) + if metadata_key != inner_idx: + for i in range(len(thickets_cp)): + thickets_cp[i].metadata.reset_index(drop=True, inplace=True) if metadata_key is None: for i in range(len(thickets_cp)): thickets_cp[i].metadata.index.set_names("profile", inplace=True) else: for i in range(len(thickets_cp)): - thickets_cp[i].metadata.set_index(metadata_key, inplace=True) + if metadata_key != inner_idx: + thickets_cp[i].metadata.set_index(metadata_key, inplace=True) thickets_cp[i].metadata.sort_index(inplace=True) # Create multi-index columns @@ -179,17 +181,17 @@ def _handle_perfdata(): thickets_cp[i].metadata_column_to_perfdata( "new_profiles", drop=True ) - thickets_cp[i].dataframe.reset_index(level="profile", inplace=True) + thickets_cp[i].dataframe.reset_index(level=inner_idx, inplace=True) new_mappings.update( pd.Series( thickets_cp[i] .dataframe["new_profiles"] .map(lambda x: (x, headers[i])) .values, - index=thickets_cp[i].dataframe["profile"], + index=thickets_cp[i].dataframe[inner_idx], ).to_dict() ) - thickets_cp[i].dataframe.drop("profile", axis=1, inplace=True) + thickets_cp[i].dataframe.drop(inner_idx, axis=1, inplace=True) thickets_cp[i].dataframe.set_index( "new_profiles", append=True, inplace=True ) @@ -198,18 +200,20 @@ def _handle_perfdata(): ) else: # Change second-level index to be from metadata's "metadata_key" column for i in range(len(thickets_cp)): - thickets_cp[i].metadata_column_to_perfdata(metadata_key) - thickets_cp[i].dataframe.reset_index(level="profile", inplace=True) + if metadata_key not in thickets_cp[i].dataframe.index.names: + thickets_cp[i].metadata_column_to_perfdata(metadata_key) + thickets_cp[i].dataframe.reset_index(level=inner_idx, inplace=True) new_mappings.update( pd.Series( thickets_cp[i] .dataframe[metadata_key] .map(lambda x: (x, headers[i])) .values, - index=thickets_cp[i].dataframe["profile"], + index=thickets_cp[i].dataframe[inner_idx], ).to_dict() ) - thickets_cp[i].dataframe.drop("profile", axis=1, inplace=True) + if inner_idx != metadata_key: + thickets_cp[i].dataframe.drop(inner_idx, axis=1, inplace=True) thickets_cp[i].dataframe.set_index( metadata_key, append=True, inplace=True ) @@ -266,11 +270,12 @@ def _handle_statsframe(): ), ) - # Step 0A: Pre-check of data structures - _check_structures() - # Step 0B: Variable Initialization + # Step 0A: Variable Initialization combined_th = thickets[0].deepcopy() thickets_cp = [th.deepcopy() for th in thickets] + inner_idx = thickets_cp[0].dataframe.index.names[1] + # Step 0B: Pre-check of data structures + _check_structures() # Step 1: Unify the thickets union_graph, _thickets = Ensemble._unify(thickets_cp) diff --git a/thicket/utils.py b/thicket/utils.py index 3f2c2435..0fd09239 100644 --- a/thicket/utils.py +++ b/thicket/utils.py @@ -55,7 +55,11 @@ def verify_sorted_profile(thicket_component): thicket_component (DataFrame): component of thicket to check """ profile_index_values = list( - OrderedDict.fromkeys(thicket_component.index.get_level_values("profile")) + OrderedDict.fromkeys( + thicket_component.index.get_level_values( + thicket_component.index.nlevels - 1 + ) # Innermost index + ) ) if profile_index_values != sorted(profile_index_values): raise ValueError(