Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FusedAttentionConcatPastKeyValue #9963

Merged
merged 14 commits into from
Mar 9, 2023
5 changes: 2 additions & 3 deletions cmake/oneflow.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -334,9 +334,8 @@ if(BUILD_CUDA AND WITH_CUTLASS)
add_definitions(-DCUTLASS_ENABLE_TENSOR_CORE_MMA=1)
endif()

set_property(
SOURCE ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/fused_multi_head_attention_inference_kernel.cu
APPEND PROPERTY INCLUDE_DIRECTORIES ${CUTLASS_INSTALL_DIR}/examples/xformers_fmha)
set_property(SOURCE ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/fused_attention_kernels.cu APPEND
PROPERTY INCLUDE_DIRECTORIES ${CUTLASS_INSTALL_DIR}/examples/xformers_fmha)
set_property(SOURCE ${PROJECT_SOURCE_DIR}/oneflow/user/kernels/fused_glu_kernel.cu APPEND
PROPERTY INCLUDE_DIRECTORIES ${CUTLASS_INSTALL_DIR}/examples/45_dual_gemm)
if("${CMAKE_CUDA_COMPILER_ID}" STREQUAL "NVIDIA")
Expand Down
4 changes: 4 additions & 0 deletions oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2600,6 +2600,10 @@
signature: 'Tensor (*, Tensor query, String query_layout, Int64 query_head_size=None, Tensor key=None, String key_layout=None, Tensor value=None, String value_layout=None, Tensor attn_bias=None, String output_layout="BM(HK)", Bool causal=False, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInferenceV2'
bind_python: True

- name: "fused_attention_concat_past_key_value"
signature: 'TensorTuple (*, Tensor past_key, String past_key_layout, Tensor past_value, String past_value_layout, Tensor key, String key_layout, Tensor value, String value_layout, Int64 key_head_size=None) => FusedAttentionConcatPastKeyValue'
bind_python: True

- name: "fused_scale_mask_bias_softmax"
signature: 'Tensor (Tensor x, Tensor mask, Tensor bias=None, Float scale=0.35355, Bool inplace=False) => FusedScaleMaskBiasSoftmax'
bind_python: True
Expand Down
415 changes: 415 additions & 0 deletions oneflow/core/functional/impl/fused_attention_functor.cpp

Large diffs are not rendered by default.

302 changes: 0 additions & 302 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4851,305 +4851,6 @@ class BatchNormBackwardElemtFunctor {
std::shared_ptr<OpExpr> op_;
};

class FusedMultiHeadAttentionInferenceFunctor {
public:
FusedMultiHeadAttentionInferenceFunctor() = default;
Maybe<Tensor> operator()(
const std::shared_ptr<one::Tensor>& query, const std::shared_ptr<one::Tensor>& key,
const std::shared_ptr<one::Tensor>& value, const int64_t& num_heads, const bool& causal,
const int64_t& query_hidden_slice_start, const int64_t& query_hidden_slice_end,
const int64_t& key_hidden_slice_start, const int64_t& key_hidden_slice_end,
const int64_t& value_hidden_slice_start, const int64_t& value_hidden_slice_end,
const Optional<one::Tensor>& attn_bias, const int64_t& causal_diagonal_offset) const {
CHECK_OR_RETURN(query_hidden_slice_start == 0 && key_hidden_slice_start == 0
&& value_hidden_slice_start == 0 && query_hidden_slice_end == -1
&& key_hidden_slice_end == -1 && value_hidden_slice_end == -1)
<< "The parameters 'query_hidden_slice_start', 'query_hidden_slice_end', "
"'key_hidden_slice_start', 'key_hidden_slice_end', 'value_hidden_slice_start', "
"'value_hidden_slice_end' have been deprecated.";

const int64_t query_hidden_size = query->shape()->At(2);
CHECK_EQ_OR_RETURN(query_hidden_size % num_heads, 0)
<< "The hidden size of the query tensor should be a multiple of num_heads.";
const int64_t query_head_size = query_hidden_size / num_heads;
return functional::FusedMultiHeadAttentionInferenceV2(query, "BM(HK)", query_head_size, key,
"BM(HK)", value, "BM(HK)", attn_bias,
"BM(HK)", causal, causal_diagonal_offset);
}
};

class FusedMultiHeadAttentionInferenceV2Functor {
public:
FusedMultiHeadAttentionInferenceV2Functor() {
op_ = CHECK_JUST(one::OpBuilder("fused_multi_head_attention_inference")
.Input("query")
.Input("key")
.Input("value")
.Output("out")
.Build());
op_with_attn_bias_ = CHECK_JUST(one::OpBuilder("fused_multi_head_attention_inference")
.Input("query")
.Input("key")
.Input("value")
.Input("attn_bias")
.Output("out")
.Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& query,
const std::string& query_layout,
const Optional<int64_t>& query_head_size,
const Optional<one::Tensor>& key,
const Optional<std::string>& key_layout,
const Optional<one::Tensor>& value,
const Optional<std::string>& value_layout,
const Optional<one::Tensor>& attn_bias, const std::string& output_layout,
const bool& causal, const int64_t& causal_diagonal_offset) const {
CHECK_GE_OR_RETURN(causal_diagonal_offset, 0)
<< "The value of causal_diagonal_offset should be greater or equal to 0.";

const auto ParseDims = [](const std::string& name, const Shape& shape,
const std::string& layout, const Optional<int64_t>& num_heads,
const Optional<int64_t>& head_size, int64_t* b, int64_t* m,
int64_t* h, int64_t* k) -> Maybe<void> {
if (shape.NumAxes() == 3) {
if (layout == "BM(HK)" || layout == "MB(HK)" || layout == "BM(H2K)" || layout == "MB(H2K)"
|| layout == "BM(H3K)" || layout == "MB(H3K)") {
int64_t packed_n = 0;
if (layout == "BM(HK)") {
*b = shape.At(0);
*m = shape.At(1);
packed_n = 1;
} else if (layout == "MB(HK)") {
*b = shape.At(1);
*m = shape.At(0);
packed_n = 1;
} else if (layout == "BM(H2K)") {
CHECK_NE_OR_RETURN(name, "query") << "query_layout should not be 'BM(H2K)'";
*b = shape.At(0);
*m = shape.At(1);
packed_n = 2;
} else if (layout == "MB(H2K)") {
CHECK_NE_OR_RETURN(name, "query") << "query_layout should not be 'MB(H2K)'";
*b = shape.At(1);
*m = shape.At(0);
packed_n = 2;
} else if (layout == "BM(H3K)") {
*b = shape.At(0);
*m = shape.At(1);
packed_n = 3;
} else if (layout == "MB(H3K)") {
*b = shape.At(1);
*m = shape.At(0);
packed_n = 3;
} else {
UNIMPLEMENTED_THEN_RETURN();
}
const int64_t hidden_size = shape.At(2);
if (num_heads) {
const int64_t expected_h = JUST(num_heads);
const int64_t packed_h = packed_n * expected_h;
CHECK_EQ_OR_RETURN(hidden_size % packed_h, 0)
<< "The size of the last dimension of the " << name
<< " tensor should be a multiple of " << packed_h << ".";
*h = expected_h;
*k = hidden_size / packed_h;
} else if (head_size) {
const int64_t expected_k = JUST(head_size);
const int64_t packed_k = expected_k * packed_n;
CHECK_EQ_OR_RETURN(hidden_size % packed_k, 0)
<< "The size of the last dimension of the " << name
<< " tensor should be a multiple of " << packed_k << ".";
*h = hidden_size / packed_k;
*k = expected_k;
} else {
UNIMPLEMENTED_THEN_RETURN();
}
} else {
UNIMPLEMENTED_THEN_RETURN()
<< name
<< "_layout should be 'BM(HK)', 'MB(HK)', 'BM(H2K)', 'MB(H2K)', 'BM(H3K)' or "
"'MB(H3K)' when the number of dimensions of "
<< name << " tensor is 3.";
}
} else if (shape.NumAxes() == 4) {
if (layout == "BMHK") {
*b = shape.At(0);
*m = shape.At(1);
*h = shape.At(2);
*k = shape.At(3);
} else if (layout == "BHMK") {
*b = shape.At(0);
*m = shape.At(2);
*h = shape.At(1);
*k = shape.At(3);
} else if (layout == "MBHK") {
*b = shape.At(1);
*m = shape.At(0);
*h = shape.At(2);
*k = shape.At(3);
} else {
UNIMPLEMENTED_THEN_RETURN()
<< name
<< "_layout should be 'BMHK', 'BHMK' or 'MBHK' when the number of dimensions of "
<< name << " tensor is 4.";
}
if (num_heads) {
const int64_t expected_h = JUST(num_heads);
CHECK_EQ_OR_RETURN(*h, expected_h) << "The size of dimension 'H' of " << name
<< " tensor should be " << expected_h << ".";
}
if (head_size) {
const int64_t expected_k = JUST(head_size);
CHECK_EQ_OR_RETURN(*k, expected_k) << "The size of dimension 'K' of " << name
<< " tensor should be " << expected_k << ".";
}
} else {
UNIMPLEMENTED_THEN_RETURN()
<< "The number of dimensions of the " << name << " tensor should be 3 or 4";
};
return Maybe<void>::Ok();
};

std::shared_ptr<one::Tensor> key_tensor;
std::string key_tensor_layout;
std::shared_ptr<one::Tensor> value_tensor;
std::string value_tensor_layout;

int64_t q_b = 0;
int64_t q_m = 0;
int64_t q_h = 0;
int64_t q_k = 0;
JUST(ParseDims("query", *query->shape(), query_layout, Optional<int64_t>(), query_head_size,
&q_b, &q_m, &q_h, &q_k));
CHECK_EQ_OR_RETURN(q_k % 8, 0)
<< "The size of dimension 'K' of the query tensor should be a multiple of 8.";

int64_t k_b = 0;
int64_t k_m = 0;
int64_t k_h = 0;
int64_t k_k = 0;
if (key) {
key_tensor = JUST(key);
key_tensor_layout = *JUST(key_layout);
JUST(ParseDims("key", *key_tensor->shape(), key_tensor_layout, Optional<int64_t>(), q_k, &k_b,
&k_m, &k_h, &k_k));
CHECK_EQ_OR_RETURN(k_b, q_b) << "The size of dimension 'B' of the key tensor should be the "
"same as that of the query tensor.";
CHECK_EQ_OR_RETURN(k_h, q_h) << "The size of dimension 'H' of the key tensor should be the "
"same as that of the query tensor.";

} else {
CHECK_OR_RETURN(query_layout == "BM(H3K)" || query_layout == "MB(H3K)")
<< "The value of query_layout should be 'BM(H3K)' or 'MB(H3K)' when the key tensor is "
"None.";
key_tensor = query;
key_tensor_layout = query_layout;
k_b = q_b;
k_m = q_m;
k_h = q_h;
k_k = q_k;
}

int64_t v_b = 0;
int64_t v_m = 0;
int64_t v_h = 0;
int64_t v_k = 0;
if (value) {
value_tensor = JUST(value);
value_tensor_layout = *JUST(value_layout);
JUST(ParseDims("value", *value_tensor->shape(), value_tensor_layout, q_h, Optional<int64_t>(),
&v_b, &v_m, &v_h, &v_k));
CHECK_EQ_OR_RETURN(v_b, q_b) << "The size of dimension 'B' of the value tensor should be the "
"same as that of the query tensor.";
CHECK_EQ_OR_RETURN(v_m, k_m) << "The size of dimension 'M' of the value tensor should be the "
"same as that of the key tensor.";
CHECK_EQ_OR_RETURN(v_k % 8, 0)
<< "The size of dimension 'K' of the value tensor should be a multiple of 8.";

} else {
CHECK_OR_RETURN(key_tensor_layout == "BM(H2K)" || key_tensor_layout == "MB(H2K)"
|| key_tensor_layout == "BM(H3K)" || key_tensor_layout == "MB(H3K)")
<< "The value of key_layout should be 'BM(H3K)', 'MB(H3K)', 'BM(H2K)' or 'MB(H2K)' when "
"the value tensor is None.";
value_tensor = key_tensor;
value_tensor_layout = key_tensor_layout;
v_b = k_b;
v_m = k_m;
v_h = k_h;
v_k = k_k;
}

if (attn_bias) {
const auto attn_bias_shape = JUST(attn_bias)->shape();
const int64_t num_attn_bias_axes = attn_bias_shape->NumAxes();
CHECK_OR_RETURN(num_attn_bias_axes > 0 && num_attn_bias_axes <= 4)
<< "The number of dimensions of attn_bias should be greater than 0 and less than or "
"equal to 4.";
CHECK_GE_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 1), k_m)
<< "The size of the -1 dimension of attn_bias should be greater than or equal to the "
"dimension 'M' of the key tensor";
CHECK_EQ_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 1) % 8, 0)
<< "The size of the -1 dimension of attn_bias should be a multiple of 8.";
if (num_attn_bias_axes >= 2) {
CHECK_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 2) == 1
|| attn_bias_shape->At(num_attn_bias_axes - 2) >= q_m)
<< "The size of the -2 dimension of attn_bias should be greater than or equal to the "
"dimension 'M' of the query tensor or equal to 1.";
}
if (num_attn_bias_axes >= 3) {
CHECK_OR_RETURN(attn_bias_shape->At(num_attn_bias_axes - 3) == 1
|| attn_bias_shape->At(num_attn_bias_axes - 3) == q_h)
<< "The size of the -3 dimension of attn_bias should be equal to the dimension 'H' of "
"the query tensor or equal to 1.";
}
if (num_attn_bias_axes == 4) {
CHECK_OR_RETURN(attn_bias_shape->At(0) == 1 || attn_bias_shape->At(0) == q_b)
<< "The size of the -4 dimension of attn_bias should be equal to the dimension 'B' of "
"the query tensor or equal to 1.";
}
}

std::string op_output_layout;
if (output_layout == "BM(HK)") {
op_output_layout = output_layout;
} else if (output_layout == "MB(HK)") {
if (q_b == 1) {
op_output_layout = output_layout;
} else {
op_output_layout = "BM(HK)";
}
} else {
UNIMPLEMENTED_THEN_RETURN() << "output_layout should be 'BM(HK)' or 'MB(HK)'";
}
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("query_layout", "key_layout", "value_layout",
"output_layout", "query_head_size", "causal",
"causal_diagonal_offset");
attrs.SetAllAttrs(query_layout, key_tensor_layout, value_tensor_layout, op_output_layout, q_k,
causal, causal_diagonal_offset);
std::shared_ptr<one::Tensor> op_output;
if (attn_bias) {
op_output = JUST(OpInterpUtil::Dispatch<Tensor>(
*op_with_attn_bias_, {query, key_tensor, value_tensor, JUST(attn_bias)}, attrs));
} else {
op_output =
JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {query, key_tensor, value_tensor}, attrs));
}
if (op_output_layout == output_layout) {
return op_output;
} else {
if (op_output_layout == "BM(HK)" && output_layout == "MB(HK)") {
return functional::Transpose(op_output, {1, 0, 2});
} else {
UNIMPLEMENTED_THEN_RETURN();
}
}
}

private:
std::shared_ptr<OpExpr> op_;
std::shared_ptr<OpExpr> op_with_attn_bias_;
};

class FusedFastGeluMulFunctor {
public:
FusedFastGeluMulFunctor() {
Expand Down Expand Up @@ -5502,9 +5203,6 @@ ONEFLOW_FUNCTION_LIBRARY(m) {
m.add_functor<impl::BatchNormElemtFunctor>("BatchNormElemt");
m.add_functor<impl::BatchNormBackwardReduceFunctor>("BatchNormBackwardReduce");
m.add_functor<impl::BatchNormBackwardElemtFunctor>("BatchNormBackwardElemt");
m.add_functor<impl::FusedMultiHeadAttentionInferenceFunctor>("FusedMultiHeadAttentionInference");
m.add_functor<impl::FusedMultiHeadAttentionInferenceV2Functor>(
"FusedMultiHeadAttentionInferenceV2");
m.add_functor<impl::FusedFastGeluMulFunctor>("FusedFastGeluMul");
m.add_functor<impl::FusedFastGeluMulGradFunctor>("FusedFastGeluMulGrad");
m.add_functor<impl::GroupedMatmulBiasFunctor>("GroupedMatmulBias");
Expand Down
24 changes: 24 additions & 0 deletions oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2873,6 +2873,30 @@ def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_hea
let has_data_type_infer_fn = 1;
}

def OneFlow_FusedAttentionConcatPastKeyValueOp : OneFlow_BaseOp<"fused_attention_concat_past_key_value", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$past_key,
OneFlow_Tensor:$past_value,
OneFlow_Tensor:$key,
OneFlow_Tensor:$value
);
let output = (outs
OneFlow_Tensor:$output_key,
OneFlow_Tensor:$output_value
);
let attrs = (ins
StrAttr:$past_key_layout,
StrAttr:$past_value_layout,
StrAttr:$key_layout,
StrAttr:$value_layout,
SI64Attr:$key_head_size
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
let has_get_sbp_fn = 1;
let has_data_type_infer_fn = 1;
}

def OneFlow_FusedFastGeluMulOp : OneFlow_BaseOp<"fused_fast_gelu_mul", [NoSideEffect, DeclareOpInterfaceMethods<UserOpCompatibleInterface>]> {
let input = (ins
OneFlow_Tensor:$in,
Expand Down
Loading