-
Notifications
You must be signed in to change notification settings - Fork 802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fmha] packed qkv and output_layout #9950
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4874,7 +4874,7 @@ class FusedMultiHeadAttentionInferenceFunctor { | |
const int64_t query_head_size = query_hidden_size / num_heads; | ||
return functional::FusedMultiHeadAttentionInferenceV2(query, "BM(HK)", query_head_size, key, | ||
"BM(HK)", value, "BM(HK)", attn_bias, | ||
causal, causal_diagonal_offset); | ||
"BM(HK)", causal, causal_diagonal_offset); | ||
} | ||
}; | ||
|
||
|
@@ -4898,9 +4898,11 @@ class FusedMultiHeadAttentionInferenceV2Functor { | |
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& query, | ||
const std::string& query_layout, | ||
const Optional<int64_t>& query_head_size, | ||
const std::shared_ptr<one::Tensor>& key, const std::string& key_layout, | ||
const std::shared_ptr<one::Tensor>& value, | ||
const std::string& value_layout, const Optional<one::Tensor>& attn_bias, | ||
const Optional<one::Tensor>& key, | ||
const Optional<std::string>& key_layout, | ||
const Optional<one::Tensor>& value, | ||
const Optional<std::string>& value_layout, | ||
const Optional<one::Tensor>& attn_bias, const std::string& output_layout, | ||
const bool& causal, const int64_t& causal_diagonal_offset) const { | ||
CHECK_GE_OR_RETURN(causal_diagonal_offset, 0) | ||
<< "The value of causal_diagonal_offset should be greater or equal to 0."; | ||
|
@@ -4910,43 +4912,63 @@ class FusedMultiHeadAttentionInferenceV2Functor { | |
const Optional<int64_t>& head_size, int64_t* b, int64_t* m, | ||
int64_t* h, int64_t* k) -> Maybe<void> { | ||
if (shape.NumAxes() == 3) { | ||
if (layout == "BM(HK)") { | ||
*b = shape.At(0); | ||
*m = shape.At(1); | ||
const int64_t hidden_size = shape.At(2); | ||
if (num_heads) { | ||
const int64_t expected_h = JUST(num_heads); | ||
CHECK_EQ_OR_RETURN(hidden_size % expected_h, 0); | ||
*h = expected_h; | ||
*k = hidden_size / expected_h; | ||
} else if (head_size) { | ||
const int64_t expected_k = JUST(head_size); | ||
CHECK_EQ_OR_RETURN(hidden_size % expected_k, 0); | ||
*h = hidden_size / expected_k; | ||
*k = expected_k; | ||
if (layout == "BM(HK)" || layout == "MB(HK)" || layout == "BM(H2K)" || layout == "MB(H2K)" | ||
|| layout == "BM(H3K)" || layout == "MB(H3K)") { | ||
int64_t packed_n = 0; | ||
if (layout == "BM(HK)") { | ||
*b = shape.At(0); | ||
*m = shape.At(1); | ||
packed_n = 1; | ||
} else if (layout == "MB(HK)") { | ||
*b = shape.At(1); | ||
*m = shape.At(0); | ||
packed_n = 1; | ||
} else if (layout == "BM(H2K)") { | ||
CHECK_NE_OR_RETURN(name, "query") << "query_layout should not be 'BM(H2K)'"; | ||
*b = shape.At(0); | ||
*m = shape.At(1); | ||
packed_n = 2; | ||
} else if (layout == "MB(H2K)") { | ||
CHECK_NE_OR_RETURN(name, "query") << "query_layout should not be 'MB(H2K)'"; | ||
*b = shape.At(1); | ||
*m = shape.At(0); | ||
packed_n = 2; | ||
} else if (layout == "BM(H3K)") { | ||
*b = shape.At(0); | ||
*m = shape.At(1); | ||
packed_n = 3; | ||
} else if (layout == "MB(H3K)") { | ||
*b = shape.At(1); | ||
*m = shape.At(0); | ||
packed_n = 3; | ||
} else { | ||
UNIMPLEMENTED_THEN_RETURN(); | ||
} | ||
} else if (layout == "MB(HK)") { | ||
*b = shape.At(1); | ||
*m = shape.At(0); | ||
const int64_t hidden_size = shape.At(2); | ||
if (num_heads) { | ||
const int64_t expected_h = JUST(num_heads); | ||
CHECK_EQ_OR_RETURN(hidden_size % expected_h, 0); | ||
const int64_t packed_h = packed_n * expected_h; | ||
CHECK_EQ_OR_RETURN(hidden_size % packed_h, 0) | ||
<< "The size of the last dimension of the " << name | ||
<< " tensor should be a multiple of " << packed_h << "."; | ||
*h = expected_h; | ||
*k = hidden_size / expected_h; | ||
*k = hidden_size / packed_h; | ||
} else if (head_size) { | ||
const int64_t expected_k = JUST(head_size); | ||
CHECK_EQ_OR_RETURN(hidden_size % expected_k, 0); | ||
*h = hidden_size / expected_k; | ||
const int64_t packed_k = expected_k * packed_n; | ||
CHECK_EQ_OR_RETURN(hidden_size % packed_k, 0) | ||
<< "The size of the last dimension of the " << name | ||
<< " tensor should be a multiple of " << packed_k << "."; | ||
*h = hidden_size / packed_k; | ||
*k = expected_k; | ||
} else { | ||
UNIMPLEMENTED_THEN_RETURN(); | ||
} | ||
} else { | ||
UNIMPLEMENTED_THEN_RETURN() | ||
<< name << "_layout should be 'BM(HK)' or 'MB(HK)' when the number of dimensions of " | ||
<< name | ||
<< "_layout should be 'BM(HK)', 'MB(HK)', 'BM(H2K)', 'MB(H2K)', 'BM(H3K)' or " | ||
"'MB(H3K)' when the number of dimensions of " | ||
<< name << " tensor is 3."; | ||
} | ||
} else if (shape.NumAxes() == 4) { | ||
|
@@ -4960,11 +4982,16 @@ class FusedMultiHeadAttentionInferenceV2Functor { | |
*m = shape.At(2); | ||
*h = shape.At(1); | ||
*k = shape.At(3); | ||
} else if (layout == "MBHK") { | ||
*b = shape.At(1); | ||
*m = shape.At(0); | ||
*h = shape.At(2); | ||
*k = shape.At(3); | ||
} else { | ||
UNIMPLEMENTED_THEN_RETURN() | ||
<< name << "_layout should be 'BMHK' or 'BHMK' when the number of dimensions of " | ||
<< name | ||
<< "_layout should be 'BMHK', 'BHMK' or 'MBHK' when the number of dimensions of " | ||
<< name << " tensor is 4."; | ||
; | ||
} | ||
if (num_heads) { | ||
const int64_t expected_h = JUST(num_heads); | ||
|
@@ -4983,6 +5010,11 @@ class FusedMultiHeadAttentionInferenceV2Functor { | |
return Maybe<void>::Ok(); | ||
}; | ||
|
||
std::shared_ptr<one::Tensor> key_tensor; | ||
std::string key_tensor_layout; | ||
std::shared_ptr<one::Tensor> value_tensor; | ||
std::string value_tensor_layout; | ||
|
||
int64_t q_b = 0; | ||
int64_t q_m = 0; | ||
int64_t q_h = 0; | ||
|
@@ -4996,25 +5028,56 @@ class FusedMultiHeadAttentionInferenceV2Functor { | |
int64_t k_m = 0; | ||
int64_t k_h = 0; | ||
int64_t k_k = 0; | ||
JUST(ParseDims("key", *key->shape(), key_layout, Optional<int64_t>(), q_k, &k_b, &k_m, &k_h, | ||
&k_k)); | ||
CHECK_EQ_OR_RETURN(k_b, q_b) << "The size of dimension 'B' of the key tensor should be the " | ||
"same as that of the query tensor."; | ||
CHECK_EQ_OR_RETURN(k_h, q_h) << "The size of dimension 'H' of the key tensor should be the " | ||
"same as that of the query tensor."; | ||
if (key) { | ||
key_tensor = JUST(key); | ||
key_tensor_layout = *JUST(key_layout); | ||
JUST(ParseDims("key", *key_tensor->shape(), key_tensor_layout, Optional<int64_t>(), q_k, &k_b, | ||
&k_m, &k_h, &k_k)); | ||
CHECK_EQ_OR_RETURN(k_b, q_b) << "The size of dimension 'B' of the key tensor should be the " | ||
"same as that of the query tensor."; | ||
CHECK_EQ_OR_RETURN(k_h, q_h) << "The size of dimension 'H' of the key tensor should be the " | ||
"same as that of the query tensor."; | ||
|
||
} else { | ||
CHECK_OR_RETURN(query_layout == "BM(H3K)" || query_layout == "MB(H3K)") | ||
<< "The value of query_layout should be 'BM(H3K)' or 'MB(H3K)' when the key tensor is " | ||
"None."; | ||
key_tensor = query; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里加个检查 |
||
key_tensor_layout = query_layout; | ||
k_b = q_b; | ||
k_m = q_m; | ||
k_h = q_h; | ||
k_k = q_k; | ||
} | ||
|
||
int64_t v_b = 0; | ||
int64_t v_m = 0; | ||
int64_t v_h = 0; | ||
int64_t v_k = 0; | ||
JUST(ParseDims("value", *value->shape(), value_layout, q_h, Optional<int64_t>(), &v_b, &v_m, | ||
&v_h, &v_k)); | ||
CHECK_EQ_OR_RETURN(v_b, q_b) << "The size of dimension 'B' of the value tensor should be the " | ||
"same as that of the query tensor."; | ||
CHECK_EQ_OR_RETURN(v_m, k_m) << "The size of dimension 'M' of the value tensor should be the " | ||
"same as that of the key tensor."; | ||
CHECK_EQ_OR_RETURN(v_k % 8, 0) | ||
<< "The size of dimension 'K' of the value tensor should be a multiple of 8."; | ||
if (value) { | ||
value_tensor = JUST(value); | ||
value_tensor_layout = *JUST(value_layout); | ||
JUST(ParseDims("value", *value_tensor->shape(), value_tensor_layout, q_h, Optional<int64_t>(), | ||
&v_b, &v_m, &v_h, &v_k)); | ||
CHECK_EQ_OR_RETURN(v_b, q_b) << "The size of dimension 'B' of the value tensor should be the " | ||
"same as that of the query tensor."; | ||
CHECK_EQ_OR_RETURN(v_m, k_m) << "The size of dimension 'M' of the value tensor should be the " | ||
"same as that of the key tensor."; | ||
CHECK_EQ_OR_RETURN(v_k % 8, 0) | ||
<< "The size of dimension 'K' of the value tensor should be a multiple of 8."; | ||
|
||
} else { | ||
CHECK_OR_RETURN(key_tensor_layout == "BM(H2K)" || key_tensor_layout == "MB(H2K)" | ||
|| key_tensor_layout == "BM(H3K)" || key_tensor_layout == "MB(H3K)") | ||
<< "The value of key_layout should be 'BM(H3K)', 'MB(H3K)', 'BM(H2K)' or 'MB(H2K)' when " | ||
"the value tensor is None."; | ||
value_tensor = key_tensor; | ||
value_tensor_layout = key_tensor_layout; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里也加个检查 |
||
v_b = k_b; | ||
v_m = k_m; | ||
v_h = k_h; | ||
v_k = k_k; | ||
} | ||
|
||
if (attn_bias) { | ||
const auto attn_bias_shape = JUST(attn_bias)->shape(); | ||
|
@@ -5046,15 +5109,39 @@ class FusedMultiHeadAttentionInferenceV2Functor { | |
} | ||
} | ||
|
||
auto& attrs = | ||
THREAD_CACHED_MUTABLE_ATTR_MAP("query_layout", "key_layout", "value_layout", | ||
"query_head_size", "causal", "causal_diagonal_offset"); | ||
attrs.SetAllAttrs(query_layout, key_layout, value_layout, q_k, causal, causal_diagonal_offset); | ||
std::string op_output_layout; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个functor里面if else也有点多,虽然不影响功能但或许switch会好看一点点比如这里改成: constexpr auto kBMHK = "BM(HK)";
constexpr auto kMBHK = "MB(HK)";
auto op_output_layout = output_layout;
switch (output_layout) {
case kBMHK:
break;
case kMBHK:
if (q_b != 1) {
op_output_layout = kBMHK;
}
break;
default:
UNIMPLEMENTED_THEN_RETURN() << "output_layout should be '" << kBMHK << "' or '" << kMBHK << "'";
} 只是提个comment,不一定需要修改 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上 |
||
if (output_layout == "BM(HK)") { | ||
op_output_layout = output_layout; | ||
} else if (output_layout == "MB(HK)") { | ||
if (q_b == 1) { | ||
op_output_layout = output_layout; | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. op_output_layout是不是固定为"BM(HK)"?如果是的话,这里的q_b==1感觉可以删除? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
目前Kernel计算只支持 BM(HK),所以当output_layout 是其他值的时候需要转置,但是当B=1的时候, BM(HK) 和 MB(HK) 的数据分布是一致的,可以省掉一个转置 |
||
op_output_layout = "BM(HK)"; | ||
} | ||
} else { | ||
UNIMPLEMENTED_THEN_RETURN() << "output_layout should be 'BM(HK)' or 'MB(HK)'"; | ||
} | ||
auto& attrs = THREAD_CACHED_MUTABLE_ATTR_MAP("query_layout", "key_layout", "value_layout", | ||
"output_layout", "query_head_size", "causal", | ||
"causal_diagonal_offset"); | ||
attrs.SetAllAttrs(query_layout, key_tensor_layout, value_tensor_layout, op_output_layout, q_k, | ||
causal, causal_diagonal_offset); | ||
std::shared_ptr<one::Tensor> op_output; | ||
if (attn_bias) { | ||
return OpInterpUtil::Dispatch<Tensor>(*op_with_attn_bias_, | ||
{query, key, value, JUST(attn_bias)}, attrs); | ||
op_output = JUST(OpInterpUtil::Dispatch<Tensor>( | ||
*op_with_attn_bias_, {query, key_tensor, value_tensor, JUST(attn_bias)}, attrs)); | ||
} else { | ||
return OpInterpUtil::Dispatch<Tensor>(*op_, {query, key, value}, attrs); | ||
op_output = | ||
JUST(OpInterpUtil::Dispatch<Tensor>(*op_, {query, key_tensor, value_tensor}, attrs)); | ||
} | ||
if (op_output_layout == output_layout) { | ||
return op_output; | ||
} else { | ||
if (op_output_layout == "BM(HK)" && output_layout == "MB(HK)") { | ||
return functional::Transpose(op_output, {1, 0, 2}); | ||
} else { | ||
UNIMPLEMENTED_THEN_RETURN(); | ||
} | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个functor里面涉及到的layerout常量字符串要不要直接用constexpr来定义一下,可以避免硬编码和重复,可读性也会好点。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里其实 kBMHK和"BMHK"在可读性和避免出错的角度基本上是一致的,而且因为包含括号,所以命名还更麻烦一点
这里一个更好的处理方式其实是定义一个Layout类来处理这个问题