10
10
ComputeDirectEffectsInDAGsMultivariateRegression ,
11
11
)
12
12
from causy .common_pipeline_steps .calculation import CalculatePearsonCorrelations
13
+ from causy .edge_types import DirectedEdge
13
14
from causy .generators import PairsWithNeighboursGenerator
14
15
from causy .graph_model import graph_model_factory
15
16
from causy .causal_discovery .constraint .independence_tests .common import (
@@ -426,6 +427,36 @@ def test_track_triples_three_nodes_pc_unconditionally_independent(self):
426
427
# TODO: find issue with tracking in partial correlation test in this setting
427
428
pass
428
429
430
+ def test_orientation_conflict_tracking (self ):
431
+ causal_insufficiency_four_nodes = IIDSampleGenerator (
432
+ edges = [
433
+ SampleEdge (NodeReference ("U1" ), NodeReference ("X" ), 1 ),
434
+ SampleEdge (NodeReference ("U1" ), NodeReference ("Y" ), 1 ),
435
+ SampleEdge (NodeReference ("U2" ), NodeReference ("Y" ), 1 ),
436
+ SampleEdge (NodeReference ("U2" ), NodeReference ("Z" ), 1 ),
437
+ SampleEdge (NodeReference ("U3" ), NodeReference ("Z" ), 1 ),
438
+ SampleEdge (NodeReference ("U3" ), NodeReference ("V" ), 1 ),
439
+ SampleEdge (NodeReference ("U4" ), NodeReference ("V" ), 1 ),
440
+ SampleEdge (NodeReference ("U4" ), NodeReference ("X" ), 1 ),
441
+ ],
442
+ )
443
+ test_data , graph = causal_insufficiency_four_nodes .generate (10000 )
444
+ test_data .pop ("U1" )
445
+ test_data .pop ("U2" )
446
+ test_data .pop ("U3" )
447
+ test_data .pop ("U4" )
448
+ tst = PCClassic ()
449
+ tst .create_graph_from_data (test_data )
450
+ tst .create_all_possible_edges ()
451
+ tst .execute_pipeline_steps ()
452
+
453
+ nb_of_conflicts = 0
454
+ for result in tst .graph .action_history :
455
+ for proposed_action in result .all_proposed_actions :
456
+ if "orientation_conflict" in proposed_action .data :
457
+ nb_of_conflicts += 1
458
+ self .assertGreater (nb_of_conflicts , 1 )
459
+
429
460
def test_d_separation_on_output_of_pc (self ):
430
461
rdnv = self .seeded_random .normalvariate
431
462
sample_generator = IIDSampleGenerator (
@@ -447,3 +478,47 @@ def test_d_separation_on_output_of_pc(self):
447
478
z = tst .graph .node_by_id ("Z" )
448
479
self .assertEqual (tst .graph .are_nodes_d_separated_cpdag (x , z , []), False )
449
480
self .assertEqual (tst .graph .are_nodes_d_separated_cpdag (x , z , [y ]), True )
481
+
482
+ def test_pc_faithfulness_violation (self ):
483
+ rdnv = self .seeded_random .normalvariate
484
+ sample_generator = IIDSampleGenerator (
485
+ edges = [
486
+ SampleEdge (NodeReference ("X" ), NodeReference ("V" ), 2 ),
487
+ SampleEdge (NodeReference ("V" ), NodeReference ("W" ), 2 ),
488
+ SampleEdge (NodeReference ("W" ), NodeReference ("Y" ), - 2 ),
489
+ SampleEdge (NodeReference ("X" ), NodeReference ("Y" ), 8 ),
490
+ ],
491
+ random = lambda : rdnv (0 , 1 ),
492
+ )
493
+ test_data , graph = sample_generator .generate (10000 )
494
+ tst = PCClassic ()
495
+ tst .create_graph_from_data (test_data )
496
+ tst .create_all_possible_edges ()
497
+ tst .execute_pipeline_steps ()
498
+
499
+ self .assertEqual (tst .graph .edge_exists ("X" , "Y" ), False )
500
+ self .assertEqual (tst .graph .edge_exists ("V" , "Y" ), False )
501
+ self .assertEqual (tst .graph .edge_exists ("W" , "X" ), False )
502
+ self .assertEqual (tst .graph .edge_exists ("W" , "Y" ), True )
503
+ self .assertEqual (tst .graph .edge_exists ("V" , "W" ), True )
504
+ self .assertEqual (tst .graph .edge_exists ("X" , "V" ), True )
505
+
506
+ def test_noncollider_triple_rule_e2e (self ):
507
+ rdnv = self .seeded_random .normalvariate
508
+ sample_generator = IIDSampleGenerator (
509
+ edges = [
510
+ SampleEdge (NodeReference ("X" ), NodeReference ("Y" ), 2 ),
511
+ SampleEdge (NodeReference ("Z" ), NodeReference ("Y" ), 2 ),
512
+ SampleEdge (NodeReference ("Y" ), NodeReference ("W" ), 2 ),
513
+ ],
514
+ random = lambda : rdnv (0 , 1 ),
515
+ )
516
+ test_data , graph = sample_generator .generate (10000 )
517
+ tst = PCClassic ()
518
+ tst .create_graph_from_data (test_data )
519
+ tst .create_all_possible_edges ()
520
+ tst .execute_pipeline_steps ()
521
+
522
+ self .assertEqual (tst .graph .edge_of_type_exists ("X" , "Y" , DirectedEdge ()), True )
523
+ self .assertEqual (tst .graph .edge_of_type_exists ("Z" , "Y" , DirectedEdge ()), True )
524
+ self .assertEqual (tst .graph .edge_of_type_exists ("Y" , "W" , DirectedEdge ()), True )
0 commit comments