Skip to content

Commit 9c0f65f

Browse files
bwastifacebook-github-bot
authored andcommitted
Remove While op stuff (pytorch#10102)
Summary: Pull Request resolved: pytorch#10102 these codepaths are unused, deleting them Reviewed By: yinghai Differential Revision: D9109764 fbshipit-source-id: 8ace42a399806632bfbcada96b383268f0a8ae89
1 parent c54d71b commit 9c0f65f

File tree

2 files changed

+3
-208
lines changed

2 files changed

+3
-208
lines changed

caffe2/opt/converter.cc

+3-146
Original file line numberDiff line numberDiff line change
@@ -270,145 +270,6 @@ std::unique_ptr<repr::NeuralNetOperator> convertToNeuralNetOperator(
270270
return nnOp;
271271
}
272272

273-
void handleWhileOp(
274-
repr::NNGraph& dfg,
275-
repr::NNCFGraph& cfg,
276-
repr::NNGraph::NodeRef& opNode,
277-
repr::NNCFGraph::NodeRef& bbNode,
278-
OperatorDef& op,
279-
std::unordered_map<std::string, repr::NNGraph::NodeRef>& blobMap
280-
) {
281-
opNode->resetData(util::make_unique<repr::While>());
282-
auto argMap = Converter::getArgumentsFromOperator(op);
283-
std::string bodyNetSerialized = argMap["body"].s();
284-
auto bodyNet = caffe2::NetDef();
285-
bodyNet.ParseFromString(bodyNetSerialized);
286-
287-
std::unordered_map<std::string, repr::NNGraph::NodeRef> bodyBlobMap;
288-
auto bodyNN = convertToNNModule(bodyNet, &bodyBlobMap);
289-
repr::NNGraph bodyGraph = std::move(bodyNN.dataFlow);
290-
repr::NNCFGraph bodyCFGraph = std::move(bodyNN.controlFlow);
291-
292-
auto rev_sorted = algorithm::tarjans(&bodyGraph);
293-
294-
for (auto& k : bodyBlobMap) {
295-
auto name = k.first;
296-
if (blobMap.count(name)) {
297-
auto oldNode = blobMap[name];
298-
printf("Exit tensor %s is in the parent scope, inserting Phi node...\n", k.first.c_str());
299-
auto phiNode = dfg.createNode(util::make_unique<repr::NNPhi>()); // NN variant of a Phi node
300-
// Clone the operator.
301-
auto tensor = dyn_cast<repr::NeuralNetData>(blobMap[name]->data().get());
302-
auto* clonedTensor = tensor->clone();
303-
auto phiOut = dfg.createNode(std::unique_ptr<repr::NeuralNetData>(clonedTensor));
304-
dfg.createEdge(phiNode, phiOut);
305-
dfg.createEdge(oldNode, phiNode);
306-
dfg.createEdge(bodyBlobMap[name], phiNode);
307-
blobMap[name] = phiOut;
308-
for (auto& inEdge : opNode->getInEdges()) {
309-
if (inEdge->tail() == oldNode) {
310-
dfg.deleteEdge(inEdge);
311-
dfg.createEdge(phiOut, opNode);
312-
}
313-
}
314-
}
315-
}
316-
317-
// Dependencies simply have no producers
318-
std::unordered_map<repr::NNGraph::NodeRef, repr::NNGraph::NodeRef> inNodeMap;
319-
for (auto& n : bodyGraph.getMutableNodes()) {
320-
if (!isa<repr::NeuralNetData>(n->data())) { continue; }
321-
if (n->getInEdges().size() == 0) {
322-
auto name = dyn_cast<repr::NeuralNetData>(n->data().get())->getName();
323-
// TODO(bwasti): this may be needed, depending on constraints
324-
//assert(blobMap.count(name) != 0 && "Loop body takes undefined dependency.");
325-
if (blobMap.count(name)) {
326-
inNodeMap[n] = blobMap[name];
327-
}
328-
}
329-
}
330-
331-
CAFFE_ENFORCE(rev_sorted.front().getNodes().size() == 1,
332-
"More than one exit node.");
333-
CAFFE_ENFORCE(rev_sorted.back().getNodes().size() == 1,
334-
"More than one entry node.");
335-
336-
auto exit_tensor = *(rev_sorted.front().getNodes().begin());
337-
CAFFE_ENFORCE(isa<repr::NeuralNetData>(exit_tensor->data()),
338-
"Exit node is not a tensor.");
339-
340-
auto bodyNodes = bodyGraph.getMutableNodes();
341-
auto bodyEdges = bodyGraph.getMutableEdges();
342-
343-
for (auto node : bodyNodes) {
344-
bodyGraph.importNode(node, dfg);
345-
}
346-
347-
for (auto edge : bodyEdges) {
348-
bodyGraph.importEdge(edge, dfg);
349-
}
350-
351-
// Merge all dependencies
352-
for (auto node : dfg.getMutableNodes()) {
353-
if (inNodeMap.count(node)) {
354-
dfg.replaceNode(node, inNodeMap[node]);
355-
dfg.deleteNode(node);
356-
}
357-
}
358-
359-
for (const auto& inEdge : opNode->getInEdges()) {
360-
auto* inputData = dyn_cast<repr::NeuralNetData>(inEdge->tail()->data().get());
361-
auto* exitData = dyn_cast<repr::NeuralNetData>(exit_tensor->data().get());
362-
if (inputData->getName() == exitData->getName()) {
363-
dfg.replaceNode(exit_tensor, inEdge->tail());
364-
dfg.deleteNode(exit_tensor);
365-
}
366-
}
367-
368-
// CFG Handling
369-
auto bodyCFNodes = bodyCFGraph.getMutableNodes();
370-
auto bodyCFEdges = bodyCFGraph.getMutableEdges();
371-
372-
// Create a while loop CFG node.
373-
auto whileBasicBlock = util::make_unique<repr::BasicBlockType<repr::NNGraph>>();
374-
for (auto& inEdge : opNode->getInEdges()) {
375-
auto node = inEdge->tail();
376-
for (auto& parentInEdge : node->getInEdges()) {
377-
auto parentNode = parentInEdge->tail();
378-
if (isa<repr::Phi>(parentNode->data().get())) {
379-
whileBasicBlock->pushInstructionNode(parentNode);
380-
}
381-
}
382-
}
383-
whileBasicBlock->pushInstructionNode(opNode);
384-
385-
auto whileCFNode = cfg.createNode(std::move(whileBasicBlock));
386-
cfg.createEdge(bbNode, whileCFNode, 0);
387-
388-
// The true path executes the body of the loop, so we
389-
// take that BB and point to it.
390-
for (auto cfNode : bodyCFNodes) {
391-
bodyCFGraph.importNode(cfNode, cfg);
392-
// If the CFG node has no children, we loop back to the top of the
393-
// while loop.
394-
if (cfNode->getOutEdges().size() == 0) {
395-
cfg.createEdge(cfNode, whileCFNode, 0);
396-
}
397-
// TODO check for a single entry point
398-
if (cfNode->getInEdges().size() == 0) {
399-
cfg.createEdge(whileCFNode, cfNode, 1);
400-
}
401-
}
402-
for (auto cfEdge : bodyCFEdges) {
403-
bodyCFGraph.importEdge(cfEdge, cfg);
404-
}
405-
406-
// Now create the false case.
407-
bbNode =
408-
cfg.createNode(util::make_unique<repr::BasicBlockType<repr::NNGraph>>());
409-
cfg.createEdge(whileCFNode, bbNode, -1);
410-
}
411-
412273

413274
/// \brief Ingest a caffe2 protobuf model and output an NNModule.
414275
/// \param net The caffe2 protobuf NetDef
@@ -455,13 +316,9 @@ repr::NNModule convertToNNModule(caffe2::NetDef &net, std::unordered_map<std::st
455316
blobMap[output] = tensorNode;
456317
}
457318

458-
if (op.type() == "While") {
459-
handleWhileOp(dfg, cfg, opNode, bbNode, op, blobMap);
460-
} else {
461-
opNode->resetData(convertToNeuralNetOperator(op));
462-
auto currentBasicBlock = bbNode->mutableData()->get();
463-
currentBasicBlock->pushInstructionNode(opNode);
464-
}
319+
opNode->resetData(convertToNeuralNetOperator(op));
320+
auto currentBasicBlock = bbNode->mutableData()->get();
321+
currentBasicBlock->pushInstructionNode(opNode);
465322
}
466323

467324
repr::NNModule module;

caffe2/opt/converter_nomigraph_test.cc

-62
Original file line numberDiff line numberDiff line change
@@ -48,65 +48,3 @@ TEST(Converter, UnknownType) {
4848
auto new_netdef = caffe2::convertToCaffe2Proto(nn);
4949
}
5050

51-
/* Temporarily disabled While conversion tests
52-
TEST(Converter, While) {
53-
caffe2::NetDef net;
54-
55-
caffe2::OperatorDef *def = net.add_op();
56-
def->set_type("While");
57-
def->add_input("X");
58-
59-
caffe2::NetDef body_net;
60-
{
61-
caffe2::OperatorDef *rdef = body_net.add_op();
62-
rdef->set_type("Relu");
63-
rdef->add_input("X");
64-
rdef->add_output("X");
65-
}
66-
std::string body_net_serialized;
67-
assert(body_net.SerializeToString(&body_net_serialized));
68-
ADD_ARG(def, "body", s, body_net_serialized);
69-
70-
auto nn = caffe2::convertToNNModule(net);
71-
}
72-
73-
TEST(Converter, ComplexWhile) {
74-
caffe2::NetDef net;
75-
76-
{
77-
caffe2::OperatorDef *rdef = net.add_op();
78-
rdef->set_type("Relu");
79-
rdef->add_input("X");
80-
rdef->add_output("X");
81-
}
82-
83-
caffe2::OperatorDef *def = net.add_op();
84-
def->set_type("While");
85-
def->add_input("X");
86-
87-
caffe2::NetDef body_net;
88-
{
89-
caffe2::OperatorDef *rdef = body_net.add_op();
90-
rdef->set_type("Instr1");
91-
rdef->add_input("X");
92-
rdef->add_output("X");
93-
}
94-
{
95-
caffe2::OperatorDef *rdef = body_net.add_op();
96-
rdef->set_type("Instr2");
97-
rdef->add_input("X");
98-
rdef->add_output("X");
99-
}
100-
{
101-
caffe2::OperatorDef *rdef = body_net.add_op();
102-
rdef->set_type("Instr3");
103-
rdef->add_input("X");
104-
rdef->add_output("X");
105-
}
106-
std::string body_net_serialized;
107-
assert(body_net.SerializeToString(&body_net_serialized));
108-
ADD_ARG(def, "body", s, body_net_serialized);
109-
110-
auto nn = caffe2::convertToNNModule(net);
111-
}
112-
*/

0 commit comments

Comments
 (0)