Skip to content

Commit ad6f051

Browse files
committed
.
1 parent aeec225 commit ad6f051

File tree

3 files changed

+18
-8
lines changed

3 files changed

+18
-8
lines changed

causalspyne/dag_viewer.py

+12
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ def __init__(self, dag, rng=default_rng(0)):
5757
self._list_nodes2hide = None
5858
self._success = False
5959

60+
@property
61+
def dag(self):
62+
return self._dag
63+
6064
def run(self, num_samples, list_nodes2hide=None, confound=False):
6165
"""
6266
generate subgraph adjcency matrix and corresponding data
@@ -173,6 +177,14 @@ def to_csv(self, title="data_subdag.csv"):
173177
def list_global_inds_nodes2hide(self):
174178
return self._list_global_inds_unobserved
175179

180+
@property
181+
def col_inds(self):
182+
subview_global_inds = \
183+
[self.dag._dict_node_names2ind[name]
184+
for name in self.dag.list_node_names
185+
if name not in self.str_node2hide]
186+
return subview_global_inds
187+
176188
@property
177189
def list_global_inds_observed(self):
178190
if self._list_global_inds_observed is None:

causalspyne/main.py

+1-5
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,4 @@ def re_hide(subview, dag, num_sample, list_confounder2hide, output_dir,
116116
",".join(str(node) for node in
117117
subview._list_global_inds_unobserved)
118118
)
119-
subview_global_inds = [dag._dict_node_names2ind[name]
120-
for name in dag.list_node_names if
121-
name not in str_node2hide]
122-
123-
return subview, dag, subview_global_inds
119+
return subview

tests/test_utils_ancestral_acc.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ def test_ancestral_acc():
1111
"""
1212
test if ancestral accuracy runs
1313
"""
14-
arr_data, dag, subview_global_inds = gen_partially_observed(
14+
subview = gen_partially_observed(
1515
size_micro_node_dag=3,
1616
num_macro_nodes=2,
1717
degree=2, # average vertex/node degree
@@ -23,8 +23,10 @@ def test_ancestral_acc():
2323
rng=np.random.default_rng(1),
2424
graphviz=False
2525
)
26-
pred_order_inds = [dag._dict_node_names2ind[name] for name in dag.list_node_names]
27-
acc = ancestral_acc(true_dag=dag,
26+
pred_order_inds = \
27+
[subview.dag._dict_node_names2ind[name]
28+
for name in subview.dag.list_node_names]
29+
acc = ancestral_acc(true_dag=subview.dag,
2830
pred_order=pred_order_inds)
2931

3032
print(f"ancestral acc: {acc}")

0 commit comments

Comments
 (0)