-
Notifications
You must be signed in to change notification settings - Fork 802
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Supplement ND SBP signatures for reshape op (#9858)
对于绝大多数 op 来说,我们只需要列举其可能支持的 1-D SBP signatures,在 2-D (可推及到 ND) 时,对其 1-D SBP signature list 做叉积即可得到 2-D SBP signature list。 但 reshape op 某些情况就属于例外,见如下的例子: ``` (8, 4) reshape to (2, 4, 4) with 2x2 ranks, has the 1-D sbp signatrue list as below: S(0) -> S(0) S(1) -> S(2) by cross product, get the 2-D sbp signatrue list as below: [S(0), S(1)] -> [S(0), S(2)] [S(1), S(0)] -> [S(2), S(0)] (will bring huge comm cost, ignore it in later discuss) but below 2-D sbp signatrue is supported too: [S(0), S(0)] -> [S(0), S(1)] ``` 从上面的例子中可以发现一个规律:**高维的 SBP signatures 不能完全由低维组合而来**。 基于以上理由,为 op 提供一个新的重载函数 **EnumerateNdSbpSignatures**:其会在 1-D SBP signatures 被列举完后,并由 1-D 叉积产生了 2-D SBP signature list 后被调用。作为当 1-D 叉积不能产生全部的 2-D SBP signatures 的时候,提供一种手段来补充额外的 2-D SBP signatures。 为 reshape 实现了 EnumerateNdSbpSignatures,算法简单来说就是,找到那些被 reshape 的 dimension,从高到低连续按 rank num 切分,直到失败,或者能均匀切到每个 rank 上(从高到底是为了保证切分连续性)。 EnumerateNdSbpSignatures 与已有的 **GetNdSbpSignatureList** 重载区别是:EnumerateNdSbpSignatures 是在 1-D SBP signatures 叉积之后的额外补充。而 GetNdSbpSignatureList 是完全重载 2-D SBP signatures 的列举逻辑,不会包含 1-D SBP 的叉积生成,其主要作用是为了 source op,用户可以直接通过 attr 来设置输出的 sbp,而无需推导。 --------- Co-authored-by: Yipeng Li <jamesonli1313@gmail.com>
- Loading branch information
1 parent
6442a25
commit 9f47744
Showing
20 changed files
with
1,476 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
/* | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
*/ | ||
#include "oneflow/core/framework/sbp_infer_util.h" | ||
#include "oneflow/core/framework/nd_sbp.h" | ||
|
||
#include <gtest/gtest.h> | ||
|
||
namespace oneflow { | ||
namespace test { | ||
|
||
namespace { | ||
|
||
bool ParseNdSbpSignatureFromString(const std::string& nd_sbp_signature_str, | ||
NdSbpSignature& nd_sbp_signature) { | ||
auto* bn2nd_sbp = nd_sbp_signature.mutable_bn_in_op2nd_sbp(); | ||
std::string arg_name = "in"; | ||
bool meet_nd_sbp_group = false; | ||
bool meet_split = false; | ||
int nd_sbp_group_id = 0; | ||
std::vector<std::string> nd_sbp_str_group; | ||
size_t pos = 0; | ||
while (pos < nd_sbp_signature_str.size()) { | ||
const char& c = nd_sbp_signature_str[pos]; | ||
pos++; | ||
if (c == ' ') { | ||
continue; | ||
} else if (c == '(') { | ||
if (!meet_nd_sbp_group) { | ||
// enter a nd-sbp group | ||
meet_nd_sbp_group = true; | ||
nd_sbp_str_group.emplace_back(); | ||
continue; | ||
} else { | ||
// meet left parentheses of S(x) | ||
meet_split = true; | ||
} | ||
} else if (c == ')') { | ||
if (meet_split) { | ||
// meet right parentheses of S(x) | ||
meet_split = false; | ||
} else if (meet_nd_sbp_group) { | ||
// leave a nd-sbp group | ||
meet_nd_sbp_group = false; | ||
std::string bn = arg_name + "_" + std::to_string(nd_sbp_group_id); | ||
if (!ParseNdSbpFromStringList(nd_sbp_str_group, &(*bn2nd_sbp)[bn])) { return false; } | ||
nd_sbp_str_group.clear(); | ||
continue; | ||
} else { | ||
return false; | ||
} | ||
} else if (c == ',') { | ||
if (meet_nd_sbp_group) { | ||
nd_sbp_str_group.emplace_back(); | ||
} else { | ||
nd_sbp_group_id += 1; | ||
} | ||
continue; | ||
} else if (c == '-') { | ||
if (pos < nd_sbp_signature_str.size() && nd_sbp_signature_str[pos] == '>') { | ||
// in args parsing has finished, parse out args | ||
arg_name = "out"; | ||
nd_sbp_group_id = 0; | ||
// skip '>' in substr '->' | ||
pos++; | ||
continue; | ||
} else { | ||
return false; | ||
} | ||
} else { | ||
// do nothing | ||
} | ||
nd_sbp_str_group.back() += c; | ||
} | ||
return true; | ||
} | ||
|
||
std::string NdSbpSignature2String(const NdSbpSignature& nd_sbp_signature, | ||
const std::vector<std::string>& inputs, | ||
const std::vector<std::string>& outputs) { | ||
std::ostringstream ss; | ||
auto BnNdSbpToString = [&](const std::string& bn) { | ||
auto iter = nd_sbp_signature.bn_in_op2nd_sbp().find(bn); | ||
CHECK(iter != nd_sbp_signature.bn_in_op2nd_sbp().end()); | ||
ss << NdSbpToString(iter->second); | ||
}; | ||
auto ArgsNdSbpToString = [&](const std::vector<std::string>& arg_bns) { | ||
for (size_t i = 0; i < arg_bns.size(); ++i) { | ||
if (i > 0) { ss << ", "; } | ||
BnNdSbpToString(arg_bns[i]); | ||
} | ||
}; | ||
ArgsNdSbpToString(inputs); | ||
ss << " -> "; | ||
ArgsNdSbpToString(outputs); | ||
return ss.str(); | ||
} | ||
|
||
void TestDeduplicateNdSbpSignature(const std::vector<std::string>& nd_sbp_signature_str_list, | ||
const std::vector<std::string>& input_bns, | ||
const std::vector<std::string>& output_bns) { | ||
// parse | ||
std::vector<NdSbpSignature> nd_sbp_sig_list; | ||
nd_sbp_sig_list.reserve(nd_sbp_signature_str_list.size()); | ||
for (const auto& nd_sbp_signature_str : nd_sbp_signature_str_list) { | ||
nd_sbp_sig_list.emplace_back(); | ||
ASSERT_TRUE(ParseNdSbpSignatureFromString(nd_sbp_signature_str, nd_sbp_sig_list.back())); | ||
} | ||
|
||
// shuffle and repeat | ||
std::random_device rd; | ||
std::mt19937 gen(rd()); | ||
std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen); | ||
nd_sbp_sig_list.reserve(nd_sbp_sig_list.size() + nd_sbp_sig_list.size() / 2); | ||
std::copy_n(nd_sbp_sig_list.begin(), nd_sbp_sig_list.size() / 2, | ||
std::back_inserter(nd_sbp_sig_list)); | ||
std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen); | ||
|
||
// dedup and sort | ||
std::vector<std::string> bns; | ||
bns.insert(bns.end(), input_bns.begin(), input_bns.end()); | ||
bns.insert(bns.end(), output_bns.begin(), output_bns.end()); | ||
DeduplicateNdSbpSignatureList(&nd_sbp_sig_list, bns); | ||
|
||
// compare | ||
ASSERT_EQ(nd_sbp_signature_str_list.size(), nd_sbp_sig_list.size()); | ||
for (size_t i = 0; i < nd_sbp_sig_list.size(); ++i) { | ||
auto nd_sbp_sig_result = NdSbpSignature2String(nd_sbp_sig_list[i], input_bns, output_bns); | ||
ASSERT_EQ(nd_sbp_sig_result, nd_sbp_signature_str_list[i]); | ||
} | ||
} | ||
|
||
} // namespace | ||
|
||
TEST(SbpInferUtil, DeduplicateNdSbpSignatureList) { | ||
TestDeduplicateNdSbpSignature( | ||
{ | ||
"(B, B) -> (B, B)", | ||
"(B, P) -> (B, P)", | ||
"(B, S(0)) -> (B, S(0))", | ||
"(B, S(1)) -> (B, S(1))", | ||
"(B, S(3)) -> (B, S(2))", | ||
"(P, B) -> (P, B)", | ||
"(P, P) -> (P, P)", | ||
"(P, S(0)) -> (P, S(0))", | ||
"(P, S(1)) -> (P, S(1))", | ||
"(P, S(3)) -> (P, S(2))", | ||
"(S(0), B) -> (S(0), B)", | ||
"(S(0), P) -> (S(0), P)", | ||
"(S(0), S(0)) -> (S(0), S(0))", | ||
"(S(0), S(1)) -> (S(0), S(1))", | ||
"(S(0), S(3)) -> (S(0), S(2))", | ||
"(S(1), B) -> (S(1), B)", | ||
"(S(1), P) -> (S(1), P)", | ||
"(S(1), S(0)) -> (S(1), S(0))", | ||
"(S(1), S(1)) -> (S(1), S(1))", | ||
"(S(1), S(3)) -> (S(1), S(2))", | ||
"(S(3), B) -> (S(2), B)", | ||
"(S(3), P) -> (S(2), P)", | ||
"(S(3), S(0)) -> (S(2), S(0))", | ||
"(S(3), S(1)) -> (S(2), S(1))", | ||
"(S(3), S(3)) -> (S(2), S(2))", | ||
}, | ||
{"in_0"}, {"out_0"}); | ||
} | ||
|
||
} // namespace test | ||
} // namespace oneflow |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.