@@ -204,9 +204,36 @@ def __eq__(self, other):
204
204
class MetaGraph (BaseGraph ):
205
205
206
206
def __init__ (self ):
207
- """ """
208
207
BaseGraph .__init__ (self )
209
208
209
+ def get_metanode (self , metanode ):
210
+ """
211
+ Return the metanode specified by the input, which can be either a:
212
+ - MetaNode (passthrough)
213
+ - metanode kind (str)
214
+ - metanode abbreviation (str)
215
+ """
216
+ if isinstance (metanode , MetaNode ):
217
+ return metanode
218
+ if metanode in self .node_dict :
219
+ return self .get_node (metanode )
220
+ # Assume metanode must be an abbreviation
221
+ return self .get_node (self .abbrev_to_kind [metanode ])
222
+
223
+ def get_metaedge (self , metaedge ):
224
+ """
225
+ Return the metaedge specified by the input, which can be either a:
226
+ - MetaEdge (passthrough)
227
+ - metaedge_id (tuple)
228
+ - metaedge abbreviation
229
+ """
230
+ if isinstance (metaedge , MetaEdge ):
231
+ return metaedge
232
+ if isinstance (metaedge , tuple ):
233
+ return self .get_edge (metaedge )
234
+ metaedge_id = hetio .abbreviation .metaedge_id_from_abbreviation (self , metaedge )
235
+ return self .get_edge (metaedge_id )
236
+
210
237
@staticmethod
211
238
def from_edge_tuples (metaedge_tuples , kind_to_abbrev = None ):
212
239
"""Create a new metagraph defined by its edges."""
@@ -231,6 +258,7 @@ def from_edge_tuples(metaedge_tuples, kind_to_abbrev=None):
231
258
def set_abbreviations (self , kind_to_abbrev ):
232
259
"""Add abbreviations as an attribute for metanodes and metaedges"""
233
260
self .kind_to_abbrev = kind_to_abbrev
261
+ self .abbrev_to_kind = {v : k for k , v in kind_to_abbrev .items ()}
234
262
for kind , metanode in self .node_dict .items ():
235
263
metanode .abbrev = kind_to_abbrev [kind ]
236
264
for metaedge in self .edge_dict .values ():
@@ -281,10 +309,9 @@ def extract_metapaths(self, source, target=None, max_length=4):
281
309
the metagraph from source to target. If target is None (default), then
282
310
metapaths to any target node are returned.
283
311
"""
284
- if not isinstance (source , MetaNode ):
285
- source = self .node_dict [source ]
286
- if target and not isinstance (target , MetaNode ):
287
- target = self .node_dict [target ]
312
+ source = self .get_metanode (source )
313
+ if target :
314
+ target = self .get_metanode (target )
288
315
289
316
assert max_length >= 0
290
317
if max_length == 0 :
0 commit comments