Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supplement ND SBP signatures for reshape op #9858

Merged
merged 35 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
663cf30
EnumerateNdSbpSignatures
leaves-zwx Feb 9, 2023
dccb762
ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 11, 2023
7cd2dc1
fix algo
leaves-zwx Feb 12, 2023
1eed7aa
fix
leaves-zwx Feb 12, 2023
3dc723b
cpp test
leaves-zwx Feb 12, 2023
cbf7261
fix
leaves-zwx Feb 12, 2023
a440e8d
rename
leaves-zwx Feb 12, 2023
6998d52
refine UserOp::GetNdSbpSignatureList
leaves-zwx Feb 12, 2023
94870a6
ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 12, 2023
bff5183
py test
leaves-zwx Feb 12, 2023
2db180a
revert changes
leaves-zwx Feb 13, 2023
84b3acf
new ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 13, 2023
ad4a227
add test case
leaves-zwx Feb 13, 2023
7c22e1a
refine algorithm
leaves-zwx Feb 15, 2023
c51364a
update test and fix bug
leaves-zwx Feb 16, 2023
db658eb
rm comment
leaves-zwx Feb 16, 2023
278f577
rm redundant condition
leaves-zwx Feb 16, 2023
c500114
DeduplicateNdSbpSignatureList
leaves-zwx Feb 17, 2023
a4ddd55
refine GenRankMeshSubset
leaves-zwx Feb 17, 2023
508cc6f
add comments
leaves-zwx Feb 17, 2023
e5451c2
suppress warning
leaves-zwx Feb 18, 2023
b3c9ca9
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 18, 2023
66d3923
suppress warning
leaves-zwx Feb 18, 2023
5e4aebc
suppress warning
leaves-zwx Feb 18, 2023
670eca3
refine sort algorithm
leaves-zwx Feb 21, 2023
351083e
refine sort algorithm and test
leaves-zwx Feb 21, 2023
e6b579b
rm calling FilterNdSbpIn2OutSignatures in EnumerateNdSbpSignatures
leaves-zwx Feb 21, 2023
ca9eb24
change example
leaves-zwx Feb 21, 2023
abe42db
Update oneflow/user/ops/reshape_user_op_util.cpp
leaves-zwx Feb 21, 2023
4836822
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 21, 2023
7bacbcf
refine sort algorithm
leaves-zwx Feb 25, 2023
f578fcf
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 25, 2023
d3cf09d
change sort order
leaves-zwx Feb 28, 2023
13dd1be
use FilterNdSbpByLogicalShape
leaves-zwx Feb 28, 2023
70ae77e
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class GetNdSbpSignatureListContext {
virtual ~GetNdSbpSignatureListContext() = default;

virtual void AddNdSbpSignature(NdSbpSignature&) = 0;
virtual std::vector<NdSbpSignature>* MutNdSbpSignatureList() = 0;
virtual const Shape& parallel_hierarchy() = 0;
virtual const Shape& BlobShape4InputArgNameAndIndex(const std::string& arg_name,
int32_t index) const = 0;
Expand Down
63 changes: 63 additions & 0 deletions oneflow/core/framework/sbp_infer_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,69 @@ void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dim
}
}

namespace {

// give a mesure value for NdSbp for sorting
size_t MesureNdSbp(const NdSbp& nd_sbp) {
// max split axis = 8 (start from 1), + B + P = 10
constexpr size_t kMaxSplitAxis = 8;
constexpr size_t kCarryDigit = kMaxSplitAxis + 3;
size_t value = 0;
for (int i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {
size_t cur_dim_value = 0;
const auto& sbp = nd_sbp.sbp_parallel(i);
if (sbp.has_split_parallel()) {
CHECK_LT(sbp.split_parallel().axis(), kMaxSplitAxis);
// from 1 to 8
cur_dim_value = sbp.split_parallel().axis() + 1;
} else if (sbp.has_broadcast_parallel()) {
// 9
cur_dim_value = kMaxSplitAxis + 1;
} else if (sbp.has_partial_sum_parallel()) {
// 10
cur_dim_value = kMaxSplitAxis + 2;
} else {
UNIMPLEMENTED();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

别的地方没什么问题了。

这里我想了一下,上次我们开会时是讨论到说用97作为不同blob的进位对吧,如果说NdSbp的数字不超过96,是不会有任何风险的。
那现在把B映射到9,P映射到10,而进位是11,这样(B, B)就会是9*11+9 = 108,是超过了97的。
而B是出现很频繁的SBP,也就是风险会出现得频繁一些。
但是如果把B映射到1,P映射到2,Si -> i+3,这样要超过96起码是 88+9,也就是 (S5, S6),(S5, S7)或者 (S6, 任意SBP) 才有可能有风险。

实际中S5少见,更不用说 (S5, S6)了,甚至 (S5, S5) 都是没有问题的。要出问题,至少有一个S6或者 S7,也就是张量起码要有7维。

所以把B,P映射的数字前调能够很有效地避免大部分的风险。(当然即使有风险也不一定会出问题,素数能有效地规避掉一些,但是如果能够避免大部分的风险,还是避免的好)

value = value * kCarryDigit + cur_dim_value;
}
return value;
}

size_t MesureNdSbpSignature(const NdSbpSignature& nd_sbp_sig, const std::vector<std::string>& bns) {
// big enough for 2d-sbp signatrue set
// if want to extend to 3d-sbp, consider increase to 170
constexpr size_t kCarryDigit = 97;
size_t value = 0;
for (size_t i = 0; i < bns.size(); ++i) {
auto nd_sbp_it = nd_sbp_sig.bn_in_op2nd_sbp().find(bns[i]);
CHECK(nd_sbp_it != nd_sbp_sig.bn_in_op2nd_sbp().end())
<< "can't find bn (" << bns[i] << ") in " << PbMessage2TxtString(nd_sbp_sig);
size_t cur_arg_value = MesureNdSbp(nd_sbp_it->second);
CHECK_LE(value + cur_arg_value / kCarryDigit, std::numeric_limits<size_t>::max() / kCarryDigit);
value = value * kCarryDigit + cur_arg_value;
}
return value;
}

} // namespace

void DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list,
const std::vector<std::string>& bns) {
if (bns.size() > 8) { return; }
std::map<size_t, NdSbpSignature> value2nd_sbp_sig;
for (auto& nd_sbp_sig : *nd_sbp_sig_list) {
size_t order_value = MesureNdSbpSignature(nd_sbp_sig, bns);
if (value2nd_sbp_sig.find(order_value) == value2nd_sbp_sig.end()) {
value2nd_sbp_sig.emplace(order_value, std::move(nd_sbp_sig));
}
}
nd_sbp_sig_list->clear();
for (auto& nd_sbp_pair : value2nd_sbp_sig) {
nd_sbp_sig_list->emplace_back(std::move(nd_sbp_pair.second));
}
}

// Compute storage per device for given NdSbp
double Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy) {
if (nd_sbp.sbp_parallel_size() == 1) {
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/framework/sbp_infer_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dim
const HashMap<int32_t, SbpSignatureList>& hierarchy_value2sbp_sig_list,
std::vector<NdSbpSignature>* nd_sbp_sig_list);

void DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list,
const std::vector<std::string>& bns);

// Compute storage for given NdSbp
double Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy);

Expand Down
142 changes: 142 additions & 0 deletions oneflow/core/framework/sbp_infer_util_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/sbp_infer_util.h"

#include <gtest/gtest.h>

namespace oneflow {
namespace test {

namespace {

bool ParseNdSbpSignatureFromString(const std::string& nd_sbp_signature_str,
NdSbpSignature& nd_sbp_signature) {
auto* bn2nd_sbp = nd_sbp_signature.mutable_bn_in_op2nd_sbp();
std::string arg_name = "in";
bool meet_nd_sbp_group = false;
bool meet_split = false;
int nd_sbp_group_id = 0;
std::vector<std::string> nd_sbp_str_group;
size_t pos = 0;
while (pos < nd_sbp_signature_str.size()) {
const char& c = nd_sbp_signature_str[pos];
pos++;
if (c == ' ') {
continue;
} else if (c == '(') {
if (!meet_nd_sbp_group) {
// enter a nd-sbp group
meet_nd_sbp_group = true;
nd_sbp_str_group.emplace_back();
continue;
} else {
// meet left parentheses of S(x)
meet_split = true;
}
} else if (c == ')') {
if (meet_split) {
// meet right parentheses of S(x)
meet_split = false;
} else if (meet_nd_sbp_group) {
// leave a nd-sbp group
meet_nd_sbp_group = false;
std::string bn = arg_name + "_" + std::to_string(nd_sbp_group_id);
if (!ParseNdSbpFromStringList(nd_sbp_str_group, &(*bn2nd_sbp)[bn])) { return false; }
nd_sbp_str_group.clear();
continue;
} else {
return false;
}
} else if (c == ',') {
if (meet_nd_sbp_group) {
nd_sbp_str_group.emplace_back();
} else {
nd_sbp_group_id += 1;
}
continue;
} else if (c == '-') {
if (pos < nd_sbp_signature_str.size() && nd_sbp_signature_str[pos] == '>') {
// in args parsing has finished, parse out args
arg_name = "out";
nd_sbp_group_id = 0;
// skip '>' in substr '->'
pos++;
continue;
} else {
return false;
}
} else {
// do nothing
}
nd_sbp_str_group.back() += c;
}
return true;
}

void TestDeduplicateNdSbpSignature(const std::vector<std::string>& nd_sbp_signature_str_list,
const std::vector<std::string>& bns) {
std::vector<NdSbpSignature> nd_sbp_sig_list;
nd_sbp_sig_list.reserve(nd_sbp_signature_str_list.size());
for (const auto& nd_sbp_signature_str : nd_sbp_signature_str_list) {
nd_sbp_sig_list.emplace_back();
ASSERT_TRUE(ParseNdSbpSignatureFromString(nd_sbp_signature_str, nd_sbp_sig_list.back()));
}
std::random_device rd;
std::mt19937 gen(rd());
std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen);
nd_sbp_sig_list.reserve(nd_sbp_sig_list.size() + nd_sbp_sig_list.size() / 2);
std::copy_n(nd_sbp_sig_list.begin(), nd_sbp_sig_list.size() / 2,
std::back_inserter(nd_sbp_sig_list));
std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen);
DeduplicateNdSbpSignatureList(&nd_sbp_sig_list, bns);
}

} // namespace

TEST(SbpInferUtil, DeduplicateNdSbpSignatureList) {
TestDeduplicateNdSbpSignature(
{
"(S(0), S(0)) -> (S(0), S(0))",
"(S(0), S(1)) -> (S(0), S(1))",
"(S(0), S(3)) -> (S(0), S(2))",
"(S(0), B) -> (S(0), B)",
"(S(0), P) -> (S(0), P)",
"(S(1), S(0)) -> (S(1), S(0))",
"(S(1), S(1)) -> (S(1), S(1))",
"(S(1), S(3)) -> (S(1), S(2))",
"(S(1), B) -> (S(1), B)",
"(S(1), P) -> (S(1), P)",
"(S(3), S(0)) -> (S(2), S(0))",
"(S(3), S(1)) -> (S(2), S(1))",
"(S(3), S(3)) -> (S(2), S(2))",
"(S(3), B) -> (S(2), B)",
"(S(3), P) -> (S(2), P)",
"(B, S(0)) -> (B, S(0))",
"(B, S(1)) -> (B, S(1))",
"(B, S(3)) -> (B, S(2))",
"(B, B) -> (B, B)",
"(B, P) -> (B, P)",
"(P, S(0)) -> (P, S(0))",
"(P, S(1)) -> (P, S(1))",
"(P, S(3)) -> (P, S(2))",
"(P, B) -> (P, B)",
"(P, P) -> (P, P)",
},
{"in_0", "out_0"});
}

} // namespace test
} // namespace oneflow
5 changes: 5 additions & 0 deletions oneflow/core/framework/user_op_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ OpRegistry& OpRegistry::SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn get_n
return *this;
}

OpRegistry& OpRegistry::SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn) {
result_.enumerate_nd_sbp_signatures_fn = std::move(fn);
return *this;
}

OpRegistry& OpRegistry::SetDumpNdSbpSignatureForOpConfFn(
Operator::DumpNdSbpSignatureForOpConfFn fn) {
result_.dump_nd_sbp_signature_for_op_conf_fn = std::move(fn);
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/framework/user_op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ using NdSbpInferFn = std::function<Maybe<void>(InferNdSbpFnContext*)>;
using ComputeComplexityFn = std::function<Maybe<double>(ComputeComplexityFnContext*)>;
// TODO: set up another context
using GetNdSbpSignatureListFn = std::function<Maybe<void>(GetNdSbpSignatureListContext*)>;
using EnumerateNdSbpSignaturesFn = std::function<Maybe<void>(GetNdSbpSignatureListContext*)>;

struct OpRegistryResult {
OpRegistryResult()
Expand Down Expand Up @@ -94,6 +95,7 @@ struct OpRegistryResult {
NdSbpInferFn nd_sbp_infer_fn;
ComputeComplexityFn compute_complexity_fn;
GetNdSbpSignatureListFn get_nd_sbp_list_fn;
EnumerateNdSbpSignaturesFn enumerate_nd_sbp_signatures_fn;
Operator::DumpNdSbpSignatureForOpConfFn dump_nd_sbp_signature_for_op_conf_fn;
};

Expand Down Expand Up @@ -143,6 +145,7 @@ class OpRegistry final {
OpRegistry& SetDeviceAndStreamInferFn(DeviceAndStreamInferFn fn);
OpRegistry& SetComputeComplexityFn(ComputeComplexityFn fn);
OpRegistry& SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn fn);
OpRegistry& SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn);
OpRegistry& SetDumpNdSbpSignatureForOpConfFn(Operator::DumpNdSbpSignatureForOpConfFn fn);

Maybe<OpRegistry&> Finish();
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/operator/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,12 @@ Maybe<void> Operator::GetSbpSignaturesIf(
return Maybe<void>::Ok();
}

Maybe<void> Operator::EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
return Maybe<void>::Ok();
}

Maybe<void> Operator::GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
Expand Down Expand Up @@ -536,6 +542,7 @@ Maybe<void> Operator::GetNdSbpSignatureList(
CHECK_OR_RETURN(nd_sbp_sig_list->empty());
DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, *parallel_desc.hierarchy(),
hierarchy_value2sbp_sig_list, nd_sbp_sig_list);
JUST(EnumerateNdSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list));
return Maybe<void>::Ok();
}

Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/operator/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class Operator {
Maybe<void> GetSbpSignaturesIf(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const;
virtual Maybe<void> EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const;
virtual Maybe<void> GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const;
Expand Down
15 changes: 15 additions & 0 deletions oneflow/core/operator/user_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ class UserOpGetNdSbpSignatureListContext : public user_op::GetNdSbpSignatureList
nd_sbp_sig_list_(nd_sbp_sig_list) {}
~UserOpGetNdSbpSignatureListContext() override = default;

std::vector<NdSbpSignature>* MutNdSbpSignatureList() override { return nd_sbp_sig_list_; }

void AddNdSbpSignature(NdSbpSignature& nd_sbp_sig) override {
nd_sbp_sig_list_->emplace_back(nd_sbp_sig);
}
Expand Down Expand Up @@ -992,6 +994,19 @@ Maybe<void> UserOp::InferNdSbpSignature(
return Maybe<void>::Ok();
}

Maybe<void> UserOp::EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
if (val_->enumerate_nd_sbp_signatures_fn) {
NdSbpSignature empty_sbp_signature;
UserOpGetNdSbpSignatureListContext user_op_get_nd_sbp_list_context(
this, LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list);
return val_->enumerate_nd_sbp_signatures_fn(&user_op_get_nd_sbp_list_context);
} else {
return Operator::EnumerateNdSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list);
}
}

Maybe<void> UserOp::GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/operator/user_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class UserOp final : public Operator {
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const override;
Maybe<void> EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc,
std::vector<NdSbpSignature>* nd_sbp_sig_list) const override;
Maybe<void> GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc,
Expand Down
1 change: 1 addition & 0 deletions oneflow/ir/include/OneFlow/OneFlowBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class OneFlow_BaseOp<string mnemonic, list<Trait> traits = []> :
bit has_nd_sbp_infer_fn = 0;
bit has_compute_complexity_fn = 0;
bit has_get_nd_sbp_fn = 0;
bit has_enumerate_nd_sbp_signatures_fn = 0;
bit has_dump_nd_sbp_signature_for_op_conf_fn = 0;
}

Expand Down
1 change: 1 addition & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8264,6 +8264,7 @@ def OneFlow_ReshapeOp : OneFlow_BaseOp<"reshape", [NoSideEffect, DeclareOpInterf
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_enumerate_nd_sbp_signatures_fn = 1;
let has_data_type_infer_fn = 1;
let hasFolder = 1;
}
Expand Down
Loading