-
Notifications
You must be signed in to change notification settings - Fork 802
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[fmha] packed qkv and output_layout #9950
Conversation
"same as that of the query tensor."; | ||
|
||
} else { | ||
key_tensor = query; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里加个检查CHECK_OR_RETURN(query_layout=="BM(H3K)" || query_layout=="MB(H3K)")
|
||
} else { | ||
value_tensor = key_tensor; | ||
value_tensor_layout = key_tensor_layout; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里也加个检查CHECK_OR_RETURN(key_tensor_layout =="BM(H2K)" || key_tensor_layout =="MB(H2K)"|| key_tensor_layout =="BM(H3K)" || key_tensor_layout =="MB(H3K)")
Speed stats:
|
} else if (output_layout == "MB(HK)") { | ||
if (q_b == 1) { | ||
op_output_layout = output_layout; | ||
} else { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_output_layout是不是固定为"BM(HK)"?如果是的话,这里的q_b==1感觉可以删除?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
op_output_layout是不是固定为"BM(HK)"?如果是的话,这里的q_b==1感觉可以删除?
目前Kernel计算只支持 BM(HK),所以当output_layout 是其他值的时候需要转置,但是当B=1的时候, BM(HK) 和 MB(HK) 的数据分布是一致的,可以省掉一个转置
@@ -4910,43 +4912,63 @@ class FusedMultiHeadAttentionInferenceV2Functor { | |||
const Optional<int64_t>& head_size, int64_t* b, int64_t* m, | |||
int64_t* h, int64_t* k) -> Maybe<void> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个functor里面涉及到的layerout常量字符串要不要直接用constexpr来定义一下,可以避免硬编码和重复,可读性也会好点。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里其实 kBMHK和"BMHK"在可读性和避免出错的角度基本上是一致的,而且因为包含括号,所以命名还更麻烦一点
这里一个更好的处理方式其实是定义一个Layout类来处理这个问题
THREAD_CACHED_MUTABLE_ATTR_MAP("query_layout", "key_layout", "value_layout", | ||
"query_head_size", "causal", "causal_diagonal_offset"); | ||
attrs.SetAllAttrs(query_layout, key_layout, value_layout, q_k, causal, causal_diagonal_offset); | ||
std::string op_output_layout; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个functor里面if else也有点多,虽然不影响功能但或许switch会好看一点点比如这里改成:
constexpr auto kBMHK = "BM(HK)";
constexpr auto kMBHK = "MB(HK)";
auto op_output_layout = output_layout;
switch (output_layout) {
case kBMHK:
break;
case kMBHK:
if (q_b != 1) {
op_output_layout = kBMHK;
}
break;
default:
UNIMPLEMENTED_THEN_RETURN() << "output_layout should be '" << kBMHK << "' or '" << kMBHK << "'";
}
只是提个comment,不一定需要修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同上
Speed stats:
|
View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9950/ |
No description provided.