Skip to content

Commit 98f7275

Browse files
allow to create edges after directed edges
1 parent a14b12b commit 98f7275

File tree

2 files changed

+37
-19
lines changed

2 files changed

+37
-19
lines changed

causy/graph.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,28 @@ def get_edge(self, u: Node, v: Node) -> Edge:
650650
raise GraphError(f"Edge {u} -> {v} does not exist")
651651
return self.edges[u.id][v.id]
652652

653+
def _init_edge(self, u: Node, v: Node):
654+
"""
655+
Initialize an edge between two nodes
656+
:param u:
657+
:param v:
658+
:return:
659+
"""
660+
661+
if u.id not in self.edges:
662+
self.edges[u.id] = self.__init_dict()
663+
self._reverse_edges[u.id] = self.__init_dict()
664+
self._deleted_edges[u.id] = self.__init_dict()
665+
if v.id not in self.edges:
666+
self.edges[v.id] = self.__init_dict()
667+
self._reverse_edges[v.id] = self.__init_dict()
668+
self._deleted_edges[v.id] = self.__init_dict()
669+
670+
if (u.id, v.id) not in self.edge_history:
671+
self.edge_history[(u.id, v.id)] = []
672+
if (v.id, u.id) not in self.edge_history:
673+
self.edge_history[(v.id, u.id)] = []
674+
653675
def add_edge(
654676
self,
655677
u: Node,
@@ -674,14 +696,7 @@ def add_edge(
674696
if u.id == v.id:
675697
raise GraphError("Self loops are currently not allowed")
676698

677-
if u.id not in self.edges:
678-
self.edges[u.id] = self.__init_dict()
679-
self._reverse_edges[u.id] = self.__init_dict()
680-
self._deleted_edges[u.id] = self.__init_dict()
681-
if v.id not in self.edges:
682-
self.edges[v.id] = self.__init_dict()
683-
self._reverse_edges[v.id] = self.__init_dict()
684-
self._deleted_edges[v.id] = self.__init_dict()
699+
self._init_edge(u, v)
685700

686701
a_edge = Edge(u=u, v=v, edge_type=edge_type, metadata=metadata)
687702
self.edges[u.id][v.id] = a_edge
@@ -691,9 +706,6 @@ def add_edge(
691706
self.edges[v.id][u.id] = b_edge
692707
self._reverse_edges[u.id][v.id] = b_edge
693708

694-
self.edge_history[(u.id, v.id)] = []
695-
self.edge_history[(v.id, u.id)] = []
696-
697709
def add_directed_edge(
698710
self,
699711
u: Node,
@@ -718,19 +730,13 @@ def add_directed_edge(
718730
if u.id == v.id:
719731
raise GraphError("Self loops are currently not allowed")
720732

721-
if u.id not in self.edges:
722-
self.edges[u.id] = self.__init_dict()
723-
self._deleted_edges[u.id] = self.__init_dict()
724-
if v.id not in self._reverse_edges:
725-
self._reverse_edges[v.id] = self.__init_dict()
733+
self._init_edge(u, v)
726734

727735
edge = Edge(u=u, v=v, edge_type=edge_type, metadata=metadata)
728736

729737
self.edges[u.id][v.id] = edge
730738
self._reverse_edges[v.id][u.id] = edge
731739

732-
self.edge_history[(u.id, v.id)] = []
733-
734740
def add_edge_history(self, u, v, action: TestResult):
735741
"""
736742
Add an action to the edge history

tests/test_graph.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,23 @@ def test_add_directed_edge(self):
5757
node2 = graph.add_node("test2", [1, 2, 3])
5858
graph.add_directed_edge(node1, node2, {"test": "test"})
5959
self.assertEqual(len(graph.nodes), 2)
60-
self.assertEqual(len(graph.edges), 1)
6160
self.assertEqual(graph.edge_value(node1, node2), {"test": "test"})
6261
self.assertTrue(graph.directed_edge_exists(node1, node2))
6362
self.assertFalse(graph.directed_edge_exists(node2, node1))
6463

64+
def test_create_undirected_edge_after_directed_edge(self):
65+
graph = GraphManager()
66+
node1 = graph.add_node("test1", [1, 2, 3])
67+
node2 = graph.add_node("test2", [1, 2, 3])
68+
graph.add_directed_edge(node1, node2, {"test": "test"})
69+
graph.add_edge(node1, node2, {"test": "test"})
70+
self.assertEqual(len(graph.nodes), 2)
71+
self.assertEqual(graph.edge_value(node1, node2), {"test": "test"})
72+
self.assertTrue(graph.undirected_edge_exists(node1, node2))
73+
self.assertTrue(graph.undirected_edge_exists(node2, node1))
74+
self.assertTrue(graph.edge_exists(node1, node2))
75+
self.assertTrue(graph.edge_exists(node2, node1))
76+
6577
def test_add_edge_with_non_existing_node(self):
6678
graph = GraphManager()
6779
node1 = graph.add_node("test1", [1, 2, 3])

0 commit comments

Comments
 (0)