3
3
from typing import Any , Dict , List , Tuple , Type
4
4
5
5
from bigtree .node .dagnode import DAGNode
6
+ from bigtree .utils .assertions import (
7
+ assert_dataframe_no_duplicate_attribute ,
8
+ assert_dataframe_not_empty ,
9
+ assert_dictionary_not_empty ,
10
+ assert_length_not_empty ,
11
+ filter_attributes ,
12
+ isnull ,
13
+ )
6
14
from bigtree .utils .exceptions import optional_dependencies_pandas
7
15
8
16
try :
@@ -35,15 +43,15 @@ def list_to_dag(
35
43
Returns:
36
44
(DAGNode)
37
45
"""
38
- if not len (relations ):
39
- raise ValueError ("Input list does not contain any data, check `relations`" )
46
+ assert_length_not_empty (relations , "Input list" , "relations" )
40
47
41
48
relation_data = pd .DataFrame (relations , columns = ["parent" , "child" ])
42
49
return dataframe_to_dag (
43
50
relation_data , child_col = "child" , parent_col = "parent" , node_type = node_type
44
51
)
45
52
46
53
54
+ @optional_dependencies_pandas
47
55
def dict_to_dag (
48
56
relation_attrs : Dict [str , Any ],
49
57
parent_key : str = "parents" ,
@@ -75,8 +83,7 @@ def dict_to_dag(
75
83
Returns:
76
84
(DAGNode)
77
85
"""
78
- if not len (relation_attrs ):
79
- raise ValueError ("Dictionary does not contain any data, check `relation_attrs`" )
86
+ assert_dictionary_not_empty (relation_attrs , "relation_attrs" )
80
87
81
88
# Convert dictionary to dataframe
82
89
data = pd .DataFrame (relation_attrs ).T .rename_axis ("_tmp_child" ).reset_index ()
@@ -110,6 +117,8 @@ def dataframe_to_dag(
110
117
- If columns are not specified, `child_col` takes first column, `parent_col` takes second column, and all other
111
118
columns are `attribute_cols`.
112
119
120
+ Only attributes in `attribute_cols` with non-null values will be added to the tree.
121
+
113
122
Examples:
114
123
>>> import pandas as pd
115
124
>>> from bigtree import dataframe_to_dag, dag_iterator
@@ -141,12 +150,7 @@ def dataframe_to_dag(
141
150
Returns:
142
151
(DAGNode)
143
152
"""
144
- data = data .copy ()
145
-
146
- if not len (data .columns ):
147
- raise ValueError ("Data does not contain any columns, check `data`" )
148
- if not len (data ):
149
- raise ValueError ("Data does not contain any rows, check `data`" )
153
+ assert_dataframe_not_empty (data )
150
154
151
155
if not child_col :
152
156
child_col = data .columns [0 ]
@@ -160,27 +164,12 @@ def dataframe_to_dag(
160
164
attribute_cols = list (data .columns )
161
165
attribute_cols .remove (child_col )
162
166
attribute_cols .remove (parent_col )
163
- elif any ([col not in data .columns for col in attribute_cols ]):
164
- raise ValueError (
165
- f"One or more attribute column(s) not in data, check `attribute_cols`: { attribute_cols } "
166
- )
167
167
168
- data_check = data .copy ()[[child_col , parent_col ] + attribute_cols ].drop_duplicates (
169
- subset = [child_col ] + attribute_cols
170
- )
171
- _duplicate_check = (
172
- data_check [child_col ]
173
- .value_counts ()
174
- .to_frame ("counts" )
175
- .rename_axis (child_col )
176
- .reset_index ()
168
+ data = data [[child_col , parent_col ] + attribute_cols ].copy ()
169
+
170
+ assert_dataframe_no_duplicate_attribute (
171
+ data , "child name" , child_col , attribute_cols
177
172
)
178
- _duplicate_check = _duplicate_check [_duplicate_check ["counts" ] > 1 ]
179
- if len (_duplicate_check ):
180
- raise ValueError (
181
- f"There exists duplicate child name with different attributes\n "
182
- f"Check { _duplicate_check } "
183
- )
184
173
if sum (data [child_col ].isnull ()):
185
174
raise ValueError (f"Child name cannot be empty, check column: { child_col } " )
186
175
@@ -190,15 +179,14 @@ def dataframe_to_dag(
190
179
for row in data .reset_index (drop = True ).to_dict (orient = "index" ).values ():
191
180
child_name = row [child_col ]
192
181
parent_name = row [parent_col ]
193
- node_attrs = row .copy ()
194
- del node_attrs [child_col ]
195
- del node_attrs [parent_col ]
196
- node_attrs = {k : v for k , v in node_attrs .items () if not pd .isnull (v )}
197
- child_node = node_dict .get (child_name , node_type (child_name ))
182
+ node_attrs = filter_attributes (
183
+ row , omit_keys = ["name" , child_col , parent_col ], omit_null_values = True
184
+ )
185
+ child_node = node_dict .get (child_name , node_type (child_name , ** node_attrs ))
198
186
child_node .set_attrs (node_attrs )
199
187
node_dict [child_name ] = child_node
200
188
201
- if not pd . isnull (parent_name ):
189
+ if not isnull (parent_name ):
202
190
parent_node = node_dict .get (parent_name , node_type (parent_name ))
203
191
node_dict [parent_name ] = parent_node
204
192
child_node .parents = [parent_node ]
0 commit comments