Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supplement ND SBP signatures for reshape op #9858

Merged
merged 35 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
663cf30
EnumerateNdSbpSignatures
leaves-zwx Feb 9, 2023
dccb762
ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 11, 2023
7cd2dc1
fix algo
leaves-zwx Feb 12, 2023
1eed7aa
fix
leaves-zwx Feb 12, 2023
3dc723b
cpp test
leaves-zwx Feb 12, 2023
cbf7261
fix
leaves-zwx Feb 12, 2023
a440e8d
rename
leaves-zwx Feb 12, 2023
6998d52
refine UserOp::GetNdSbpSignatureList
leaves-zwx Feb 12, 2023
94870a6
ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 12, 2023
bff5183
py test
leaves-zwx Feb 12, 2023
2db180a
revert changes
leaves-zwx Feb 13, 2023
84b3acf
new ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 13, 2023
ad4a227
add test case
leaves-zwx Feb 13, 2023
7c22e1a
refine algorithm
leaves-zwx Feb 15, 2023
c51364a
update test and fix bug
leaves-zwx Feb 16, 2023
db658eb
rm comment
leaves-zwx Feb 16, 2023
278f577
rm redundant condition
leaves-zwx Feb 16, 2023
c500114
DeduplicateNdSbpSignatureList
leaves-zwx Feb 17, 2023
a4ddd55
refine GenRankMeshSubset
leaves-zwx Feb 17, 2023
508cc6f
add comments
leaves-zwx Feb 17, 2023
e5451c2
suppress warning
leaves-zwx Feb 18, 2023
b3c9ca9
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 18, 2023
66d3923
suppress warning
leaves-zwx Feb 18, 2023
5e4aebc
suppress warning
leaves-zwx Feb 18, 2023
670eca3
refine sort algorithm
leaves-zwx Feb 21, 2023
351083e
refine sort algorithm and test
leaves-zwx Feb 21, 2023
e6b579b
rm calling FilterNdSbpIn2OutSignatures in EnumerateNdSbpSignatures
leaves-zwx Feb 21, 2023
ca9eb24
change example
leaves-zwx Feb 21, 2023
abe42db
Update oneflow/user/ops/reshape_user_op_util.cpp
leaves-zwx Feb 21, 2023
4836822
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 21, 2023
7bacbcf
refine sort algorithm
leaves-zwx Feb 25, 2023
f578fcf
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 25, 2023
d3cf09d
change sort order
leaves-zwx Feb 28, 2023
13dd1be
use FilterNdSbpByLogicalShape
leaves-zwx Feb 28, 2023
70ae77e
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class GetNdSbpSignatureListContext {
virtual ~GetNdSbpSignatureListContext() = default;

virtual void AddNdSbpSignature(NdSbpSignature&) = 0;
virtual std::vector<NdSbpSignature>* MutNdSbpSignatureList() = 0;
virtual const Shape& parallel_hierarchy() = 0;
virtual const Shape& BlobShape4InputArgNameAndIndex(const std::string& arg_name,
int32_t index) const = 0;
Expand Down
88 changes: 88 additions & 0 deletions oneflow/core/framework/sbp_infer_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,6 +787,94 @@ void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dim
}
}

// return -1 indicate lhs < rhs, 1 indicate lhs > rhs, 0 indicate lhs == rhs
int CompareSbp(const SbpParallel& lhs_sbp, const SbpParallel& rhs_sbp) {
if (lhs_sbp.has_split_parallel() && rhs_sbp.has_split_parallel()) {
if (lhs_sbp.split_parallel().axis() < rhs_sbp.split_parallel().axis()) {
return -1;
} else if (lhs_sbp.split_parallel().axis() > rhs_sbp.split_parallel().axis()) {
return 1;
} else {
return 0;
}
} else if (lhs_sbp.has_split_parallel() && !rhs_sbp.has_split_parallel()) {
return -1;
} else if (!lhs_sbp.has_split_parallel() && rhs_sbp.has_split_parallel()) {
return 1;
} else {
if (lhs_sbp.has_broadcast_parallel() && rhs_sbp.has_broadcast_parallel()) {
return 0;
} else if (lhs_sbp.has_broadcast_parallel() && !rhs_sbp.has_broadcast_parallel()) {
return -1;
} else if (!lhs_sbp.has_broadcast_parallel() && rhs_sbp.has_broadcast_parallel()) {
return 1;
} else {
// both P
return 0;
}
}
}

// return -1 indicate lhs < rhs, 1 indicate lhs > rhs, 0 indicate lhs == rhs
int CompareNdSbp(const NdSbp& lhs_nd_sbp, const NdSbp& rhs_nd_sbp) {
if (lhs_nd_sbp.sbp_parallel_size() < rhs_nd_sbp.sbp_parallel_size()) {
return -1;
} else if (lhs_nd_sbp.sbp_parallel_size() > rhs_nd_sbp.sbp_parallel_size()) {
return 1;
} else {
for (int i = 0; i < lhs_nd_sbp.sbp_parallel_size(); ++i) {
const auto& lhs_sbp = lhs_nd_sbp.sbp_parallel(i);
const auto& rhs_sbp = rhs_nd_sbp.sbp_parallel(i);
auto cmp_ret = CompareSbp(lhs_sbp, rhs_sbp);
if (cmp_ret != 0) { return cmp_ret; }
}
return 0;
}
}

// return -1 indicate lhs < rhs, 1 indicate lhs > rhs, 0 indicate lhs == rhs
int CompareNdSbpSignature(const NdSbpSignature& lhs_nd_sbp_sig,
const NdSbpSignature& rhs_nd_sbp_sig) {
CHECK_EQ(lhs_nd_sbp_sig.bn_in_op2nd_sbp_size(), rhs_nd_sbp_sig.bn_in_op2nd_sbp_size());
for (const auto& lhs_bn_nd_sbp : lhs_nd_sbp_sig.bn_in_op2nd_sbp()) {
const auto& bn = lhs_bn_nd_sbp.first;
auto rhs_bn_nd_sbp_it = rhs_nd_sbp_sig.bn_in_op2nd_sbp().find(bn);
CHECK(rhs_bn_nd_sbp_it != rhs_nd_sbp_sig.bn_in_op2nd_sbp().end());
const auto& lhs_nd_sbp = lhs_bn_nd_sbp.second;
const auto& rhs_nd_sbp = rhs_bn_nd_sbp_it->second;
int cmp_ret = CompareNdSbp(lhs_nd_sbp, rhs_nd_sbp);
if (cmp_ret != 0) { return cmp_ret; }
}
return 0;
}

void DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list) {
std::vector<size_t> indices(nd_sbp_sig_list->size());
std::iota(indices.begin(), indices.end(), 0);
std::sort(indices.begin(), indices.end(), [nd_sbp_sig_list](size_t lhs, size_t rhs) {
int cmp_ret = CompareNdSbpSignature(nd_sbp_sig_list->at(lhs), nd_sbp_sig_list->at(rhs));
if (cmp_ret == 0) {
return true;
} else if (cmp_ret == -1) {
return true;
} else if (cmp_ret == 1) {
return false;
} else {
UNIMPLEMENTED();
}
});
auto new_end =
std::unique(indices.begin(), indices.end(), [nd_sbp_sig_list](size_t lhs, size_t rhs) {
return nd_sbp_sig_list->at(lhs) == nd_sbp_sig_list->at(rhs);
});
indices.erase(new_end, indices.end());
std::vector<NdSbpSignature> new_nd_sbp_sig_list;
for (size_t index : indices) {
new_nd_sbp_sig_list.emplace_back(std::move(nd_sbp_sig_list->at(index)));
}
*nd_sbp_sig_list = std::move(new_nd_sbp_sig_list);
}

// Compute storage per device for given NdSbp
double Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy) {
if (nd_sbp.sbp_parallel_size() == 1) {
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/framework/sbp_infer_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ void DfsGetNdSbpSignature(NdSbpSignature& nd_sbp_sig, int32_t depth, int32_t dim
const HashMap<int32_t, SbpSignatureList>& hierarchy_value2sbp_sig_list,
std::vector<NdSbpSignature>* nd_sbp_sig_list);

// return -1 indicate lhs < rhs, 1 indicate lhs > rhs, 0 indicate lhs == rhs
int CompareSbp(const SbpParallel& lhs_sbp, const SbpParallel& rhs_sbp);
int CompareNdSbp(const NdSbp& lhs_nd_sbp, const NdSbp& rhs_nd_sbp);
int CompareNdSbpSignature(const NdSbpSignature& lhs_nd_sbp_sig,
const NdSbpSignature& rhs_nd_sbp_sig);
void DeduplicateNdSbpSignatureList(std::vector<NdSbpSignature>* nd_sbp_sig_list);

// Compute storage for given NdSbp
double Storage4NdSbp(const NdSbp& nd_sbp, Shape& logical_shape, const Shape& parallel_hierarchy);

Expand Down
5 changes: 5 additions & 0 deletions oneflow/core/framework/user_op_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,11 @@ OpRegistry& OpRegistry::SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn get_n
return *this;
}

OpRegistry& OpRegistry::SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn) {
result_.enumerate_nd_sbp_signatures_fn = std::move(fn);
return *this;
}

OpRegistry& OpRegistry::SetDumpNdSbpSignatureForOpConfFn(
Operator::DumpNdSbpSignatureForOpConfFn fn) {
result_.dump_nd_sbp_signature_for_op_conf_fn = std::move(fn);
Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/framework/user_op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ using NdSbpInferFn = std::function<Maybe<void>(InferNdSbpFnContext*)>;
using ComputeComplexityFn = std::function<Maybe<double>(ComputeComplexityFnContext*)>;
// TODO: set up another context
using GetNdSbpSignatureListFn = std::function<Maybe<void>(GetNdSbpSignatureListContext*)>;
using EnumerateNdSbpSignaturesFn = std::function<Maybe<void>(GetNdSbpSignatureListContext*)>;

struct OpRegistryResult {
OpRegistryResult()
Expand Down Expand Up @@ -94,6 +95,7 @@ struct OpRegistryResult {
NdSbpInferFn nd_sbp_infer_fn;
ComputeComplexityFn compute_complexity_fn;
GetNdSbpSignatureListFn get_nd_sbp_list_fn;
EnumerateNdSbpSignaturesFn enumerate_nd_sbp_signatures_fn;
Operator::DumpNdSbpSignatureForOpConfFn dump_nd_sbp_signature_for_op_conf_fn;
};

Expand Down Expand Up @@ -143,6 +145,7 @@ class OpRegistry final {
OpRegistry& SetDeviceAndStreamInferFn(DeviceAndStreamInferFn fn);
OpRegistry& SetComputeComplexityFn(ComputeComplexityFn fn);
OpRegistry& SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn fn);
OpRegistry& SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn);
OpRegistry& SetDumpNdSbpSignatureForOpConfFn(Operator::DumpNdSbpSignatureForOpConfFn fn);

Maybe<OpRegistry&> Finish();
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/operator/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,12 @@ Maybe<void> Operator::GetSbpSignaturesIf(
return Maybe<void>::Ok();
}

Maybe<void> Operator::EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
return Maybe<void>::Ok();
}

Maybe<void> Operator::GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
Expand Down Expand Up @@ -536,6 +542,7 @@ Maybe<void> Operator::GetNdSbpSignatureList(
CHECK_OR_RETURN(nd_sbp_sig_list->empty());
DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, *parallel_desc.hierarchy(),
hierarchy_value2sbp_sig_list, nd_sbp_sig_list);
JUST(EnumerateNdSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list));
return Maybe<void>::Ok();
}

Expand Down
3 changes: 3 additions & 0 deletions oneflow/core/operator/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,9 @@ class Operator {
Maybe<void> GetSbpSignaturesIf(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const;
virtual Maybe<void> EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const;
virtual Maybe<void> GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const;
Expand Down
15 changes: 15 additions & 0 deletions oneflow/core/operator/user_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,8 @@ class UserOpGetNdSbpSignatureListContext : public user_op::GetNdSbpSignatureList
nd_sbp_sig_list_(nd_sbp_sig_list) {}
~UserOpGetNdSbpSignatureListContext() override = default;

std::vector<NdSbpSignature>* MutNdSbpSignatureList() override { return nd_sbp_sig_list_; }

void AddNdSbpSignature(NdSbpSignature& nd_sbp_sig) override {
nd_sbp_sig_list_->emplace_back(nd_sbp_sig);
}
Expand Down Expand Up @@ -992,6 +994,19 @@ Maybe<void> UserOp::InferNdSbpSignature(
return Maybe<void>::Ok();
}

Maybe<void> UserOp::EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
if (val_->enumerate_nd_sbp_signatures_fn) {
NdSbpSignature empty_sbp_signature;
UserOpGetNdSbpSignatureListContext user_op_get_nd_sbp_list_context(
this, LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list);
return val_->enumerate_nd_sbp_signatures_fn(&user_op_get_nd_sbp_list_context);
} else {
return Operator::EnumerateNdSbpSignatures(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list);
}
}

Maybe<void> UserOp::GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/operator/user_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class UserOp final : public Operator {
Maybe<void> GetSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
int32_t hierarchy_value, SbpSignatureList* sbp_sig_list) const override;
Maybe<void> EnumerateNdSbpSignatures(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc,
std::vector<NdSbpSignature>* nd_sbp_sig_list) const override;
Maybe<void> GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc,
Expand Down
1 change: 1 addition & 0 deletions oneflow/ir/include/OneFlow/OneFlowBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ class OneFlow_BaseOp<string mnemonic, list<Trait> traits = []> :
bit has_nd_sbp_infer_fn = 0;
bit has_compute_complexity_fn = 0;
bit has_get_nd_sbp_fn = 0;
bit has_enumerate_nd_sbp_signatures_fn = 0;
bit has_dump_nd_sbp_signature_for_op_conf_fn = 0;
}

Expand Down
1 change: 1 addition & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8265,6 +8265,7 @@ def OneFlow_ReshapeOp : OneFlow_BaseOp<"reshape", [NoSideEffect, DeclareOpInterf
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_enumerate_nd_sbp_signatures_fn = 1;
let has_data_type_infer_fn = 1;
let hasFolder = 1;
}
Expand Down
41 changes: 38 additions & 3 deletions oneflow/user/ops/reshape_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/ops/reshape_user_op_util.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/framework/op_generated.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/user/ops/reshape_user_op_util.h"

namespace oneflow {

Expand All @@ -30,6 +32,39 @@ namespace oneflow {
in_shape, *outshape, {{"in", 0}}, {{"out", 0}}, ctx->hierarchy_value(), &builder);
}

/*static*/ Maybe<void> ReshapeOp::EnumerateNdSbpSignatures(
user_op::GetNdSbpSignatureListContext* ctx) {
const Shape& in_shape = ctx->BlobShape4InputArgNameAndIndex("in", 0);
const Shape& shape_attr = ctx->Attr<Shape>("shape");
std::shared_ptr<Shape> out_shape_ptr =
JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape_attr));

std::vector<NdSbpSignature>* nd_sbp_sig_list = ctx->MutNdSbpSignatureList();
JUST(ReshapeUserOpUtil::EnumerateNdSbpSignatures({{"in", 0}}, in_shape, {{"out", 0}},
*out_shape_ptr, ctx->parallel_hierarchy(),
nd_sbp_sig_list));

// Go down from the tail to the head, since we might drop the tail.
for (int32_t sbp_id = nd_sbp_sig_list->size() - 1; sbp_id >= 0; sbp_id--) {
auto& nd_sbp_sig = (*nd_sbp_sig_list)[sbp_id];
const auto& out_nd_sbp_it = nd_sbp_sig.bn_in_op2nd_sbp().find("out_0");
CHECK_OR_RETURN(out_nd_sbp_it != nd_sbp_sig.bn_in_op2nd_sbp().end())
<< "can't get sbp for out_0";
Shape out_logical_shape = *out_shape_ptr;
// filter by output only be needed here
// filter by input will be done in Operator::FilterNdSbpSignatureListByLogicalShape
if (Storage4NdSbp(out_nd_sbp_it->second, out_logical_shape, ctx->parallel_hierarchy())
> GetValidMaxCopyCost()) {
// Remove the Nd SBP candidate
std::swap(nd_sbp_sig, nd_sbp_sig_list->back());
nd_sbp_sig_list->pop_back();
}
}

DeduplicateNdSbpSignatureList(nd_sbp_sig_list);
return Maybe<void>::Ok();
}

/*static*/ Maybe<void> ReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Shape shape = ctx->Attr<Shape>("shape");
const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0);
Expand Down
Loading