Skip to content

Commit c33bdba

Browse files
committed
Add metagraph.get_metanode
1 parent 4786dae commit c33bdba

File tree

2 files changed

+29
-5
lines changed

2 files changed

+29
-5
lines changed

hetio/hetnet.py

+21-5
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,25 @@ def __eq__(self, other):
204204
class MetaGraph(BaseGraph):
205205

206206
def __init__(self):
207-
""" """
208207
BaseGraph.__init__(self)
209208

209+
def get_metanode(self, metanode):
210+
"""
211+
Return the metanode specified by the input, which can be either a:
212+
- MetaNode (passthrough)
213+
- metanode kind (str)
214+
- metanode abbreviation (str)
215+
"""
216+
if isinstance(metanode, MetaNode):
217+
return metanode
218+
if metanode in self.node_dict:
219+
return self.get_node(metanode)
220+
# Assume metanode must be an abbreviation
221+
return self.get_node(self.abbrev_to_kind[metanode])
222+
223+
def get_metaedge(self, metanode):
224+
pass
225+
210226
@staticmethod
211227
def from_edge_tuples(metaedge_tuples, kind_to_abbrev=None):
212228
"""Create a new metagraph defined by its edges."""
@@ -231,6 +247,7 @@ def from_edge_tuples(metaedge_tuples, kind_to_abbrev=None):
231247
def set_abbreviations(self, kind_to_abbrev):
232248
"""Add abbreviations as an attribute for metanodes and metaedges"""
233249
self.kind_to_abbrev = kind_to_abbrev
250+
self.abbrev_to_kind = {v: k for k, v in kind_to_abbrev.items()}
234251
for kind, metanode in self.node_dict.items():
235252
metanode.abbrev = kind_to_abbrev[kind]
236253
for metaedge in self.edge_dict.values():
@@ -281,10 +298,9 @@ def extract_metapaths(self, source, target=None, max_length=4):
281298
the metagraph from source to target. If target is None (default), then
282299
metapaths to any target node are returned.
283300
"""
284-
if not isinstance(source, MetaNode):
285-
source = self.node_dict[source]
286-
if target and not isinstance(target, MetaNode):
287-
target = self.node_dict[target]
301+
source = self.get_metanode(source)
302+
if target:
303+
target = self.get_metanode(target)
288304

289305
assert max_length >= 0
290306
if max_length == 0:

test/graph_test.py

+8
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,14 @@ def test_disase_gene_example():
111111
('Gene', 'Gene', 'interaction', 'both'),
112112
]
113113
metagraph = hetio.hetnet.MetaGraph.from_edge_tuples(metaedge_tuples)
114+
115+
# Test metagraph getter methods
116+
gene_metanode = metagraph.node_dict['Gene']
117+
assert metagraph.get_metanode(gene_metanode) == gene_metanode
118+
assert metagraph.get_metanode('Gene') == gene_metanode
119+
assert metagraph.get_metanode('G') == gene_metanode
120+
121+
# Create graph
114122
graph = hetio.hetnet.Graph(metagraph)
115123
nodes = dict()
116124

0 commit comments

Comments
 (0)