Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supplement ND SBP signatures for reshape op #9858

Merged
merged 35 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
663cf30
EnumerateNdSbpSignatures
leaves-zwx Feb 9, 2023
dccb762
ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 11, 2023
7cd2dc1
fix algo
leaves-zwx Feb 12, 2023
1eed7aa
fix
leaves-zwx Feb 12, 2023
3dc723b
cpp test
leaves-zwx Feb 12, 2023
cbf7261
fix
leaves-zwx Feb 12, 2023
a440e8d
rename
leaves-zwx Feb 12, 2023
6998d52
refine UserOp::GetNdSbpSignatureList
leaves-zwx Feb 12, 2023
94870a6
ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 12, 2023
bff5183
py test
leaves-zwx Feb 12, 2023
2db180a
revert changes
leaves-zwx Feb 13, 2023
84b3acf
new ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 13, 2023
ad4a227
add test case
leaves-zwx Feb 13, 2023
7c22e1a
refine algorithm
leaves-zwx Feb 15, 2023
c51364a
update test and fix bug
leaves-zwx Feb 16, 2023
db658eb
rm comment
leaves-zwx Feb 16, 2023
278f577
rm redundant condition
leaves-zwx Feb 16, 2023
c500114
DeduplicateNdSbpSignatureList
leaves-zwx Feb 17, 2023
a4ddd55
refine GenRankMeshSubset
leaves-zwx Feb 17, 2023
508cc6f
add comments
leaves-zwx Feb 17, 2023
e5451c2
suppress warning
leaves-zwx Feb 18, 2023
b3c9ca9
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 18, 2023
66d3923
suppress warning
leaves-zwx Feb 18, 2023
5e4aebc
suppress warning
leaves-zwx Feb 18, 2023
670eca3
refine sort algorithm
leaves-zwx Feb 21, 2023
351083e
refine sort algorithm and test
leaves-zwx Feb 21, 2023
e6b579b
rm calling FilterNdSbpIn2OutSignatures in EnumerateNdSbpSignatures
leaves-zwx Feb 21, 2023
ca9eb24
change example
leaves-zwx Feb 21, 2023
abe42db
Update oneflow/user/ops/reshape_user_op_util.cpp
leaves-zwx Feb 21, 2023
4836822
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 21, 2023
7bacbcf
refine sort algorithm
leaves-zwx Feb 25, 2023
f578fcf
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 25, 2023
d3cf09d
change sort order
leaves-zwx Feb 28, 2023
13dd1be
use FilterNdSbpByLogicalShape
leaves-zwx Feb 28, 2023
70ae77e
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 52 additions & 70 deletions oneflow/core/framework/sbp_infer_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,85 +787,67 @@ void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dim
}
}

// return -1 indicate lhs < rhs, 1 indicate lhs > rhs, 0 indicate lhs == rhs
int CompareNdSbp(const NdSbp& lhs_nd_sbp, const NdSbp& rhs_nd_sbp) {
CHECK_EQ(lhs_nd_sbp.sbp_parallel_size(), rhs_nd_sbp.sbp_parallel_size());
// max split axis = 8, + B + P
const size_t kMaxSplitAxis = 8;
const size_t kCarryDigit = kMaxSplitAxis + 2;
auto Mesure = [](const NdSbp& nd_sbp) -> size_t {
size_t value = 0;
for (int i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {
size_t cur_dim_value = 0;
const auto& sbp = nd_sbp.sbp_parallel(i);
if (sbp.has_split_parallel()) {
CHECK_LT(sbp.split_parallel().axis(), kMaxSplitAxis);
cur_dim_value = sbp.split_parallel().axis();
} else if (sbp.has_broadcast_parallel()) {
cur_dim_value = kMaxSplitAxis;
} else if (sbp.has_partial_sum_parallel()) {
cur_dim_value = kMaxSplitAxis + 1;
} else {
UNIMPLEMENTED();
}
value += cur_dim_value * std::pow(kCarryDigit, (nd_sbp.sbp_parallel_size() - i - 1));
}
return value;
};
size_t lhs_value = Mesure(lhs_nd_sbp);
size_t rhs_value = Mesure(rhs_nd_sbp);
namespace {

if (lhs_value < rhs_value) {
return -1;
} else if (lhs_value > rhs_value) {
return 1;
} else {
return 0;
// give a mesure value for NdSbp for sorting
size_t MesureNdSbp(const NdSbp& nd_sbp) {
// max split axis = 8 (start from 1), + B + P = 10
constexpr size_t kMaxSplitAxis = 8;
constexpr size_t kCarryDigit = kMaxSplitAxis + 3;
size_t value = 0;
for (int i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {
size_t cur_dim_value = 0;
const auto& sbp = nd_sbp.sbp_parallel(i);
if (sbp.has_split_parallel()) {
CHECK_LT(sbp.split_parallel().axis(), kMaxSplitAxis);
// from 1 to 8
cur_dim_value = sbp.split_parallel().axis() + 1;
} else if (sbp.has_broadcast_parallel()) {
// 9
cur_dim_value = kMaxSplitAxis + 1;
} else if (sbp.has_partial_sum_parallel()) {
// 10
cur_dim_value = kMaxSplitAxis + 2;
} else {
UNIMPLEMENTED();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

别的地方没什么问题了。

这里我想了一下,上次我们开会时是讨论到说用97作为不同blob的进位对吧,如果说NdSbp的数字不超过96,是不会有任何风险的。
那现在把B映射到9,P映射到10,而进位是11,这样(B, B)就会是9*11+9 = 108,是超过了97的。
而B是出现很频繁的SBP,也就是风险会出现得频繁一些。
但是如果把B映射到1,P映射到2,Si -> i+3,这样要超过96起码是 88+9,也就是 (S5, S6),(S5, S7)或者 (S6, 任意SBP) 才有可能有风险。

实际中S5少见,更不用说 (S5, S6)了,甚至 (S5, S5) 都是没有问题的。要出问题,至少有一个S6或者 S7,也就是张量起码要有7维。

所以把B,P映射的数字前调能够很有效地避免大部分的风险。(当然即使有风险也不一定会出问题,素数能有效地规避掉一些,但是如果能够避免大部分的风险,还是避免的好)

value = value * kCarryDigit + cur_dim_value;
}
return value;
}

// return -1 indicate lhs < rhs, 1 indicate lhs > rhs, 0 indicate lhs == rhs
int CompareNdSbpSignature(const NdSbpSignature& lhs_nd_sbp_sig,
const NdSbpSignature& rhs_nd_sbp_sig) {
CHECK_EQ(lhs_nd_sbp_sig.bn_in_op2nd_sbp_size(), rhs_nd_sbp_sig.bn_in_op2nd_sbp_size());
for (const auto& lhs_bn_nd_sbp : lhs_nd_sbp_sig.bn_in_op2nd_sbp()) {
const auto& bn = lhs_bn_nd_sbp.first;
auto rhs_bn_nd_sbp_it = rhs_nd_sbp_sig.bn_in_op2nd_sbp().find(bn);
CHECK(rhs_bn_nd_sbp_it != rhs_nd_sbp_sig.bn_in_op2nd_sbp().end());
const auto& lhs_nd_sbp = lhs_bn_nd_sbp.second;
const auto& rhs_nd_sbp = rhs_bn_nd_sbp_it->second;
int cmp_ret = CompareNdSbp(lhs_nd_sbp, rhs_nd_sbp);
if (cmp_ret != 0) { return cmp_ret; }
size_t MesureNdSbpSignature(const NdSbpSignature& nd_sbp_sig, const std::vector<std::string>& bns) {
// big enough for 2d-sbp signatrue set
// if want to extend to 3d-sbp, consider increase to 170
constexpr size_t kCarryDigit = 97;
size_t value = 0;
for (size_t i = 0; i < bns.size(); ++i) {
auto nd_sbp_it = nd_sbp_sig.bn_in_op2nd_sbp().find(bns[i]);
CHECK(nd_sbp_it != nd_sbp_sig.bn_in_op2nd_sbp().end())
<< "can't find bn (" << bns[i] << ") in " << PbMessage2TxtString(nd_sbp_sig);
size_t cur_arg_value = MesureNdSbp(nd_sbp_it->second);
CHECK_LE(value + cur_arg_value / kCarryDigit, std::numeric_limits<size_t>::max() / kCarryDigit);
value = value * kCarryDigit + cur_arg_value;
}
return 0;
return value;
}

void DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list) {
auto CompareIndices = [nd_sbp_sig_list](size_t lhs, size_t rhs) -> bool {
int cmp_ret = CompareNdSbpSignature(nd_sbp_sig_list->at(lhs), nd_sbp_sig_list->at(rhs));
if (cmp_ret == 0) {
return true;
} else if (cmp_ret == -1) {
return true;
} else if (cmp_ret == 1) {
return false;
} else {
UNIMPLEMENTED();
} // namespace

void DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list,
const std::vector<std::string>& bns) {
if (bns.size() > 8) { return; }
std::map<size_t, NdSbpSignature> value2nd_sbp_sig;
for (auto& nd_sbp_sig : *nd_sbp_sig_list) {
size_t order_value = MesureNdSbpSignature(nd_sbp_sig, bns);
if (value2nd_sbp_sig.find(order_value) == value2nd_sbp_sig.end()) {
value2nd_sbp_sig.emplace(order_value, std::move(nd_sbp_sig));
}
};
std::vector<size_t> indices(nd_sbp_sig_list->size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), CompareIndices);
auto new_end =
std::unique(indices.begin(), indices.end(), [nd_sbp_sig_list](size_t lhs, size_t rhs) {
return nd_sbp_sig_list->at(lhs) == nd_sbp_sig_list->at(rhs);
});
indices.erase(new_end, indices.end());
std::vector<NdSbpSignature> new_nd_sbp_sig_list;
for (size_t index : indices) {
new_nd_sbp_sig_list.emplace_back(std::move(nd_sbp_sig_list->at(index)));
}
*nd_sbp_sig_list = std::move(new_nd_sbp_sig_list);
nd_sbp_sig_list->clear();
for (auto& nd_sbp_pair : value2nd_sbp_sig) {
nd_sbp_sig_list->emplace_back(std::move(nd_sbp_pair.second));
}
}

// Compute storage per device for given NdSbp
Expand Down
7 changes: 2 additions & 5 deletions oneflow/core/framework/sbp_infer_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,8 @@ void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dim
const HashMap<int32_t, SbpSignatureList>& hierarchy_value2sbp_sig_list,
std::vector<NdSbpSignature>* nd_sbp_sig_list);

// return -1 indicate lhs < rhs, 1 indicate lhs > rhs, 0 indicate lhs == rhs
int CompareNdSbp(const NdSbp& lhs_nd_sbp, const NdSbp& rhs_nd_sbp);
int CompareNdSbpSignature(const NdSbpSignature& lhs_nd_sbp_sig,
const NdSbpSignature& rhs_nd_sbp_sig);
void DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list);
void DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list,
const std::vector<std::string>& bns);

// Compute storage for given NdSbp
double Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy);
Expand Down
78 changes: 42 additions & 36 deletions oneflow/core/framework/sbp_infer_util_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ bool ParseNdSbpSignatureFromString(const std::string& nd_sbp_signature_str,
if (pos < nd_sbp_signature_str.size() && nd_sbp_signature_str[pos] == '>') {
// in args parsing has finished, parse out args
arg_name = "out";
nd_sbp_group_id = 0;
// skip '>' in substr '->'
pos++;
continue;
Expand All @@ -85,51 +86,56 @@ bool ParseNdSbpSignatureFromString(const std::string& nd_sbp_signature_str,
return true;
}

void TestCompareNdSbpSignature(const std::vector<std::string>& nd_sbp_signature_str_list) {
void TestDeduplicateNdSbpSignature(const std::vector<std::string>& nd_sbp_signature_str_list,
const std::vector<std::string>& bns) {
std::vector<NdSbpSignature> nd_sbp_sig_list;
nd_sbp_sig_list.reserve(nd_sbp_signature_str_list.size());
for (const auto& nd_sbp_signature_str : nd_sbp_signature_str_list) {
nd_sbp_sig_list.emplace_back();
ASSERT_TRUE(ParseNdSbpSignatureFromString(nd_sbp_signature_str, nd_sbp_sig_list.back()));
}
for (const auto& lhs_nd_sbp_sig : nd_sbp_sig_list) {
for (const auto& rhs_nd_sbp_sig : nd_sbp_sig_list) {
int cmp_ret1 = CompareNdSbpSignature(lhs_nd_sbp_sig, rhs_nd_sbp_sig);
int cmp_ret2 = CompareNdSbpSignature(rhs_nd_sbp_sig, lhs_nd_sbp_sig);
ASSERT_TRUE(cmp_ret1 + cmp_ret2 == 0);
}
}
std::random_device rd;
std::mt19937 gen(rd());
std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen);
nd_sbp_sig_list.reserve(nd_sbp_sig_list.size() + nd_sbp_sig_list.size() / 2);
std::copy_n(nd_sbp_sig_list.begin(), nd_sbp_sig_list.size() / 2,
std::back_inserter(nd_sbp_sig_list));
std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen);
DeduplicateNdSbpSignatureList(&nd_sbp_sig_list, bns);
}

} // namespace

TEST(SbpInferUtil, CompareNdSbpSignature) {
TestCompareNdSbpSignature({
"(S(0), S(0)) -> (S(0), S(0))",
"(S(0), S(1)) -> (S(0), S(1))",
"(S(0), S(3)) -> (S(0), S(2))",
"(S(0), B) -> (S(0), B)",
"(S(0), P) -> (S(0), P)",
"(S(1), S(0)) -> (S(1), S(0))",
"(S(1), S(1)) -> (S(1), S(1))",
"(S(1), S(3)) -> (S(1), S(2))",
"(S(1), B) -> (S(1), B)",
"(S(1), P) -> (S(1), P)",
"(S(3), S(0)) -> (S(2), S(0))",
"(S(3), S(1)) -> (S(2), S(1))",
"(S(3), S(3)) -> (S(2), S(2))",
"(S(3), B) -> (S(2), B)",
"(S(3), P) -> (S(2), P)",
"(B, S(0)) -> (B, S(0))",
"(B, S(1)) -> (B, S(1))",
"(B, S(3)) -> (B, S(2))",
"(B, B) -> (B, B)",
"(B, P) -> (B, P)",
"(P, S(0)) -> (P, S(0))",
"(P, S(1)) -> (P, S(1))",
"(P, S(3)) -> (P, S(2))",
"(P, B) -> (P, B)",
"(P, P) -> (P, P)",
});
TEST(SbpInferUtil, DeduplicateNdSbpSignatureList) {
TestDeduplicateNdSbpSignature(
{
"(S(0), S(0)) -> (S(0), S(0))",
"(S(0), S(1)) -> (S(0), S(1))",
"(S(0), S(3)) -> (S(0), S(2))",
"(S(0), B) -> (S(0), B)",
"(S(0), P) -> (S(0), P)",
"(S(1), S(0)) -> (S(1), S(0))",
"(S(1), S(1)) -> (S(1), S(1))",
"(S(1), S(3)) -> (S(1), S(2))",
"(S(1), B) -> (S(1), B)",
"(S(1), P) -> (S(1), P)",
"(S(3), S(0)) -> (S(2), S(0))",
"(S(3), S(1)) -> (S(2), S(1))",
"(S(3), S(3)) -> (S(2), S(2))",
"(S(3), B) -> (S(2), B)",
"(S(3), P) -> (S(2), P)",
"(B, S(0)) -> (B, S(0))",
"(B, S(1)) -> (B, S(1))",
"(B, S(3)) -> (B, S(2))",
"(B, B) -> (B, B)",
"(B, P) -> (B, P)",
"(P, S(0)) -> (P, S(0))",
"(P, S(1)) -> (P, S(1))",
"(P, S(3)) -> (P, S(2))",
"(P, B) -> (P, B)",
"(P, P) -> (P, P)",
},
{"in_0", "out_0"});
}

} // namespace test
Expand Down
2 changes: 1 addition & 1 deletion oneflow/user/ops/reshape_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ namespace oneflow {
}
}

DeduplicateNdSbpSignatureList(nd_sbp_sig_list);
DeduplicateNdSbpSignatureList(nd_sbp_sig_list, {"in_0", "out_0"});
return Maybe<void>::Ok();
}

Expand Down