@@ -270,145 +270,6 @@ std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
270
270
return nnOp;
271
271
}
272
272
273
- void handleWhileOp (
274
- repr::NNGraph& dfg,
275
- repr::NNCFGraph& cfg,
276
- repr::NNGraph::NodeRef& opNode,
277
- repr::NNCFGraph::NodeRef& bbNode,
278
- OperatorDef& op,
279
- std::unordered_map<std::string, repr::NNGraph::NodeRef>& blobMap
280
- ) {
281
- opNode->resetData (util::make_unique<repr::While>());
282
- auto argMap = Converter::getArgumentsFromOperator (op);
283
- std::string bodyNetSerialized = argMap[" body" ].s ();
284
- auto bodyNet = caffe2::NetDef ();
285
- bodyNet.ParseFromString (bodyNetSerialized);
286
-
287
- std::unordered_map<std::string, repr::NNGraph::NodeRef> bodyBlobMap;
288
- auto bodyNN = convertToNNModule (bodyNet, &bodyBlobMap);
289
- repr::NNGraph bodyGraph = std::move (bodyNN.dataFlow );
290
- repr::NNCFGraph bodyCFGraph = std::move (bodyNN.controlFlow );
291
-
292
- auto rev_sorted = algorithm::tarjans (&bodyGraph);
293
-
294
- for (auto & k : bodyBlobMap) {
295
- auto name = k.first ;
296
- if (blobMap.count (name)) {
297
- auto oldNode = blobMap[name];
298
- printf (" Exit tensor %s is in the parent scope, inserting Phi node...\n " , k.first .c_str ());
299
- auto phiNode = dfg.createNode (util::make_unique<repr::NNPhi>()); // NN variant of a Phi node
300
- // Clone the operator.
301
- auto tensor = dyn_cast<repr::NeuralNetData>(blobMap[name]->data ().get ());
302
- auto * clonedTensor = tensor->clone ();
303
- auto phiOut = dfg.createNode (std::unique_ptr<repr::NeuralNetData>(clonedTensor));
304
- dfg.createEdge (phiNode, phiOut);
305
- dfg.createEdge (oldNode, phiNode);
306
- dfg.createEdge (bodyBlobMap[name], phiNode);
307
- blobMap[name] = phiOut;
308
- for (auto & inEdge : opNode->getInEdges ()) {
309
- if (inEdge->tail () == oldNode) {
310
- dfg.deleteEdge (inEdge);
311
- dfg.createEdge (phiOut, opNode);
312
- }
313
- }
314
- }
315
- }
316
-
317
- // Dependencies simply have no producers
318
- std::unordered_map<repr::NNGraph::NodeRef, repr::NNGraph::NodeRef> inNodeMap;
319
- for (auto & n : bodyGraph.getMutableNodes ()) {
320
- if (!isa<repr::NeuralNetData>(n->data ())) { continue ; }
321
- if (n->getInEdges ().size () == 0 ) {
322
- auto name = dyn_cast<repr::NeuralNetData>(n->data ().get ())->getName ();
323
- // TODO(bwasti): this may be needed, depending on constraints
324
- // assert(blobMap.count(name) != 0 && "Loop body takes undefined dependency.");
325
- if (blobMap.count (name)) {
326
- inNodeMap[n] = blobMap[name];
327
- }
328
- }
329
- }
330
-
331
- CAFFE_ENFORCE (rev_sorted.front ().getNodes ().size () == 1 ,
332
- " More than one exit node." );
333
- CAFFE_ENFORCE (rev_sorted.back ().getNodes ().size () == 1 ,
334
- " More than one entry node." );
335
-
336
- auto exit_tensor = *(rev_sorted.front ().getNodes ().begin ());
337
- CAFFE_ENFORCE (isa<repr::NeuralNetData>(exit_tensor->data ()),
338
- " Exit node is not a tensor." );
339
-
340
- auto bodyNodes = bodyGraph.getMutableNodes ();
341
- auto bodyEdges = bodyGraph.getMutableEdges ();
342
-
343
- for (auto node : bodyNodes) {
344
- bodyGraph.importNode (node, dfg);
345
- }
346
-
347
- for (auto edge : bodyEdges) {
348
- bodyGraph.importEdge (edge, dfg);
349
- }
350
-
351
- // Merge all dependencies
352
- for (auto node : dfg.getMutableNodes ()) {
353
- if (inNodeMap.count (node)) {
354
- dfg.replaceNode (node, inNodeMap[node]);
355
- dfg.deleteNode (node);
356
- }
357
- }
358
-
359
- for (const auto & inEdge : opNode->getInEdges ()) {
360
- auto * inputData = dyn_cast<repr::NeuralNetData>(inEdge->tail ()->data ().get ());
361
- auto * exitData = dyn_cast<repr::NeuralNetData>(exit_tensor->data ().get ());
362
- if (inputData->getName () == exitData->getName ()) {
363
- dfg.replaceNode (exit_tensor, inEdge->tail ());
364
- dfg.deleteNode (exit_tensor);
365
- }
366
- }
367
-
368
- // CFG Handling
369
- auto bodyCFNodes = bodyCFGraph.getMutableNodes ();
370
- auto bodyCFEdges = bodyCFGraph.getMutableEdges ();
371
-
372
- // Create a while loop CFG node.
373
- auto whileBasicBlock = util::make_unique<repr::BasicBlockType<repr::NNGraph>>();
374
- for (auto & inEdge : opNode->getInEdges ()) {
375
- auto node = inEdge->tail ();
376
- for (auto & parentInEdge : node->getInEdges ()) {
377
- auto parentNode = parentInEdge->tail ();
378
- if (isa<repr::Phi>(parentNode->data ().get ())) {
379
- whileBasicBlock->pushInstructionNode (parentNode);
380
- }
381
- }
382
- }
383
- whileBasicBlock->pushInstructionNode (opNode);
384
-
385
- auto whileCFNode = cfg.createNode (std::move (whileBasicBlock));
386
- cfg.createEdge (bbNode, whileCFNode, 0 );
387
-
388
- // The true path executes the body of the loop, so we
389
- // take that BB and point to it.
390
- for (auto cfNode : bodyCFNodes) {
391
- bodyCFGraph.importNode (cfNode, cfg);
392
- // If the CFG node has no children, we loop back to the top of the
393
- // while loop.
394
- if (cfNode->getOutEdges ().size () == 0 ) {
395
- cfg.createEdge (cfNode, whileCFNode, 0 );
396
- }
397
- // TODO check for a single entry point
398
- if (cfNode->getInEdges ().size () == 0 ) {
399
- cfg.createEdge (whileCFNode, cfNode, 1 );
400
- }
401
- }
402
- for (auto cfEdge : bodyCFEdges) {
403
- bodyCFGraph.importEdge (cfEdge, cfg);
404
- }
405
-
406
- // Now create the false case.
407
- bbNode =
408
- cfg.createNode (util::make_unique<repr::BasicBlockType<repr::NNGraph>>());
409
- cfg.createEdge (whileCFNode, bbNode, -1 );
410
- }
411
-
412
273
413
274
// / \brief Ingest a caffe2 protobuf model and output an NNModule.
414
275
// / \param net The caffe2 protobuf NetDef
@@ -455,13 +316,9 @@ repr::NNModule convertToNNModule(caffe2::NetDef &net, std::unordered_map<std::st
455
316
blobMap[output] = tensorNode;
456
317
}
457
318
458
- if (op.type () == " While" ) {
459
- handleWhileOp (dfg, cfg, opNode, bbNode, op, blobMap);
460
- } else {
461
- opNode->resetData (convertToNeuralNetOperator (op));
462
- auto currentBasicBlock = bbNode->mutableData ()->get ();
463
- currentBasicBlock->pushInstructionNode (opNode);
464
- }
319
+ opNode->resetData (convertToNeuralNetOperator (op));
320
+ auto currentBasicBlock = bbNode->mutableData ()->get ();
321
+ currentBasicBlock->pushInstructionNode (opNode);
465
322
}
466
323
467
324
repr::NNModule module;
0 commit comments