@@ -204,9 +204,25 @@ def __eq__(self, other):
204
204
class MetaGraph (BaseGraph ):
205
205
206
206
def __init__ (self ):
207
- """ """
208
207
BaseGraph .__init__ (self )
209
208
209
+ def get_metanode (self , metanode ):
210
+ """
211
+ Return the metanode specified by the input, which can be either a:
212
+ - MetaNode (passthrough)
213
+ - metanode kind (str)
214
+ - metanode abbreviation (str)
215
+ """
216
+ if isinstance (metanode , MetaNode ):
217
+ return metanode
218
+ if metanode in self .node_dict :
219
+ return self .get_node (metanode )
220
+ # Assume metanode must be an abbreviation
221
+ return self .get_node (self .abbrev_to_kind [metanode ])
222
+
223
+ def get_metaedge (self , metanode ):
224
+ pass
225
+
210
226
@staticmethod
211
227
def from_edge_tuples (metaedge_tuples , kind_to_abbrev = None ):
212
228
"""Create a new metagraph defined by its edges."""
@@ -231,6 +247,7 @@ def from_edge_tuples(metaedge_tuples, kind_to_abbrev=None):
231
247
def set_abbreviations (self , kind_to_abbrev ):
232
248
"""Add abbreviations as an attribute for metanodes and metaedges"""
233
249
self .kind_to_abbrev = kind_to_abbrev
250
+ self .abbrev_to_kind = {v : k for k , v in kind_to_abbrev .items ()}
234
251
for kind , metanode in self .node_dict .items ():
235
252
metanode .abbrev = kind_to_abbrev [kind ]
236
253
for metaedge in self .edge_dict .values ():
@@ -281,10 +298,9 @@ def extract_metapaths(self, source, target=None, max_length=4):
281
298
the metagraph from source to target. If target is None (default), then
282
299
metapaths to any target node are returned.
283
300
"""
284
- if not isinstance (source , MetaNode ):
285
- source = self .node_dict [source ]
286
- if target and not isinstance (target , MetaNode ):
287
- target = self .node_dict [target ]
301
+ source = self .get_metanode (source )
302
+ if target :
303
+ target = self .get_metanode (target )
288
304
289
305
assert max_length >= 0
290
306
if max_length == 0 :
0 commit comments