Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fmha] packed qkv and output_layout #9950

Merged
merged 4 commits into from
Mar 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion oneflow/core/functional/functional_api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2597,7 +2597,7 @@
bind_python: True

- name: "fused_multi_head_attention_inference_v2"
signature: "Tensor (*, Tensor query, String query_layout, Int64 query_head_size=None, Tensor key, String key_layout, Tensor value, String value_layout, Tensor attn_bias=None, Bool causal=False, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInferenceV2"
signature: 'Tensor (*, Tensor query, String query_layout, Int64 query_head_size=None, Tensor key=None, String key_layout=None, Tensor value=None, String value_layout=None, Tensor attn_bias=None, String output_layout="BM(HK)", Bool causal=False, Int64 causal_diagonal_offset=0) => FusedMultiHeadAttentionInferenceV2'
bind_python: True

- name: "fused_scale_mask_bias_softmax"
Expand Down
185 changes: 136 additions & 49 deletions oneflow/core/functional/impl/nn_functor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4874,7 +4874,7 @@ class FusedMultiHeadAttentionInferenceFunctor {
const int64_t query_head_size = query_hidden_size / num_heads;
return functional::FusedMultiHeadAttentionInferenceV2(query, "BM(HK)", query_head_size, key,
"BM(HK)", value, "BM(HK)", attn_bias,
causal, causal_diagonal_offset);
"BM(HK)", causal, causal_diagonal_offset);
}
};

Expand All @@ -4898,9 +4898,11 @@ class FusedMultiHeadAttentionInferenceV2Functor {
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& query,
const std::string& query_layout,
const Optional<int64_t>& query_head_size,
const std::shared_ptr<one::Tensor>& key, const std::string& key_layout,
const std::shared_ptr<one::Tensor>& value,
const std::string& value_layout, const Optional<one::Tensor>& attn_bias,
const Optional<one::Tensor>& key,
const Optional<std::string>& key_layout,
const Optional<one::Tensor>& value,
const Optional<std::string>& value_layout,
const Optional<one::Tensor>& attn_bias, const std::string& output_layout,
const bool& causal, const int64_t& causal_diagonal_offset) const {
CHECK_GE_OR_RETURN(causal_diagonal_offset, 0)
<< "The value of causal_diagonal_offset should be greater or equal to 0.";
Expand All @@ -4910,43 +4912,63 @@ class FusedMultiHeadAttentionInferenceV2Functor {
const Optional<int64_t>& head_size, int64_t* b, int64_t* m,
int64_t* h, int64_t* k) -> Maybe<void> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个functor里面涉及到的layerout常量字符串要不要直接用constexpr来定义一下,可以避免硬编码和重复,可读性也会好点。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里其实 kBMHK和"BMHK"在可读性和避免出错的角度基本上是一致的,而且因为包含括号,所以命名还更麻烦一点
这里一个更好的处理方式其实是定义一个Layout类来处理这个问题

if (shape.NumAxes() == 3) {
if (layout == "BM(HK)") {
*b = shape.At(0);
*m = shape.At(1);
const int64_t hidden_size = shape.At(2);
if (num_heads) {
const int64_t expected_h = JUST(num_heads);
CHECK_EQ_OR_RETURN(hidden_size % expected_h, 0);
*h = expected_h;
*k = hidden_size / expected_h;
} else if (head_size) {
const int64_t expected_k = JUST(head_size);
CHECK_EQ_OR_RETURN(hidden_size % expected_k, 0);
*h = hidden_size / expected_k;
*k = expected_k;
if (layout == "BM(HK)" || layout == "MB(HK)" || layout == "BM(H2K)" || layout == "MB(H2K)"
|| layout == "BM(H3K)" || layout == "MB(H3K)") {
int64_t packed_n = 0;
if (layout == "BM(HK)") {
*b = shape.At(0);
*m = shape.At(1);
packed_n = 1;
} else if (layout == "MB(HK)") {
*b = shape.At(1);
*m = shape.At(0);
packed_n = 1;
} else if (layout == "BM(H2K)") {
CHECK_NE_OR_RETURN(name, "query") << "query_layout should not be 'BM(H2K)'";
*b = shape.At(0);
*m = shape.At(1);
packed_n = 2;
} else if (layout == "MB(H2K)") {
CHECK_NE_OR_RETURN(name, "query") << "query_layout should not be 'MB(H2K)'";
*b = shape.At(1);
*m = shape.At(0);
packed_n = 2;
} else if (layout == "BM(H3K)") {
*b = shape.At(0);
*m = shape.At(1);
packed_n = 3;
} else if (layout == "MB(H3K)") {
*b = shape.At(1);
*m = shape.At(0);
packed_n = 3;
} else {
UNIMPLEMENTED_THEN_RETURN();
}
} else if (layout == "MB(HK)") {
*b = shape.At(1);
*m = shape.At(0);
const int64_t hidden_size = shape.At(2);
if (num_heads) {
const int64_t expected_h = JUST(num_heads);
CHECK_EQ_OR_RETURN(hidden_size % expected_h, 0);
const int64_t packed_h = packed_n * expected_h;
CHECK_EQ_OR_RETURN(hidden_size % packed_h, 0)
<< "The size of the last dimension of the " << name
<< " tensor should be a multiple of " << packed_h << ".";
*h = expected_h;
*k = hidden_size / expected_h;
*k = hidden_size / packed_h;
} else if (head_size) {
const int64_t expected_k = JUST(head_size);
CHECK_EQ_OR_RETURN(hidden_size % expected_k, 0);
*h = hidden_size / expected_k;
const int64_t packed_k = expected_k * packed_n;
CHECK_EQ_OR_RETURN(hidden_size % packed_k, 0)
<< "The size of the last dimension of the " << name
<< " tensor should be a multiple of " << packed_k << ".";
*h = hidden_size / packed_k;
*k = expected_k;
} else {
UNIMPLEMENTED_THEN_RETURN();
}
} else {
UNIMPLEMENTED_THEN_RETURN()
<< name << "_layout should be 'BM(HK)' or 'MB(HK)' when the number of dimensions of "
<< name
<< "_layout should be 'BM(HK)', 'MB(HK)', 'BM(H2K)', 'MB(H2K)', 'BM(H3K)' or "
"'MB(H3K)' when the number of dimensions of "
<< name << " tensor is 3.";
}
} else if (shape.NumAxes() == 4) {
Expand All @@ -4960,11 +4982,16 @@ class FusedMultiHeadAttentionInferenceV2Functor {
*m = shape.At(2);
*h = shape.At(1);
*k = shape.At(3);
} else if (layout == "MBHK") {
*b = shape.At(1);
*m = shape.At(0);
*h = shape.At(2);
*k = shape.At(3);
} else {
UNIMPLEMENTED_THEN_RETURN()
<< name << "_layout should be 'BMHK' or 'BHMK' when the number of dimensions of "
<< name
<< "_layout should be 'BMHK', 'BHMK' or 'MBHK' when the number of dimensions of "
<< name << " tensor is 4.";
;
}
if (num_heads) {
const int64_t expected_h = JUST(num_heads);
Expand All @@ -4983,6 +5010,11 @@ class FusedMultiHeadAttentionInferenceV2Functor {
return Maybe<void>::Ok();
};

std::shared_ptr<one::Tensor> key_tensor;
std::string key_tensor_layout;
std::shared_ptr<one::Tensor> value_tensor;
std::string value_tensor_layout;

int64_t q_b = 0;
int64_t q_m = 0;
int64_t q_h = 0;
Expand All @@ -4996,25 +5028,56 @@ class FusedMultiHeadAttentionInferenceV2Functor {
int64_t k_m = 0;
int64_t k_h = 0;
int64_t k_k = 0;
JUST(ParseDims("key", *key->shape(), key_layout, Optional<int64_t>(), q_k, &k_b, &k_m, &k_h,
&k_k));
CHECK_EQ_OR_RETURN(k_b, q_b) << "The size of dimension 'B' of the key tensor should be the "
"same as that of the query tensor.";
CHECK_EQ_OR_RETURN(k_h, q_h) << "The size of dimension 'H' of the key tensor should be the "
"same as that of the query tensor.";
if (key) {
key_tensor = JUST(key);
key_tensor_layout = *JUST(key_layout);
JUST(ParseDims("key", *key_tensor->shape(), key_tensor_layout, Optional<int64_t>(), q_k, &k_b,
&k_m, &k_h, &k_k));
CHECK_EQ_OR_RETURN(k_b, q_b) << "The size of dimension 'B' of the key tensor should be the "
"same as that of the query tensor.";
CHECK_EQ_OR_RETURN(k_h, q_h) << "The size of dimension 'H' of the key tensor should be the "
"same as that of the query tensor.";

} else {
CHECK_OR_RETURN(query_layout == "BM(H3K)" || query_layout == "MB(H3K)")
<< "The value of query_layout should be 'BM(H3K)' or 'MB(H3K)' when the key tensor is "
"None.";
key_tensor = query;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里加个检查CHECK_OR_RETURN(query_layout=="BM(H3K)" || query_layout=="MB(H3K)")

key_tensor_layout = query_layout;
k_b = q_b;
k_m = q_m;
k_h = q_h;
k_k = q_k;
}

int64_t v_b = 0;
int64_t v_m = 0;
int64_t v_h = 0;
int64_t v_k = 0;
JUST(ParseDims("value", *value->shape(), value_layout, q_h, Optional<int64_t>(), &v_b, &v_m,
&v_h, &v_k));
CHECK_EQ_OR_RETURN(v_b, q_b) << "The size of dimension 'B' of the value tensor should be the "
"same as that of the query tensor.";
CHECK_EQ_OR_RETURN(v_m, k_m) << "The size of dimension 'M' of the value tensor should be the "
"same as that of the key tensor.";
CHECK_EQ_OR_RETURN(v_k % 8, 0)
<< "The size of dimension 'K' of the value tensor should be a multiple of 8.";
if (value) {
value_tensor = JUST(value);
value_tensor_layout = *JUST(value_layout);
JUST(ParseDims("value", *value_tensor->shape(), value_tensor_layout, q_h, Optional<int64_t>(),
&v_b, &v_m, &v_h, &v_k));
CHECK_EQ_OR_RETURN(v_b, q_b) << "The size of dimension 'B' of the value tensor should be the "
"same as that of the query tensor.";
CHECK_EQ_OR_RETURN(v_m, k_m) << "The size of dimension 'M' of the value tensor should be the "
"same as that of the key tensor.";
CHECK_EQ_OR_RETURN(v_k % 8, 0)
<< "The size of dimension 'K' of the value tensor should be a multiple of 8.";

} else {
CHECK_OR_RETURN(key_tensor_layout == "BM(H2K)" || key_tensor_layout == "MB(H2K)"
|| key_tensor_layout == "BM(H3K)" || key_tensor_layout == "MB(H3K)")
<< "The value of key_layout should be 'BM(H3K)', 'MB(H3K)', 'BM(H2K)' or 'MB(H2K)' when "
"the value tensor is None.";
value_tensor = key_tensor;
value_tensor_layout = key_tensor_layout;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也加个检查CHECK_OR_RETURN(key_tensor_layout =="BM(H2K)" || key_tensor_layout =="MB(H2K)"|| key_tensor_layout =="BM(H3K)" || key_tensor_layout =="MB(H3K)")

v_b = k_b;
v_m = k_m;
v_h = k_h;
v_k = k_k;
}

if (attn_bias) {
const auto attn_bias_shape = JUST(attn_bias)->shape();
Expand Down Expand Up @@ -5046,15 +5109,39 @@ class FusedMultiHeadAttentionInferenceV2Functor {
}
}

auto& attrs =
THREAD_CACHED_MUTABLE_ATTR_MAP("query_layout", "key_layout", "value_layout",
"query_head_size", "causal", "causal_diagonal_offset");
attrs.SetAllAttrs(query_layout, key_layout, value_layout, q_k, causal, causal_diagonal_offset);
std::string op_output_layout;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个functor里面if else也有点多,虽然不影响功能但或许switch会好看一点点比如这里改成:

constexpr auto kBMHK = "BM(HK)";
constexpr auto kMBHK = "MB(HK)";
auto op_output_layout = output_layout;
switch (output_layout) {
  case kBMHK:
    break;
  case kMBHK:
    if (q_b != 1) {
      op_output_layout = kBMHK;
    }
    break;
  default:
    UNIMPLEMENTED_THEN_RETURN() << "output_layout should be '" << kBMHK << "' or '" << kMBHK << "'";
}

只是提个comment,不一定需要修改

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

if (output_layout == "BM(HK)") {
op_output_layout = output_layout;
} else if (output_layout == "MB(HK)") {
if (q_b == 1) {
op_output_layout = output_layout;
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_output_layout是不是固定为"BM(HK)"?如果是的话,这里的q_b==1感觉可以删除?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_output_layout是不是固定为"BM(HK)"?如果是的话,这里的q_b==1感觉可以删除?

目前Kernel计算只支持 BM(HK),所以当output_layout 是其他值的时候需要转置,但是当B=1的时候, BM(HK) 和 MB(HK) 的数据分布是一致的,可以省掉一个转置

op_output_layout = "BM(HK)";
}
} else {
UNIMPLEMENTED_THEN_RETURN() << "output_layout should be 'BM(HK)' or 'MB(HK)'";
}
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("query_layout", "key_layout", "value_layout",
"output_layout", "query_head_size", "causal",
"causal_diagonal_offset");
attrs.SetAllAttrs(query_layout, key_tensor_layout, value_tensor_layout, op_output_layout, q_k,
causal, causal_diagonal_offset);
std::shared_ptr<one::Tensor> op_output;
if (attn_bias) {
return OpInterpUtil::Dispatch<Tensor>(*op_with_attn_bias_,
{query, key, value, JUST(attn_bias)}, attrs);
op_output = JUST(OpInterpUtil::Dispatch<Tensor>(
*op_with_attn_bias_, {query, key_tensor, value_tensor, JUST(attn_bias)}, attrs));
} else {
return OpInterpUtil::Dispatch<Tensor>(*op_, {query, key, value}, attrs);
op_output =
JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {query, key_tensor, value_tensor}, attrs));
}
if (op_output_layout == output_layout) {
return op_output;
} else {
if (op_output_layout == "BM(HK)" && output_layout == "MB(HK)") {
return functional::Transpose(op_output, {1, 0, 2});
} else {
UNIMPLEMENTED_THEN_RETURN();
}
}
}

Expand Down
3 changes: 2 additions & 1 deletion oneflow/ir/include/OneFlow/OneFlowUserOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2864,7 +2864,8 @@ def OneFlow_FusedMultiHeadAttentionInferenceOp : OneFlow_BaseOp<"fused_multi_hea
DefaultValuedAttr<SI64Attr, "0">:$causal_diagonal_offset,
StrAttr:$query_layout,
StrAttr:$key_layout,
StrAttr:$value_layout
StrAttr:$value_layout,
StrAttr:$output_layout
);
let has_logical_tensor_desc_infer_fn = 1;
let has_physical_tensor_desc_infer_fn = 1;
Expand Down
6 changes: 4 additions & 2 deletions oneflow/ir/lib/OneFlow/PDLL/ForwardOpPatterns.pdll
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ Pattern {
query_head_size = GetHeadSizeFromTranpose(query_transpose),
query_layout = attr<"\"BM(HK)\"">,
key_layout = attr<"\"BM(HK)\"">,
value_layout = attr<"\"BM(HK)\"">
value_layout = attr<"\"BM(HK)\"">,
output_layout = attr<"\"BM(HK)\"">
} -> (out_t));
};
}
Expand Down Expand Up @@ -105,7 +106,8 @@ Pattern {
query_head_size = GetHeadSizeFromTranpose(query_permute),
query_layout = attr<"\"BM(HK)\"">,
key_layout = attr<"\"BM(HK)\"">,
value_layout = attr<"\"BM(HK)\"">
value_layout = attr<"\"BM(HK)\"">,
output_layout = attr<"\"BM(HK)\"">
} -> (out_t));
};
}
Expand Down
4 changes: 2 additions & 2 deletions oneflow/ir/test/OneFlow/fuse_forward_ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ module {
%out_transpose = "oneflow.transpose"(%out) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "transpose-11", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 12 : i64} : (tensor<2x8x4096x40xf16>) -> tensor<2x4096x8x40xf16>
%out_reshape = "oneflow.reshape"(%out_transpose) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "reshape-12", scope_symbol_id = 12 : i64, shape = [2 : si64, 4096 : si64, 320 : si64]} : (tensor<2x4096x8x40xf16>) -> tensor<2x4096x320xf16>
// CHECK: func.func @fuse_mha(%[[QUERY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[KEY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[VALUE:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>)
// CHECK: "oneflow.fused_multi_head_attention_inference"(%[[QUERY]], %[[KEY]], %[[VALUE]]) {causal = false, causal_diagonal_offset = 0 : si64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], key_layout = "BM(HK)", op_name = [[OP_NAME:".*"]], query_head_size = 40 : si64, query_layout = "BM(HK)", scope_symbol_id = 12 : i64, value_layout = "BM(HK)"} : (tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16>
// CHECK: "oneflow.fused_multi_head_attention_inference"(%[[QUERY]], %[[KEY]], %[[VALUE]]) {causal = false, causal_diagonal_offset = 0 : si64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], key_layout = "BM(HK)", op_name = [[OP_NAME:".*"]], output_layout = "BM(HK)", query_head_size = 40 : si64, query_layout = "BM(HK)", scope_symbol_id = 12 : i64, value_layout = "BM(HK)"} : (tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16>
return %out_reshape : tensor<2x4096x320xf16>
}

Expand All @@ -65,7 +65,7 @@ module {
%311 = "oneflow.transpose"(%310) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-transpose-133", perm = [0 : si32, 2 : si32, 1 : si32, 3 : si32], scope_symbol_id = 661 : i64} : (tensor<2x8x4096x40xf16>) -> tensor<2x4096x8x40xf16>
%out_reshape_to_heads = "oneflow.reshape"(%311) {device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], op_name = "unet.down_blocks.0.attentions.1.transformer_blocks.0.attn1-reshape-134", scope_symbol_id = 661 : i64, shape = [2 : si64, 4096 : si64, 320 : si64]} : (tensor<2x4096x8x40xf16>) -> tensor<2x4096x320xf16>
// CHECK: func.func @fuse_mha2(%[[QUERY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[KEY:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>, %[[VALUE:[a-zA-Z0-9_]+]]: tensor<2x4096x320xf16>)
// CHECK: "oneflow.fused_multi_head_attention_inference"(%[[QUERY]], %[[KEY]], %[[VALUE]]) {causal = false, causal_diagonal_offset = 0 : si64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], key_layout = "BM(HK)", op_name = [[OP_NAME:".*"]], query_head_size = 40 : si64, query_layout = "BM(HK)", scope_symbol_id = 661 : i64, value_layout = "BM(HK)"} : (tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16>
// CHECK: "oneflow.fused_multi_head_attention_inference"(%[[QUERY]], %[[KEY]], %[[VALUE]]) {causal = false, causal_diagonal_offset = 0 : si64, device_name = ["@0:0"], device_tag = "cuda", hierarchy = [1], key_layout = "BM(HK)", op_name = [[OP_NAME:".*"]], output_layout = "BM(HK)", query_head_size = 40 : si64, query_layout = "BM(HK)", scope_symbol_id = 661 : i64, value_layout = "BM(HK)"} : (tensor<2x4096x320xf16>, tensor<2x4096x320xf16>, tensor<2x4096x320xf16>) -> tensor<2x4096x320xf16>
return %out_reshape_to_heads : tensor<2x4096x320xf16>
}

Expand Down
Loading