Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supplement ND SBP signatures for reshape op #9858

Merged
merged 35 commits into from
Feb 28, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
663cf30
EnumerateNdSbpSignatures
leaves-zwx Feb 9, 2023
dccb762
ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 11, 2023
7cd2dc1
fix algo
leaves-zwx Feb 12, 2023
1eed7aa
fix
leaves-zwx Feb 12, 2023
3dc723b
cpp test
leaves-zwx Feb 12, 2023
cbf7261
fix
leaves-zwx Feb 12, 2023
a440e8d
rename
leaves-zwx Feb 12, 2023
6998d52
refine UserOp::GetNdSbpSignatureList
leaves-zwx Feb 12, 2023
94870a6
ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 12, 2023
bff5183
py test
leaves-zwx Feb 12, 2023
2db180a
revert changes
leaves-zwx Feb 13, 2023
84b3acf
new ReshapeOp::EnumerateNdSbpSignatures
leaves-zwx Feb 13, 2023
ad4a227
add test case
leaves-zwx Feb 13, 2023
7c22e1a
refine algorithm
leaves-zwx Feb 15, 2023
c51364a
update test and fix bug
leaves-zwx Feb 16, 2023
db658eb
rm comment
leaves-zwx Feb 16, 2023
278f577
rm redundant condition
leaves-zwx Feb 16, 2023
c500114
DeduplicateNdSbpSignatureList
leaves-zwx Feb 17, 2023
a4ddd55
refine GenRankMeshSubset
leaves-zwx Feb 17, 2023
508cc6f
add comments
leaves-zwx Feb 17, 2023
e5451c2
suppress warning
leaves-zwx Feb 18, 2023
b3c9ca9
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 18, 2023
66d3923
suppress warning
leaves-zwx Feb 18, 2023
5e4aebc
suppress warning
leaves-zwx Feb 18, 2023
670eca3
refine sort algorithm
leaves-zwx Feb 21, 2023
351083e
refine sort algorithm and test
leaves-zwx Feb 21, 2023
e6b579b
rm calling FilterNdSbpIn2OutSignatures in EnumerateNdSbpSignatures
leaves-zwx Feb 21, 2023
ca9eb24
change example
leaves-zwx Feb 21, 2023
abe42db
Update oneflow/user/ops/reshape_user_op_util.cpp
leaves-zwx Feb 21, 2023
4836822
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 21, 2023
7bacbcf
refine sort algorithm
leaves-zwx Feb 25, 2023
f578fcf
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 25, 2023
d3cf09d
change sort order
leaves-zwx Feb 28, 2023
13dd1be
use FilterNdSbpByLogicalShape
leaves-zwx Feb 28, 2023
70ae77e
Merge branch 'master' into enlarge_reshape_sbp
leaves-zwx Feb 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions oneflow/core/framework/get_nd_sbp_signature_list_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ namespace user_op {

class UserOpDefWrapper;

class GetNdSbpSignatureListContext {
class EnumerateNdSbpSignaturesContext {
public:
virtual ~GetNdSbpSignatureListContext() = default;
virtual ~EnumerateNdSbpSignaturesContext() = default;

virtual void AddNdSbpSignature(NdSbpSignature&) = 0;
virtual const Shape& parallel_hierarchy() = 0;
Expand All @@ -44,7 +44,7 @@ class GetNdSbpSignatureListContext {
const UserOpConfWrapper& user_op_conf() const { return conf_; }

protected:
explicit GetNdSbpSignatureListContext(UserOpConfWrapper&& conf) : conf_(std::move(conf)) {}
explicit EnumerateNdSbpSignaturesContext(UserOpConfWrapper&& conf) : conf_(std::move(conf)) {}

private:
UserOpConfWrapper conf_;
Expand Down
4 changes: 2 additions & 2 deletions oneflow/core/framework/user_op_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,8 @@ OpRegistry& OpRegistry::SetComputeComplexityFn(ComputeComplexityFn compute_compl
return *this;
}

OpRegistry& OpRegistry::SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn get_nd_sbp_list_fn) {
result_.get_nd_sbp_list_fn = std::move(get_nd_sbp_list_fn);
OpRegistry& OpRegistry::SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn) {
result_.enumerate_nd_sbp_signatures_fn = std::move(fn);
return *this;
}

Expand Down
8 changes: 4 additions & 4 deletions oneflow/core/framework/user_op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class InferOutputBlobTimeShapeFnContext;
class InferNdSbpFnContext;
class DeviceAndStreamInferContext;
class ComputeComplexityFnContext;
class GetNdSbpSignatureListContext;
class EnumerateNdSbpSignaturesContext;

using CheckAttrFn = std::function<Maybe<void>(const UserOpDefWrapper&, const UserOpConfWrapper&)>;
using TensorDescInferFn = std::function<Maybe<void>(InferContext*)>;
Expand All @@ -63,7 +63,7 @@ using OutputBlobTimeShapeInferFn = std::function<Maybe<void>(InferOutputBlobTime
using NdSbpInferFn = std::function<Maybe<void>(InferNdSbpFnContext*)>;
using ComputeComplexityFn = std::function<Maybe<double>(ComputeComplexityFnContext*)>;
// TODO: set up another context
using GetNdSbpSignatureListFn = std::function<Maybe<void>(GetNdSbpSignatureListContext*)>;
using EnumerateNdSbpSignaturesFn = std::function<Maybe<void>(EnumerateNdSbpSignaturesContext*)>;

struct OpRegistryResult {
OpRegistryResult()
Expand Down Expand Up @@ -93,7 +93,7 @@ struct OpRegistryResult {
OutputBlobTimeShapeInferFn output_blob_time_shape_infer_fn;
NdSbpInferFn nd_sbp_infer_fn;
ComputeComplexityFn compute_complexity_fn;
GetNdSbpSignatureListFn get_nd_sbp_list_fn;
EnumerateNdSbpSignaturesFn enumerate_nd_sbp_signatures_fn;
Operator::DumpNdSbpSignatureForOpConfFn dump_nd_sbp_signature_for_op_conf_fn;
};

Expand Down Expand Up @@ -142,7 +142,7 @@ class OpRegistry final {
OpRegistry& SetDataTypeInferFn(DataTypeInferFn fn);
OpRegistry& SetDeviceAndStreamInferFn(DeviceAndStreamInferFn fn);
OpRegistry& SetComputeComplexityFn(ComputeComplexityFn fn);
OpRegistry& SetGetNdSbpSignatureListFn(GetNdSbpSignatureListFn fn);
OpRegistry& SetEnumerateNdSbpSignaturesFn(EnumerateNdSbpSignaturesFn fn);
OpRegistry& SetDumpNdSbpSignatureForOpConfFn(Operator::DumpNdSbpSignatureForOpConfFn fn);

Maybe<OpRegistry&> Finish();
Expand Down
14 changes: 7 additions & 7 deletions oneflow/core/operator/user_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -621,18 +621,18 @@ class UserOpComputeComplexityFnContext : public user_op::ComputeComplexityFnCont
HashMap<std::pair<std::string, int32_t>, user_op::NaiveTensorDesc> arg2tensor_desc_;
};

class UserOpGetNdSbpSignatureListContext : public user_op::GetNdSbpSignatureListContext {
class UserOpEnumerateNdSbpSignaturesContext : public user_op::EnumerateNdSbpSignaturesContext {
public:
UserOpGetNdSbpSignatureListContext(
UserOpEnumerateNdSbpSignaturesContext(
const UserOp* op,
std::function<Maybe<const BlobDesc&>(const std::string&)> LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list)
: user_op::GetNdSbpSignatureListContext(user_op::UserOpConfWrapper(op->user_op_conf())),
: user_op::EnumerateNdSbpSignaturesContext(user_op::UserOpConfWrapper(op->user_op_conf())),
op_(op),
logical_blob_desc4ibn_(std::move(LogicalBlobDesc4Ibn)),
parallel_desc_(parallel_desc),
nd_sbp_sig_list_(nd_sbp_sig_list) {}
~UserOpGetNdSbpSignatureListContext() override = default;
~UserOpEnumerateNdSbpSignaturesContext() override = default;

void AddNdSbpSignature(NdSbpSignature& nd_sbp_sig) override {
nd_sbp_sig_list_->emplace_back(nd_sbp_sig);
Expand Down Expand Up @@ -995,11 +995,11 @@ Maybe<void> UserOp::InferNdSbpSignature(
Maybe<void> UserOp::GetNdSbpSignatureList(
const std::function<Maybe<const BlobDesc&>(const std::string&)>& LogicalBlobDesc4Ibn,
const ParallelDesc& parallel_desc, std::vector<NdSbpSignature>* nd_sbp_sig_list) const {
if (val_->get_nd_sbp_list_fn) {
if (val_->enumerate_nd_sbp_signatures_fn) {
NdSbpSignature empty_sbp_signature;
UserOpGetNdSbpSignatureListContext user_op_get_nd_sbp_list_context(
UserOpEnumerateNdSbpSignaturesContext user_op_get_nd_sbp_list_context(
this, LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list);
return val_->get_nd_sbp_list_fn(&user_op_get_nd_sbp_list_context);
return val_->enumerate_nd_sbp_signatures_fn(&user_op_get_nd_sbp_list_context);
} else {
JUST(Operator::GetNdSbpSignatureList(LogicalBlobDesc4Ibn, parallel_desc, nd_sbp_sig_list));
}
Expand Down
2 changes: 1 addition & 1 deletion oneflow/ir/include/OneFlow/OneFlowBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class OneFlow_BaseOp<string mnemonic, list<Trait> traits = []> :
bit has_output_blob_time_shape_infer_fn = 0;
bit has_nd_sbp_infer_fn = 0;
bit has_compute_complexity_fn = 0;
bit has_get_nd_sbp_fn = 0;
bit has_enumerate_nd_sbp_signatures_fn = 0;
bit has_dump_nd_sbp_signature_for_op_conf_fn = 0;
}

Expand Down
5 changes: 3 additions & 2 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -1582,7 +1582,7 @@ def OneFlow_OFRecordReaderOp : OneFlow_BaseOp<"OFRecordReader", [NoSideEffect, N
let has_data_type_infer_fn = 1;
let has_output_arg_modify_fn = 1;
let has_nd_sbp_infer_fn = 1;
let has_get_nd_sbp_fn = 1;
let has_enumerate_nd_sbp_signatures_fn = 1;
let has_compute_complexity_fn = 1;
}

Expand Down Expand Up @@ -7637,7 +7637,7 @@ def OneFlow_HierarchicalParallelCastOp : OneFlow_BaseOp<"hierarchical_parallel_c
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
let has_nd_sbp_infer_fn = 1;
let has_get_nd_sbp_fn = 1;
let has_enumerate_nd_sbp_signatures_fn = 1;
}

def OneFlow_HierarchicalParallelCastLikeOp : OneFlow_BaseOp<"hierarchical_parallel_cast_like", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
Expand Down Expand Up @@ -8195,6 +8195,7 @@ def OneFlow_ReshapeOp : OneFlow_BaseOp<"reshape", [NoSideEffect, DeclareOpInterf
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_enumerate_nd_sbp_signatures_fn = 1;
let has_data_type_infer_fn = 1;
let hasFolder = 1;
}
Expand Down
4 changes: 2 additions & 2 deletions oneflow/user/ops/hierarchical_parallel_cast_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ namespace oneflow {
return Maybe<void>::Ok();
}

/* static */ Maybe<void> HierarchicalParallelCastOp::GetNdSbpSignatureList(
user_op::GetNdSbpSignatureListContext* ctx) {
/* static */ Maybe<void> HierarchicalParallelCastOp::EnumerateNdSbpSignatures(
user_op::EnumerateNdSbpSignaturesContext* ctx) {
const auto& conf = ctx->Attr<std::vector<std::string>>("nd_sbp");
NdSbpSignature nd_sbp_signature;
for (const std::string& sbp_str : conf) {
Expand Down
4 changes: 2 additions & 2 deletions oneflow/user/ops/ofrecord_reader_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ namespace oneflow {
return Maybe<void>::Ok();
}

/* static */ Maybe<void> OFRecordReaderOp::GetNdSbpSignatureList(
user_op::GetNdSbpSignatureListContext* ctx) {
/* static */ Maybe<void> OFRecordReaderOp::EnumerateNdSbpSignatures(
user_op::EnumerateNdSbpSignaturesContext* ctx) {
NdSbpSignature nd_sbp_signature;
SbpParallel split_sbp_parallel;
split_sbp_parallel.mutable_split_parallel()->set_axis(0);
Expand Down
60 changes: 57 additions & 3 deletions oneflow/user/ops/reshape_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/balanced_splitter.h"
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/ops/reshape_user_op_util.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/framework/op_generated.h"
#include "oneflow/core/framework/sbp_infer_util.h"
#include "oneflow/core/framework/user_op_conf.h"
#include "oneflow/core/framework/nd_sbp.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/user/ops/reshape_user_op_util.h"

namespace oneflow {

Expand All @@ -30,6 +32,58 @@ namespace oneflow {
in_shape, *outshape, {{"in", 0}}, {{"out", 0}}, ctx->hierarchy_value(), &builder);
}

/*static*/ Maybe<void> ReshapeOp::EnumerateNdSbpSignatures(
user_op::EnumerateNdSbpSignaturesContext* ctx) {
const Shape& in_shape = ctx->BlobShape4InputArgNameAndIndex("in", 0);
const Shape& shape_attr = ctx->Attr<Shape>("shape");
const Shape& out_shape = *JUST(ReshapeUserOpUtil::GetLogicalOutBlobShape(in_shape, shape_attr));

std::vector<NdSbpSignature> nd_sbp_sig_list;
HashMap<int32_t, SbpSignatureList> hierarchy_value2sbp_sig_list;
for (int32_t hierarchy_value : ctx->parallel_hierarchy()) {
if (hierarchy_value2sbp_sig_list.find(hierarchy_value) == hierarchy_value2sbp_sig_list.end()) {
auto& sbp_sig_list = hierarchy_value2sbp_sig_list[hierarchy_value];
user_op::UserOpSbpSignatureBuilder builder(&sbp_sig_list);
JUST(ReshapeUserOpUtil::GetReshapeUserOpSbpSignatures(
in_shape, out_shape, {{"in", 0}}, {{"out", 0}}, hierarchy_value, &builder));
CHECK_GT_OR_RETURN(sbp_sig_list.sbp_signature_size(), 0)
<< "reshape can't enumerate any SBP signatures with rank dim (" << hierarchy_value << ")";
}
}

int32_t sbp_dimension = ctx->parallel_hierarchy().NumAxes();
NdSbpSignature nd_sbp_sig;
SbpSignatureToNdSbpSignature(hierarchy_value2sbp_sig_list.begin()->second.sbp_signature(0),
&nd_sbp_sig);
ResizeNdSbpSignature(nd_sbp_sig, sbp_dimension);
// ND sbp signature list would be direct product of 1D sbp signatures
CHECK_OR_RETURN(nd_sbp_sig_list.empty());
DfsGetNdSbpSignature(nd_sbp_sig, 0, sbp_dimension, ctx->parallel_hierarchy(),
hierarchy_value2sbp_sig_list, &nd_sbp_sig_list);

JUST(ReshapeUserOpUtil::EnumerateNdSplitSignatures({{"in", 0}}, in_shape, {{"out", 0}}, out_shape,
ctx->parallel_hierarchy(), &nd_sbp_sig_list));

// Go down from the tail to the head, since we might drop the tail.
for (int32_t sbp_id = nd_sbp_sig_list.size() - 1; sbp_id >= 0; sbp_id--) {
auto& nd_sbp_sig = nd_sbp_sig_list[sbp_id];
const auto& out_nd_sbp_it = nd_sbp_sig.bn_in_op2nd_sbp().find("out_0");
CHECK_OR_RETURN(out_nd_sbp_it != nd_sbp_sig.bn_in_op2nd_sbp().end())
<< "can't get sbp for out_0";
Shape out_logical_shape = out_shape;
// filter output only here
// filter input in Operator::FilterNdSbpSignatureListByLogicalShape
if (Storage4NdSbp(out_nd_sbp_it->second, out_logical_shape, ctx->parallel_hierarchy())
> GetValidMaxCopyCost()) {
// Remove the Nd SBP candidate
std::swap(nd_sbp_sig, nd_sbp_sig_list.back());
nd_sbp_sig_list.pop_back();
}
}
for (auto& nd_sbp_sig : nd_sbp_sig_list) { ctx->AddNdSbpSignature(nd_sbp_sig); }
return Maybe<void>::Ok();
}

/*static*/ Maybe<void> ReshapeOp::InferLogicalTensorDesc(user_op::InferContext* ctx) {
Shape shape = ctx->Attr<Shape>("shape");
const user_op::TensorDesc& in_tensor_desc = ctx->InputTensorDesc("in", 0);
Expand Down
Loading