Skip to content

Commit ca02309

Browse files
Merge pull request #77 from causy-dev/track-orientation-conflicts
feat(orientation_rules): track orientation conflicts
2 parents c5e68b3 + 2997d29 commit ca02309

File tree

2 files changed

+110
-7
lines changed

2 files changed

+110
-7
lines changed

causy/causal_discovery/constraint/orientation_rules/pc.py

+35-7
Original file line numberDiff line numberDiff line change
@@ -142,15 +142,30 @@ def process(
142142
unapplied_actions, y, z
143143
)
144144
if len(unapplied_actions_y_z) > 0 or len(unapplied_actions_x_z) > 0:
145-
logger.warning(
146-
f"Orientation conflict detected in ColliderTest stage when orienting the edge between {x.name} and {y.name}. The conflict is resolved using the strategy {self.conflict_resolution_strategy}, but orientation conflicts indicate assumption violations and can severely affect the accuracy of the results.",
147-
)
148145
if (
149146
ColliderTestConflictResolutionStrategies.KEEP_FIRST
150147
is self.conflict_resolution_strategy
151148
):
152-
# We keep the first edge that was removed
153-
continue
149+
# We prioritize the first edge that was removed, we do nothing
150+
if len(unapplied_actions_y_z) > 0:
151+
results.append(
152+
TestResult(
153+
u=z,
154+
v=y,
155+
action=TestResultAction.DO_NOTHING,
156+
data={"orientation_conflict": True},
157+
)
158+
)
159+
if len(unapplied_actions_x_z) > 0:
160+
results.append(
161+
TestResult(
162+
u=z,
163+
v=x,
164+
action=TestResultAction.DO_NOTHING,
165+
data={"orientation_conflict": True},
166+
)
167+
)
168+
154169
elif (
155170
ColliderTestConflictResolutionStrategies.KEEP_LAST
156171
is self.conflict_resolution_strategy
@@ -237,7 +252,12 @@ def process(
237252
breakflag = True
238253
break
239254
if breakflag is True:
240-
continue
255+
return TestResult(
256+
u=y,
257+
v=z,
258+
action=TestResultAction.DO_NOTHING,
259+
data={"orientation_conflict": True},
260+
)
241261
return TestResult(
242262
u=y,
243263
v=z,
@@ -250,7 +270,15 @@ def process(
250270
):
251271
for node in graph.nodes:
252272
if graph.only_directed_edge_exists(graph.nodes[node], x):
253-
continue
273+
breakflag = True
274+
break
275+
if breakflag is True:
276+
return TestResult(
277+
u=x,
278+
v=z,
279+
action=TestResultAction.DO_NOTHING,
280+
data={"orientation_conflict": True},
281+
)
254282
return TestResult(
255283
u=x,
256284
v=z,

tests/test_pc_e2e.py

+75
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
ComputeDirectEffectsInDAGsMultivariateRegression,
1111
)
1212
from causy.common_pipeline_steps.calculation import CalculatePearsonCorrelations
13+
from causy.edge_types import DirectedEdge
1314
from causy.generators import PairsWithNeighboursGenerator
1415
from causy.graph_model import graph_model_factory
1516
from causy.causal_discovery.constraint.independence_tests.common import (
@@ -426,6 +427,36 @@ def test_track_triples_three_nodes_pc_unconditionally_independent(self):
426427
# TODO: find issue with tracking in partial correlation test in this setting
427428
pass
428429

430+
def test_orientation_conflict_tracking(self):
431+
causal_insufficiency_four_nodes = IIDSampleGenerator(
432+
edges=[
433+
SampleEdge(NodeReference("U1"), NodeReference("X"), 1),
434+
SampleEdge(NodeReference("U1"), NodeReference("Y"), 1),
435+
SampleEdge(NodeReference("U2"), NodeReference("Y"), 1),
436+
SampleEdge(NodeReference("U2"), NodeReference("Z"), 1),
437+
SampleEdge(NodeReference("U3"), NodeReference("Z"), 1),
438+
SampleEdge(NodeReference("U3"), NodeReference("V"), 1),
439+
SampleEdge(NodeReference("U4"), NodeReference("V"), 1),
440+
SampleEdge(NodeReference("U4"), NodeReference("X"), 1),
441+
],
442+
)
443+
test_data, graph = causal_insufficiency_four_nodes.generate(10000)
444+
test_data.pop("U1")
445+
test_data.pop("U2")
446+
test_data.pop("U3")
447+
test_data.pop("U4")
448+
tst = PCClassic()
449+
tst.create_graph_from_data(test_data)
450+
tst.create_all_possible_edges()
451+
tst.execute_pipeline_steps()
452+
453+
nb_of_conflicts = 0
454+
for result in tst.graph.action_history:
455+
for proposed_action in result.all_proposed_actions:
456+
if "orientation_conflict" in proposed_action.data:
457+
nb_of_conflicts += 1
458+
self.assertGreater(nb_of_conflicts, 1)
459+
429460
def test_d_separation_on_output_of_pc(self):
430461
rdnv = self.seeded_random.normalvariate
431462
sample_generator = IIDSampleGenerator(
@@ -447,3 +478,47 @@ def test_d_separation_on_output_of_pc(self):
447478
z = tst.graph.node_by_id("Z")
448479
self.assertEqual(tst.graph.are_nodes_d_separated_cpdag(x, z, []), False)
449480
self.assertEqual(tst.graph.are_nodes_d_separated_cpdag(x, z, [y]), True)
481+
482+
def test_pc_faithfulness_violation(self):
483+
rdnv = self.seeded_random.normalvariate
484+
sample_generator = IIDSampleGenerator(
485+
edges=[
486+
SampleEdge(NodeReference("X"), NodeReference("V"), 2),
487+
SampleEdge(NodeReference("V"), NodeReference("W"), 2),
488+
SampleEdge(NodeReference("W"), NodeReference("Y"), -2),
489+
SampleEdge(NodeReference("X"), NodeReference("Y"), 8),
490+
],
491+
random=lambda: rdnv(0, 1),
492+
)
493+
test_data, graph = sample_generator.generate(10000)
494+
tst = PCClassic()
495+
tst.create_graph_from_data(test_data)
496+
tst.create_all_possible_edges()
497+
tst.execute_pipeline_steps()
498+
499+
self.assertEqual(tst.graph.edge_exists("X", "Y"), False)
500+
self.assertEqual(tst.graph.edge_exists("V", "Y"), False)
501+
self.assertEqual(tst.graph.edge_exists("W", "X"), False)
502+
self.assertEqual(tst.graph.edge_exists("W", "Y"), True)
503+
self.assertEqual(tst.graph.edge_exists("V", "W"), True)
504+
self.assertEqual(tst.graph.edge_exists("X", "V"), True)
505+
506+
def test_noncollider_triple_rule_e2e(self):
507+
rdnv = self.seeded_random.normalvariate
508+
sample_generator = IIDSampleGenerator(
509+
edges=[
510+
SampleEdge(NodeReference("X"), NodeReference("Y"), 2),
511+
SampleEdge(NodeReference("Z"), NodeReference("Y"), 2),
512+
SampleEdge(NodeReference("Y"), NodeReference("W"), 2),
513+
],
514+
random=lambda: rdnv(0, 1),
515+
)
516+
test_data, graph = sample_generator.generate(10000)
517+
tst = PCClassic()
518+
tst.create_graph_from_data(test_data)
519+
tst.create_all_possible_edges()
520+
tst.execute_pipeline_steps()
521+
522+
self.assertEqual(tst.graph.edge_of_type_exists("X", "Y", DirectedEdge()), True)
523+
self.assertEqual(tst.graph.edge_of_type_exists("Z", "Y", DirectedEdge()), True)
524+
self.assertEqual(tst.graph.edge_of_type_exists("Y", "W", DirectedEdge()), True)

0 commit comments

Comments
 (0)