Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[fmha] packed qkv and output_layout #9950

Merged
merged 4 commits into from
Mar 7, 2023
Merged

[fmha] packed qkv and output_layout #9950

merged 4 commits into from
Mar 7, 2023

Conversation

liujuncheng
Copy link
Collaborator

No description provided.

"same as that of the query tensor.";

} else {
key_tensor = query;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里加个检查CHECK_OR_RETURN(query_layout=="BM(H3K)" || query_layout=="MB(H3K)")


} else {
value_tensor = key_tensor;
value_tensor_layout = key_tensor_layout;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里也加个检查CHECK_OR_RETURN(key_tensor_layout =="BM(H2K)" || key_tensor_layout =="MB(H2K)"|| key_tensor_layout =="BM(H3K)" || key_tensor_layout =="MB(H3K)")

@github-actions
Copy link
Contributor

github-actions bot commented Mar 6, 2023

Speed stats:

} else if (output_layout == "MB(HK)") {
if (q_b == 1) {
op_output_layout = output_layout;
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_output_layout是不是固定为"BM(HK)"?如果是的话,这里的q_b==1感觉可以删除?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

op_output_layout是不是固定为"BM(HK)"?如果是的话,这里的q_b==1感觉可以删除?

目前Kernel计算只支持 BM(HK),所以当output_layout 是其他值的时候需要转置,但是当B=1的时候, BM(HK) 和 MB(HK) 的数据分布是一致的,可以省掉一个转置

@@ -4910,43 +4912,63 @@ class FusedMultiHeadAttentionInferenceV2Functor {
const Optional<int64_t>& head_size, int64_t* b, int64_t* m,
int64_t* h, int64_t* k) -> Maybe<void> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个functor里面涉及到的layerout常量字符串要不要直接用constexpr来定义一下,可以避免硬编码和重复,可读性也会好点。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里其实 kBMHK和"BMHK"在可读性和避免出错的角度基本上是一致的,而且因为包含括号,所以命名还更麻烦一点
这里一个更好的处理方式其实是定义一个Layout类来处理这个问题

THREAD_CACHED_MUTABLE_ATTR_MAP("query_layout", "key_layout", "value_layout",
"query_head_size", "causal", "causal_diagonal_offset");
attrs.SetAllAttrs(query_layout, key_layout, value_layout, q_k, causal, causal_diagonal_offset);
std::string op_output_layout;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个functor里面if else也有点多,虽然不影响功能但或许switch会好看一点点比如这里改成:

constexpr auto kBMHK = "BM(HK)";
constexpr auto kMBHK = "MB(HK)";
auto op_output_layout = output_layout;
switch (output_layout) {
  case kBMHK:
    break;
  case kMBHK:
    if (q_b != 1) {
      op_output_layout = kBMHK;
    }
    break;
  default:
    UNIMPLEMENTED_THEN_RETURN() << "output_layout should be '" << kBMHK << "' or '" << kMBHK << "'";
}

只是提个comment,不一定需要修改

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@liujuncheng liujuncheng enabled auto-merge (squash) March 7, 2023 02:42
@github-actions
Copy link
Contributor

github-actions bot commented Mar 7, 2023

Speed stats:
GPU Name: GeForce GTX 1080 

❌ OneFlow resnet50 time: 141.3ms (= 14126.3ms / 100, input_shape=[16, 3, 224, 224])
PyTorch resnet50 time: 143.1ms (= 14306.5ms / 100, input_shape=[16, 3, 224, 224])
❌ Relative speed: 1.01 (= 143.1ms / 141.3ms)

OneFlow resnet50 time: 82.6ms (= 8263.1ms / 100, input_shape=[8, 3, 224, 224])
PyTorch resnet50 time: 89.7ms (= 8972.8ms / 100, input_shape=[8, 3, 224, 224])
✔️ Relative speed: 1.09 (= 89.7ms / 82.6ms)

OneFlow resnet50 time: 51.1ms (= 10216.5ms / 200, input_shape=[4, 3, 224, 224])
PyTorch resnet50 time: 63.8ms (= 12762.9ms / 200, input_shape=[4, 3, 224, 224])
✔️ Relative speed: 1.25 (= 63.8ms / 51.1ms)

OneFlow resnet50 time: 33.9ms (= 6783.0ms / 200, input_shape=[2, 3, 224, 224])
PyTorch resnet50 time: 42.5ms (= 8491.3ms / 200, input_shape=[2, 3, 224, 224])
✔️ Relative speed: 1.25 (= 42.5ms / 33.9ms)

OneFlow resnet50 time: 25.4ms (= 5081.6ms / 200, input_shape=[1, 3, 224, 224])
PyTorch resnet50 time: 42.7ms (= 8534.3ms / 200, input_shape=[1, 3, 224, 224])
✔️ Relative speed: 1.68 (= 42.7ms / 25.4ms)

OneFlow swin dataloader time: 0.237s (= 47.405s / 200, num_workers=1)
PyTorch swin dataloader time: 0.150s (= 29.903s / 200, num_workers=1)
Relative speed: 0.631 (= 0.150s / 0.237s)

OneFlow swin dataloader time: 0.070s (= 13.999s / 200, num_workers=4)
PyTorch swin dataloader time: 0.045s (= 8.910s / 200, num_workers=4)
Relative speed: 0.636 (= 0.045s / 0.070s)

OneFlow swin dataloader time: 0.041s (= 8.125s / 200, num_workers=8)
PyTorch swin dataloader time: 0.023s (= 4.521s / 200, num_workers=8)
Relative speed: 0.556 (= 0.023s / 0.041s)

❌ OneFlow resnet50 time: 153.7ms (= 15371.3ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 165.1ms (= 16511.2ms / 100, input_shape=[16, 3, 224, 224], ddp, world size=2)
❌ Relative speed: 1.07 (= 165.1ms / 153.7ms)

OneFlow resnet50 time: 94.3ms (= 9431.8ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 103.3ms (= 10335.0ms / 100, input_shape=[8, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.10 (= 103.3ms / 94.3ms)

OneFlow resnet50 time: 61.3ms (= 12264.0ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 77.2ms (= 15438.3ms / 200, input_shape=[4, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.26 (= 77.2ms / 61.3ms)

OneFlow resnet50 time: 43.5ms (= 8707.4ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 66.7ms (= 13346.4ms / 200, input_shape=[2, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.53 (= 66.7ms / 43.5ms)

OneFlow resnet50 time: 36.7ms (= 7331.6ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
PyTorch resnet50 time: 67.5ms (= 13495.0ms / 200, input_shape=[1, 3, 224, 224], ddp, world size=2)
✔️ Relative speed: 1.84 (= 67.5ms / 36.7ms)

@github-actions
Copy link
Contributor

github-actions bot commented Mar 7, 2023

View latest API docs preview at: https://staging.oneflow.info/docs/Oneflow-Inc/oneflow/pr/9950/

@liujuncheng liujuncheng merged commit 368f054 into master Mar 7, 2023
@liujuncheng liujuncheng deleted the dev_fmha_pack_qkv branch March 7, 2023 04:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants