Skip to content

Commit

Permalink
Supplement ND SBP signatures for reshape op (#9858)
Browse files Browse the repository at this point in the history
对于绝大多数 op 来说,我们只需要列举其可能支持的 1-D SBP signatures,在 2-D (可推及到 ND) 时,对其 1-D
SBP signature list 做叉积即可得到 2-D SBP signature list。

但 reshape op 某些情况就属于例外,见如下的例子:

```
(8, 4) reshape to (2, 4, 4)
with 2x2 ranks, has the 1-D sbp signatrue list as below:
  S(0) -> S(0)
  S(1) -> S(2)
by cross product, get the 2-D sbp signatrue list as below:
  [S(0), S(1)] -> [S(0), S(2)]
  [S(1), S(0)] -> [S(2), S(0)]  (will bring huge comm cost, ignore it in later discuss)
but below 2-D sbp signatrue is supported too:
  [S(0), S(0)] -> [S(0), S(1)]
```

从上面的例子中可以发现一个规律:**高维的 SBP signatures 不能完全由低维组合而来**。

基于以上理由,为 op 提供一个新的重载函数 **EnumerateNdSbpSignatures**:其会在 1-D SBP
signatures 被列举完后,并由 1-D 叉积产生了 2-D SBP signature list 后被调用。作为当 1-D
叉积不能产生全部的 2-D SBP signatures 的时候,提供一种手段来补充额外的 2-D SBP signatures。

为 reshape 实现了 EnumerateNdSbpSignatures,算法简单来说就是,找到那些被 reshape 的
dimension,从高到低连续按 rank num 切分,直到失败,或者能均匀切到每个 rank 上(从高到底是为了保证切分连续性)。

EnumerateNdSbpSignatures 与已有的 **GetNdSbpSignatureList**
重载区别是:EnumerateNdSbpSignatures 是在 1-D SBP signatures 叉积之后的额外补充。而
GetNdSbpSignatureList 是完全重载 2-D SBP signatures 的列举逻辑,不会包含 1-D SBP
的叉积生成,其主要作用是为了 source op,用户可以直接通过 attr 来设置输出的 sbp,而无需推导。

---------

Co-authored-by: Yipeng Li <jamesonli1313@gmail.com>
  • Loading branch information
leaves-zwx and Yipeng1994 authored Feb 28, 2023
1 parent 6442a25 commit 9f47744
Show file tree
Hide file tree
Showing 20 changed files with 1,476 additions and 3 deletions.
1 change: 1 addition & 0 deletions oneflow/core/framework/get_nd_sbp_signature_list_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class GetNdSbpSignatureListContext {
virtual ~GetNdSbpSignatureListContext() = default;

virtual void AddNdSbpSignature(NdSbpSignature&) = 0;
virtual std::vector<NdSbpSignature>* MutNdSbpSignatureList() = 0;
virtual const Shape& parallel_hierarchy() = 0;
virtual const Shape& BlobShape4InputArgNameAndIndex(const std::string& arg_name,
int32_t index) const = 0;
Expand Down
61 changes: 61 additions & 0 deletions oneflow/core/framework/sbp_infer_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,67 @@ void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dim
}
}

namespace {

// give a mesure value for NdSbp for sorting
size_t MesureNdSbp(const NdSbp& nd_sbp) {
// start from 1, B + P + max split axis (8)
constexpr size_t kMaxSplitAxis = 8;
constexpr size_t kCarryDigit = kMaxSplitAxis + 3;
size_t value = 0;
for (int i = 0; i < nd_sbp.sbp_parallel_size(); ++i) {
size_t cur_dim_value = 0;
const auto& sbp = nd_sbp.sbp_parallel(i);
if (sbp.has_broadcast_parallel()) {
cur_dim_value = 1;
} else if (sbp.has_partial_sum_parallel()) {
cur_dim_value = 2;
} else if (sbp.has_split_parallel()) {
CHECK_LT(sbp.split_parallel().axis(), kMaxSplitAxis);
// from 3 to 10
cur_dim_value = 3 + sbp.split_parallel().axis();
} else {
UNIMPLEMENTED();
}
value = value * kCarryDigit + cur_dim_value;
}
return value;
}

size_t MesureNdSbpSignature(const NdSbpSignature& nd_sbp_sig, const std::vector<std::string>& bns) {
// big enough for 2d-sbp signatrue set
// if want to extend to 3d-sbp, consider increase to 170
constexpr size_t kCarryDigit = 97;
size_t value = 0;
for (size_t i = 0; i < bns.size(); ++i) {
auto nd_sbp_it = nd_sbp_sig.bn_in_op2nd_sbp().find(bns[i]);
CHECK(nd_sbp_it != nd_sbp_sig.bn_in_op2nd_sbp().end())
<< "can't find bn (" << bns[i] << ") in " << PbMessage2TxtString(nd_sbp_sig);
size_t cur_arg_value = MesureNdSbp(nd_sbp_it->second);
CHECK_LE(value + cur_arg_value / kCarryDigit, std::numeric_limits<size_t>::max() / kCarryDigit);
value = value * kCarryDigit + cur_arg_value;
}
return value;
}

} // namespace

void DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list,
const std::vector<std::string>& bns) {
if (bns.size() > 8) { return; }
std::map<size_t, NdSbpSignature> value2nd_sbp_sig;
for (auto& nd_sbp_sig : *nd_sbp_sig_list) {
size_t order_value = MesureNdSbpSignature(nd_sbp_sig, bns);
if (value2nd_sbp_sig.find(order_value) == value2nd_sbp_sig.end()) {
value2nd_sbp_sig.emplace(order_value, std::move(nd_sbp_sig));
}
}
nd_sbp_sig_list->clear();
for (auto& nd_sbp_pair : value2nd_sbp_sig) {
nd_sbp_sig_list->emplace_back(std::move(nd_sbp_pair.second));
}
}

// Compute storage per device for given NdSbp
double Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy) {
if (nd_sbp.sbp_parallel_size() == 1) {
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/framework/sbp_infer_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dim
const HashMap<int32_t, SbpSignatureList>& hierarchy_value2sbp_sig_list,
std::vector<NdSbpSignature>* nd_sbp_sig_list);

void DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list,
const std::vector<std::string>& bns);

// Compute storage for given NdSbp
double Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy);

Expand Down
180 changes: 180 additions & 0 deletions oneflow/core/framework/sbp_infer_util_test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/framework/nd_sbp.h"

#include <gtest/gtest.h>

namespace oneflow {
namespace test {

namespace {

bool ParseNdSbpSignatureFromString(const std::string& nd_sbp_signature_str,
NdSbpSignature& nd_sbp_signature) {
auto* bn2nd_sbp = nd_sbp_signature.mutable_bn_in_op2nd_sbp();
std::string arg_name = "in";
bool meet_nd_sbp_group = false;
bool meet_split = false;
int nd_sbp_group_id = 0;
std::vector<std::string> nd_sbp_str_group;
size_t pos = 0;
while (pos < nd_sbp_signature_str.size()) {
const char& c = nd_sbp_signature_str[pos];
pos++;
if (c == ' ') {
continue;
} else if (c == '(') {
if (!meet_nd_sbp_group) {
// enter a nd-sbp group
meet_nd_sbp_group = true;
nd_sbp_str_group.emplace_back();
continue;
} else {
// meet left parentheses of S(x)
meet_split = true;
}
} else if (c == ')') {
if (meet_split) {
// meet right parentheses of S(x)
meet_split = false;
} else if (meet_nd_sbp_group) {
// leave a nd-sbp group
meet_nd_sbp_group = false;
std::string bn = arg_name + "_" + std::to_string(nd_sbp_group_id);
if (!ParseNdSbpFromStringList(nd_sbp_str_group, &(*bn2nd_sbp)[bn])) { return false; }
nd_sbp_str_group.clear();
continue;
} else {
return false;
}
} else if (c == ',') {
if (meet_nd_sbp_group) {
nd_sbp_str_group.emplace_back();
} else {
nd_sbp_group_id += 1;
}
continue;
} else if (c == '-') {
if (pos < nd_sbp_signature_str.size() && nd_sbp_signature_str[pos] == '>') {
// in args parsing has finished, parse out args
arg_name = "out";
nd_sbp_group_id = 0;
// skip '>' in substr '->'
pos++;
continue;
} else {
return false;
}
} else {
// do nothing
}
nd_sbp_str_group.back() += c;
}
return true;
}

std::string NdSbpSignature2String(const NdSbpSignature& nd_sbp_signature,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
std::ostringstream ss;
auto BnNdSbpToString = [&](const std::string& bn) {
auto iter = nd_sbp_signature.bn_in_op2nd_sbp().find(bn);
CHECK(iter != nd_sbp_signature.bn_in_op2nd_sbp().end());
ss << NdSbpToString(iter->second);
};
auto ArgsNdSbpToString = [&](const std::vector<std::string>& arg_bns) {
for (size_t i = 0; i < arg_bns.size(); ++i) {
if (i > 0) { ss << ", "; }
BnNdSbpToString(arg_bns[i]);
}
};
ArgsNdSbpToString(inputs);
ss << " -> ";
ArgsNdSbpToString(outputs);
return ss.str();
}

void TestDeduplicateNdSbpSignature(const std::vector<std::string>& nd_sbp_signature_str_list,
const std::vector<std::string>& input_bns,
const std::vector<std::string>& output_bns) {
// parse
std::vector<NdSbpSignature> nd_sbp_sig_list;
nd_sbp_sig_list.reserve(nd_sbp_signature_str_list.size());
for (const auto& nd_sbp_signature_str : nd_sbp_signature_str_list) {
nd_sbp_sig_list.emplace_back();
ASSERT_TRUE(ParseNdSbpSignatureFromString(nd_sbp_signature_str, nd_sbp_sig_list.back()));
}

// shuffle and repeat
std::random_device rd;
std::mt19937 gen(rd());
std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen);
nd_sbp_sig_list.reserve(nd_sbp_sig_list.size() + nd_sbp_sig_list.size() / 2);
std::copy_n(nd_sbp_sig_list.begin(), nd_sbp_sig_list.size() / 2,
std::back_inserter(nd_sbp_sig_list));
std::shuffle(nd_sbp_sig_list.begin(), nd_sbp_sig_list.end(), gen);

// dedup and sort
std::vector<std::string> bns;
bns.insert(bns.end(), input_bns.begin(), input_bns.end());
bns.insert(bns.end(), output_bns.begin(), output_bns.end());
DeduplicateNdSbpSignatureList(&nd_sbp_sig_list, bns);

// compare
ASSERT_EQ(nd_sbp_signature_str_list.size(), nd_sbp_sig_list.size());
for (size_t i = 0; i < nd_sbp_sig_list.size(); ++i) {
auto nd_sbp_sig_result = NdSbpSignature2String(nd_sbp_sig_list[i], input_bns, output_bns);
ASSERT_EQ(nd_sbp_sig_result, nd_sbp_signature_str_list[i]);
}
}

} // namespace

TEST(SbpInferUtil, DeduplicateNdSbpSignatureList) {
TestDeduplicateNdSbpSignature(
{
"(B, B) -> (B, B)",
"(B, P) -> (B, P)",
"(B, S(0)) -> (B, S(0))",
"(B, S(1)) -> (B, S(1))",
"(B, S(3)) -> (B, S(2))",
"(P, B) -> (P, B)",
"(P, P) -> (P, P)",
"(P, S(0)) -> (P, S(0))",
"(P, S(1)) -> (P, S(1))",
"(P, S(3)) -> (P, S(2))",
"(S(0), B) -> (S(0), B)",
"(S(0), P) -> (S(0), P)",
"(S(0), S(0)) -> (S(0), S(0))",
"(S(0), S(1)) -> (S(0), S(1))",
"(S(0), S(3)) -> (S(0), S(2))",
"(S(1), B) -> (S(1), B)",
"(S(1), P) -> (S(1), P)",
"(S(1), S(0)) -> (S(1), S(0))",
"(S(1), S(1)) -> (S(1), S(1))",
"(S(1), S(3)) -> (S(1), S(2))",
"(S(3), B) -> (S(2), B)",
"(S(3), P) -> (S(2), P)",
"(S(3), S(0)) -> (S(2), S(0))",
"(S(3), S(1)) -> (S(2), S(1))",
"(S(3), S(3)) -> (S(2), S(2))",
},
{"in_0"}, {"out_0"});
}

} // namespace test
} // namespace oneflow
5 changes: 5 additions & 0 deletions oneflow/core/framework/user_op_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ OpRegistry& OpRegistry::SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn get_n
return *this;
}

OpRegistry& OpRegistry::SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn) {
result_.enumerate_nd_sbp_signatures_fn = std::move(fn);
return *this;
}

OpRegistry& OpRegistry::SetDumpNdSbpSignatureForOpConfFn(
Operator::DumpNdSbpSignatureForOpConfFn fn) {
result_.dump_nd_sbp_signature_for_op_conf_fn = std::move(fn);
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/framework/user_op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ using NdSbpInferFn = std::function<Maybe<void>(InferNdSbpFnContext*)>;
using ComputeComplexityFn = std::function<Maybe<double>(ComputeComplexityFnContext*)>;
// TODO: set up another context
using GetNdSbpSignatureListFn = std::function<Maybe<void>(GetNdSbpSignatureListContext*)>;
using EnumerateNdSbpSignaturesFn = std::function<Maybe<void>(GetNdSbpSignatureListContext*)>;

struct OpRegistryResult {
OpRegistryResult()
Expand Down Expand Up @@ -94,6 +95,7 @@ struct OpRegistryResult {
NdSbpInferFn nd_sbp_infer_fn;
ComputeComplexityFn compute_complexity_fn;
GetNdSbpSignatureListFn get_nd_sbp_list_fn;
EnumerateNdSbpSignaturesFn enumerate_nd_sbp_signatures_fn;
Operator::DumpNdSbpSignatureForOpConfFn dump_nd_sbp_signature_for_op_conf_fn;
};

Expand Down Expand Up @@ -143,6 +145,7 @@ class OpRegistry final {
OpRegistry& SetDeviceAndStreamInferFn(DeviceAndStreamInferFn fn);
OpRegistry& SetComputeComplexityFn(ComputeComplexityFn fn);
OpRegistry& SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn fn);
OpRegistry& SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn);
OpRegistry& SetDumpNdSbpSignatureForOpConfFn(Operator::DumpNdSbpSignatureForOpConfFn fn);

Maybe<OpRegistry&> Finish();
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/operator/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,12 @@ Maybe<void> Operator::GetSbpSignaturesIf(
return Maybe<void>::Ok();
}

Maybe<void> Operator::EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
return Maybe<void>::Ok();
}

Maybe<void> Operator::GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
Expand Down Expand Up @@ -537,6 +543,7 @@ Maybe<void> Operator::GetNdSbpSignatureList(
CHECK_OR_RETURN(nd_sbp_sig_list->empty());
DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, *parallel_desc.hierarchy(),
hierarchy_value2sbp_sig_list, nd_sbp_sig_list);
JUST(EnumerateNdSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list));
return Maybe<void>::Ok();
}

Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/operator/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class Operator {
Maybe<void> GetSbpSignaturesIf(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const;
virtual Maybe<void> EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const;
virtual Maybe<void> GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const;
Expand Down
15 changes: 15 additions & 0 deletions oneflow/core/operator/user_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ class UserOpGetNdSbpSignatureListContext : public user_op::GetNdSbpSignatureList
nd_sbp_sig_list_(nd_sbp_sig_list) {}
~UserOpGetNdSbpSignatureListContext() override = default;

std::vector<NdSbpSignature>* MutNdSbpSignatureList() override { return nd_sbp_sig_list_; }

void AddNdSbpSignature(NdSbpSignature& nd_sbp_sig) override {
nd_sbp_sig_list_->emplace_back(nd_sbp_sig);
}
Expand Down Expand Up @@ -992,6 +994,19 @@ Maybe<void> UserOp::InferNdSbpSignature(
return Maybe<void>::Ok();
}

Maybe<void> UserOp::EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
if (val_->enumerate_nd_sbp_signatures_fn) {
NdSbpSignature empty_sbp_signature;
UserOpGetNdSbpSignatureListContext user_op_get_nd_sbp_list_context(
this, LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list);
return val_->enumerate_nd_sbp_signatures_fn(&user_op_get_nd_sbp_list_context);
} else {
return Operator::EnumerateNdSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list);
}
}

Maybe<void> UserOp::GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/operator/user_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class UserOp final : public Operator {
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const override;
Maybe<void> EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc,
std::vector<NdSbpSignature>* nd_sbp_sig_list) const override;
Maybe<void> GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc,
Expand Down
1 change: 1 addition & 0 deletions oneflow/ir/include/OneFlow/OneFlowBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class OneFlow_BaseOp<string mnemonic, list<Trait> traits = []> :
bit has_nd_sbp_infer_fn = 0;
bit has_compute_complexity_fn = 0;
bit has_get_nd_sbp_fn = 0;
bit has_enumerate_nd_sbp_signatures_fn = 0;
bit has_dump_nd_sbp_signature_for_op_conf_fn = 0;
}

Expand Down
1 change: 1 addition & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8220,6 +8220,7 @@ def OneFlow_ReshapeOp : OneFlow_BaseOp<"reshape", [NoSideEffect, DeclareOpInterf
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_enumerate_nd_sbp_signatures_fn = 1;
let has_data_type_infer_fn = 1;
let hasFolder = 1;
}
Expand Down
Loading

0 comments on commit 9f47744

Please sign in to comment.