Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dev empty op #4720

Merged
merged 25 commits into from
Apr 26, 2021
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
560165d
startup
doombeaker Apr 23, 2021
57a2ede
empty op finished
doombeaker Apr 23, 2021
3f45e02
Merge branch 'master' into dev_empty_op
doombeaker Apr 23, 2021
b8bc489
Merge branch 'master' into dev_empty_op
lixinqi Apr 24, 2021
684322c
add sbp signature
doombeaker Apr 24, 2021
9cd7f0b
Merge branch 'dev_empty_op' of https://github.com/Oneflow-Inc/oneflow…
doombeaker Apr 24, 2021
4e1130a
Merge branch 'master' into dev_empty_op
doombeaker Apr 24, 2021
4fea750
refine test case for fp16
doombeaker Apr 25, 2021
6078d85
try to fix sbp problem
doombeaker Apr 25, 2021
e8e354e
refien sbpGetFn
doombeaker Apr 25, 2021
88da3e1
Merge branch 'master' into dev_empty_op
doombeaker Apr 25, 2021
985df1c
add sbp config attr
doombeaker Apr 25, 2021
0ed2e56
Merge branch 'dev_empty_op' of https://github.com/Oneflow-Inc/oneflow…
doombeaker Apr 25, 2021
07145ba
refine
doombeaker Apr 25, 2021
3c76dea
refine
doombeaker Apr 26, 2021
d2c8d37
add balancedSpliter and add parallel check on py
doombeaker Apr 26, 2021
4c5ac69
Merge branch 'master' into dev_empty_op
doombeaker Apr 26, 2021
0ee633c
refine
doombeaker Apr 26, 2021
f9191fa
add partialSum parallel support
doombeaker Apr 26, 2021
5ba0f0b
unexported empty and rm its test case
doombeaker Apr 26, 2021
2af3f28
Merge branch 'master' into dev_empty_op
doombeaker Apr 26, 2021
810ea1e
rm python wrapper for empty op
doombeaker Apr 26, 2021
e232a51
Merge branch 'dev_empty_op' of https://github.com/Oneflow-Inc/oneflow…
doombeaker Apr 26, 2021
f219320
Merge branch 'master' into dev_empty_op
doombeaker Apr 26, 2021
10e1a56
Merge branch 'master' into dev_empty_op
oneflow-ci-bot Apr 26, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions oneflow/user/kernels/empty_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/kernel/new_kernel_util.h"

namespace oneflow {
namespace user_op {

template<DeviceType device_type, typename T>
class EmptyKernel final : public OpKernel {
public:
EmptyKernel() = default;
~EmptyKernel() = default;

private:
void Compute(user_op::KernelComputeContext* ctx) const override {
Tensor* out_tensor = ctx->Tensor4ArgNameAndIndex("out", 0);
const int64_t elem_cnt = out_tensor->shape().elem_cnt();
CHECK_GT(elem_cnt, 0);

// Do nothing
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
};

#define REGISTER_EMPTY_XPU_KERNEL(device, dtype) \
REGISTER_USER_KERNEL("empty").SetCreateFn<EmptyKernel<device, dtype>>().SetIsMatchedHob( \
(user_op::HobDeviceTag() == device) \
& (user_op::HobAttr<DataType>("dtype") == GetDataType<dtype>::value));

#define REGISTER_EMPTY_KERNEL(device, dtype_pair) \
REGISTER_EMPTY_XPU_KERNEL(device, OF_PP_PAIR_FIRST(dtype_pair))

OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE(REGISTER_EMPTY_KERNEL, DEVICE_TYPE_SEQ, ARITHMETIC_DATA_TYPE_SEQ)

#ifdef WITH_CUDA
REGISTER_EMPTY_XPU_KERNEL(DeviceType::kGPU, float16);
#endif // WITH_CUDA

} // namespace user_op
} // namespace oneflow
108 changes: 108 additions & 0 deletions oneflow/user/ops/empty_op.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/common/balanced_splitter.h"

namespace oneflow {
REGISTER_USER_OP("empty")
.Output("out")
.SetOutputBufferNum(1)
.Attr<DataType>("dtype")
.Attr<Shape>("shape")
.Attr<std::string>("sbp_parallel", "")
.SetLogicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
Shape* out_shape = ctx->Shape4ArgNameAndIndex("out", 0);
const Shape& shape = ctx->Attr<Shape>("shape");
DimVector dim_vec;
if (shape.NumAxes() > 0) {
dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend());
}
if (dim_vec.empty()) { dim_vec.push_back(1); }
*out_shape = Shape(dim_vec);
return Maybe<void>::Ok();
})
.SetPhysicalTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe<void> {
Shape* out_shape = ctx->Shape4ArgNameAndIndex("out", 0);
const Shape& shape = ctx->Attr<Shape>("shape");
DimVector dim_vec;
if (shape.NumAxes() > 0) {
dim_vec.insert(dim_vec.end(), shape.dim_vec().cbegin(), shape.dim_vec().cend());
}
if (dim_vec.empty()) { dim_vec.push_back(1); }

const SbpParallel& out_sbp_para = ctx->SbpParallel4ArgNameAndIndex("out", 0);
if (out_sbp_para.has_split_parallel()) {
const int64_t& parallel_num = ctx->parallel_ctx().parallel_num();
if (parallel_num > 1) {
const int64_t& split_axis = out_sbp_para.split_parallel().axis();
CHECK_LT_OR_RETURN(split_axis, dim_vec.size());
BalancedSplitter bs(shape.At(split_axis), parallel_num);
dim_vec[split_axis] = bs.At(ctx->parallel_ctx().parallel_id()).size();
}
}

*out_shape = Shape(dim_vec);
return Maybe<void>::Ok();
})
.SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe<void> {
const Shape& shape = ctx->Attr<Shape>("shape");
if (shape.NumAxes() > 0) {
FOR_RANGE(int64_t, i, 0, shape.NumAxes()) {
ctx->NewBuilder().Split(ctx->outputs(), i).Build();
}
}
ctx->NewBuilder().PartialSum(ctx->outputs()).Build();
return Maybe<void>::Ok();
})
.SetInferSbpSignatureFn([](user_op::InferSbpSignatureFnContext* ctx) -> Maybe<void> {
auto* bn2sbp = ctx->mutable_sbp_signature()->mutable_bn_in_op2sbp_parallel();
const std::string& obn = GenRepeatedBn("out", 0);
const auto& sbp_parallel_str = ctx->Attr<std::string>("sbp_parallel");
const std::string& ibn = GenRepeatedBn(user_op::kUserSourceOpTickInputArgName, 0);
SbpParallel sbp_parallel;
sbp_parallel.mutable_broadcast_parallel();
(*bn2sbp)[ibn] = sbp_parallel;
if (sbp_parallel_str.empty()) {
(*bn2sbp)[obn] = sbp_parallel;
} else {
CHECK_OR_RETURN(ParseSbpParallelFromString(sbp_parallel_str, &sbp_parallel))
<< "invalid sbp_parallel: " << sbp_parallel_str;
if (sbp_parallel.has_split_parallel()) {
int64_t split_axis = sbp_parallel.split_parallel().axis();
const Shape& shape = ctx->Attr<Shape>("shape");
CHECK_OR_RETURN(shape.NumAxes() > 0)
<< "Split parallel is not supported for shape whose value is None";
CHECK_GE_OR_RETURN(split_axis, 0);
CHECK_LT_OR_RETURN(split_axis, shape.NumAxes());
(*bn2sbp)[obn] = sbp_parallel;
} else if (sbp_parallel.has_broadcast_parallel()) {
(*bn2sbp)[obn] = sbp_parallel;
} else if (sbp_parallel.has_partial_sum_parallel()) {
(*bn2sbp)[obn] = sbp_parallel;
} else {
UNIMPLEMENTED() << "sbp parallel not supported";
}
}
return Maybe<void>::Ok();
})
.SetInferDataTypeFn([](user_op::InferContext* ctx) -> Maybe<void> {
const DataType dtype = ctx->Attr<DataType>("dtype");
*ctx->Dtype4ArgNameAndIndex("out", 0) = dtype;
return Maybe<void>::Ok();
});

} // namespace oneflow