Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add squared relu #10316

Merged
merged 11 commits into from
Aug 22, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/nn.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Non-linear activation functions
selu
celu
leaky_relu
square_relu
prelu
glu
gelu
Expand Down
1 change: 1 addition & 0 deletions docs/source/nn.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,7 @@ Non-linear Activations (weighted sum, nonlinearity)
nn.CELU
nn.GELU
nn.QuickGELU
nn.SquareReLU
nn.SiLU
nn.Sigmoid
nn.Mish
Expand Down
1 change: 1 addition & 0 deletions docs/source/oneflow.rst
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,7 @@ Pointwise Ops
fmod
gelu
quick_gelu
square_relu
log
log1p
log2
Expand Down
31 changes: 31 additions & 0 deletions oneflow/core/autograd/gradient_funcs/activation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,36 @@ class QuickGeLU : public OpExprGradFunction<QuickGeluCaptureState> {
}
};

struct SquareReLUCaptureState : public AutoGradCaptureState {
bool requires_grad = false;
};

class SquareReLU : public OpExprGradFunction<SquareReLUCaptureState> {
public:
Maybe<void> Init(const OpExpr& op) override { return Maybe<void>::Ok(); }

Maybe<void> Capture(SquareReLUCaptureState* ctx, const TensorTuple& inputs,
const TensorTuple& outputs, const AttrMap& attrs) const override {
CHECK_EQ_OR_RETURN(inputs.size(), 1);
CHECK_EQ_OR_RETURN(outputs.size(), 1);
ctx->requires_grad = inputs.at(0)->requires_grad();
if (!ctx->requires_grad) { return Maybe<void>::Ok(); }
ctx->SaveTensorForBackward(inputs.at(0));
return Maybe<void>::Ok();
}

Maybe<void> Apply(const SquareReLUCaptureState* ctx, const TensorTuple& out_grads,
TensorTuple* in_grads) const override {
CHECK_EQ_OR_RETURN(out_grads.size(), 1);
in_grads->resize(1);
if (ctx->requires_grad) {
const auto& x = ctx->SavedTensors().at(0);
in_grads->at(0) = JUST(functional::SquareReLUGrad(out_grads.at(0), x));
}
return Maybe<void>::Ok();
}
};

class HardSigmoid : public BaseActivation {
public:
Maybe<void> Apply(const BaseActivationCaptureState* ctx, const TensorTuple& out_grads,
Expand Down Expand Up @@ -638,6 +668,7 @@ REGISTER_OP_EXPR_GRAD_FUNCTION("softplus", Softplus);
REGISTER_OP_EXPR_GRAD_FUNCTION("softshrink", SoftShrink);
REGISTER_OP_EXPR_GRAD_FUNCTION("fast_gelu", FastGeLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("quick_gelu", QuickGeLU);
REGISTER_OP_EXPR_GRAD_FUNCTION("square_relu", SquareReLU);

} // namespace one
} // namespace oneflow
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ inline bool IsDimsEquals(size_t num_src0_dims, const int64_t* src0_dims, size_t
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kTanhBackwardWithDyY) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kThresholdBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kFastGeluBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kQuickGeluBackwardWithDyX)
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kQuickGeluBackwardWithDyX) \
OF_PP_MAKE_TUPLE_SEQ(BinaryOp::kSquareReLUBackwardWithDyX)

#define BINARY_ACTIVATION_BACKWARD_OP_SEQ \
BINARY_ACTIVATION_BACKWARD_OP_SEQ_0 \
Expand Down
3 changes: 2 additions & 1 deletion oneflow/core/ep/common/primitive/elementwise_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ namespace primitive {
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNotEqualZero) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kNanAssign) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kFastGelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kQuickGelu)
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kQuickGelu) \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kSquareReLU)

#define UNARY_COMPLEX_C2C_OP_SEQ \
OF_PP_MAKE_TUPLE_SEQ(UnaryOp::kConj) \
Expand Down
10 changes: 10 additions & 0 deletions oneflow/core/ep/cpu/primitive/binary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,16 @@ struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kQuickGeluBackwardWithDyX, Src,
static constexpr Src alpha = static_cast<Src>(1.702);
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kSquareReLUBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return static_cast<Dst>((x <= static_cast<Src>(0.0)) ? static_cast<Src>(0.0)
: static_cast<Src>(2.0) * x * dy);
}
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCPU, BinaryOp::kTanhBackwardWithDyY, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
Expand Down
11 changes: 11 additions & 0 deletions oneflow/core/ep/cpu/primitive/unary_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,16 @@ struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kQuickGelu, Dst, Src> {
static constexpr Src alpha = static_cast<Src>(1.702);
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kSquareReLU, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const {
const Src square_relu = static_cast<Dst>((src > static_cast<Src>(0.0)) ? src * src : 0);
return square_relu;
}
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCPU, UnaryOp::kTanh, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}
Expand Down Expand Up @@ -371,6 +381,7 @@ SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquareReLU);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma);
SPECIALIZATION_CPU_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTrigamma);

Expand Down
12 changes: 12 additions & 0 deletions oneflow/core/ep/cuda/primitive/binary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,16 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kQuickGeluBackwardWithDyX, Src
const Src alpha = static_cast<Src>(1.702);
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kSquareReLUBackwardWithDyX, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src dy, Src x) const {
return static_cast<Dst>((x <= static_cast<Src>(0.0)) ? static_cast<Src>(0.0)
: static_cast<Src>(2.0) * x * dy);
}
};

template<typename Src, typename Dst>
struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kTanhBackwardWithDyY, Src, Dst> {
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
Expand Down Expand Up @@ -405,6 +415,7 @@ SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyY);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kFastGeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kQuickGeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSquareReLUBackwardWithDyX);

SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAcosBackwardWithDyX);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kAcoshBackwardWithDyX);
Expand Down Expand Up @@ -479,6 +490,7 @@ SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyY);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kFastGeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kQuickGeluBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSquareReLUBackwardWithDyX);

SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAcosBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kAcoshBackwardWithDyX);
Expand Down
12 changes: 12 additions & 0 deletions oneflow/core/ep/cuda/primitive/unary_functor.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kQuickGelu, Dst, Src> {
static constexpr Src alpha = static_cast<Src>(1.702);
};

template<typename Dst, typename Src>
struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kSquareReLU, Dst, Src> {
OF_DEVICE_FUNC UnaryFunctor(Scalar attr0, Scalar attr1) {}

OF_DEVICE_FUNC Dst operator()(Src src) const {
const Src square_relu = static_cast<Dst>((src > static_cast<Src>(0.0)) ? src * src : 0);
return square_relu;
}
};

namespace unary_functor_internal {

namespace {
Expand Down Expand Up @@ -491,6 +501,7 @@ SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kReciprocalNoNan);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kNanAssign);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kQuickGelu);
SPECIALIZATION_PSEUDO_HALF_UNARY_FUNCTOR(UnaryOp::kSquareReLU);

/*********nv_bfloat16_kernel*******/

Expand Down Expand Up @@ -558,6 +569,7 @@ SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNotEqualZero);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kNanAssign);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kFastGelu);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kQuickGelu);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kSquareReLU);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kDigamma);
SPECIALIZATION_PSEUDO_BFLOAT16_UNARY_FUNCTOR(UnaryOp::kTrigamma);

Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ep/include/primitive/binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ enum class BinaryOp {
kTanBackwardWithDyX,
kFastGeluBackwardWithDyX,
kQuickGeluBackwardWithDyX,
kSquareReLUBackwardWithDyX,
};

}
Expand Down
1 change: 1 addition & 0 deletions oneflow/core/ep/include/primitive/unary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ enum class UnaryOp {
kThreshold,
kFastGelu,
kQuickGelu,
kSquareReLU,
// math op
kAbs,
kAcos,
Expand Down
8 changes: 8 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -755,6 +755,14 @@
signature: "Tensor (Tensor dy, Tensor x) => QuickGeluGrad"
bind_python: False

- name: "square_relu"
signature: "Tensor (Tensor x) => SquareReLU"
bind_python: True

- name: "square_relu_grad"
signature: "Tensor (Tensor dy, Tensor x) => SquareReLUGrad"
bind_python: False

- name: "gelu_with_approximate"
signature: 'Tensor (Tensor x, String approximate="none") => GeluWithApproximate'
bind_python: True
Expand Down
17 changes: 17 additions & 0 deletions oneflow/core/functional/impl/activation_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,21 @@ class QuickGeluGradFunctor : public BinaryFunctor {
}
};

class SquareReLUFunctor : public UnaryFunctor {
public:
SquareReLUFunctor() {
op_ = CHECK_JUST(one::OpBuilder("square_relu").Input("x").Output("y").Build());
}
};

class SquareReLUGradFunctor : public BinaryFunctor {
public:
SquareReLUGradFunctor() {
op_ =
CHECK_JUST(one::OpBuilder("square_relu_grad").Input("dy").Input("x").Output("dx").Build());
}
};

class GluFunctor {
public:
GluFunctor() {}
Expand Down Expand Up @@ -779,6 +794,8 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::FastGeluGradFunctor>("FastGeluGrad");
m.add_functor<impl::QuickGeluFunctor>("QuickGelu");
m.add_functor<impl::QuickGeluGradFunctor>("QuickGeluGrad");
m.add_functor<impl::QuickGeluFunctor>("SquareReLU");
m.add_functor<impl::QuickGeluGradFunctor>("SquareReLUGrad");
m.add_functor<impl::GluFunctor>("Glu");
m.add_functor<impl::HardSigmoidFunctor>("HardSigmoid");
m.add_functor<impl::HardSigmoidGradFunctor>("HardSigmoidGrad");
Expand Down
27 changes: 27 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,20 @@ def OneFlow_QuickGeluGradOp : OneFlow_BaseOp<"quick_gelu_grad", [NoMemoryEffect,
let has_data_type_infer_fn = 1;
}

def OneFlow_SquareReLUGradOp : OneFlow_BaseOp<"square_relu_grad", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x,
OneFlow_Tensor:$dy
);
let output = (outs
OneFlow_Tensor:$dx
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_GridSampleOp : OneFlow_BaseOp<"grid_sample", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$input,
Expand Down Expand Up @@ -10414,6 +10428,19 @@ def OneFlow_QuickGeluOp : OneFlow_BaseOp<"quick_gelu", [NoMemoryEffect, DeclareO
let has_data_type_infer_fn = 1;
}

def OneFlow_SquareReLUOp : OneFlow_BaseOp<"square_relu", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$x
);
let output = (outs
OneFlow_Tensor:$y
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_HardsigmoidOp : OneFlow_BaseOp<"hardsigmoid", [NoMemoryEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in
Expand Down
26 changes: 26 additions & 0 deletions oneflow/user/kernels/activation_kernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,32 @@ REGISTER_USER_KERNEL("quick_gelu_grad")
})
.SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kQuickGeluBackwardWithDyX, "dx",
"dy"));
REGISTER_USER_KERNEL("square_relu")
.SetCreateFn([]() {
return user_op::NewOpKernel<UnaryPrimitiveKernel>(
"y", "x", [](user_op::KernelComputeContext* ctx) {
const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("x", 0);
const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("y", 0);
return ep::primitive::NewPrimitive<ep::primitive::ElementwiseUnaryFactory>(
ctx->device_type(), ep::primitive::UnaryOp::kSquareReLU, src->data_type(),
dst->data_type());
});
})
.SetIsMatchedHob(UnaryPrimitiveExists(ep::primitive::UnaryOp::kSquareReLU, "y", "x"));

REGISTER_USER_KERNEL("square_relu_grad")
.SetCreateFn([]() {
return user_op::NewOpKernel<BinaryPrimitiveKernel>(
"dx", "dy", "x", [](user_op::KernelComputeContext* ctx) {
const user_op::TensorDesc* src = ctx->TensorDesc4ArgNameAndIndex("dy", 0);
const user_op::TensorDesc* dst = ctx->TensorDesc4ArgNameAndIndex("dx", 0);
return ep::primitive::NewPrimitive<ep::primitive::BroadcastElementwiseBinaryFactory>(
ctx->device_type(), ep::primitive::BinaryOp::kSquareReLUBackwardWithDyX,
src->data_type(), dst->data_type(), 1 /*max_num_dims*/);
});
})
.SetIsMatchedHob(BinaryPrimitiveExists(ep::primitive::BinaryOp::kSquareReLUBackwardWithDyX,
"dx", "dy"));

REGISTER_USER_KERNEL("leaky_relu")
.SetCreateFn([]() {
Expand Down
Loading