Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlashMLA support #1159

Open
wants to merge 31 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
aae324f
FlashMLA support
EricLBuehler Feb 25, 2025
f54feeb
Merge branch 'master' into flash_mla
EricLBuehler Feb 25, 2025
ef11fd5
Include flash_fwd_mla_kernel.h again?
EricLBuehler Feb 25, 2025
ab2d86f
Include flash_fwd_mla_kernel.h only once
EricLBuehler Feb 25, 2025
1fdba40
Tweak
EricLBuehler Feb 25, 2025
e881bb4
Move CUDA_NVCC_FLAGS to last
EricLBuehler Feb 25, 2025
357519c
Permute caches
EricLBuehler Feb 25, 2025
37d9c1b
Fix reshape
EricLBuehler Feb 25, 2025
e47733f
Remove lora from q in dsv3
EricLBuehler Feb 25, 2025
7ff28b2
Add prefill part
EricLBuehler Feb 25, 2025
72f55a3
Use MLAAttention
EricLBuehler Feb 25, 2025
1da1a77
Add concat_and_cache_kernel_mla cuda kernel
EricLBuehler Feb 25, 2025
70c6c2f
Merge branch 'dsv3_matrix_absorb' into flash_mla
EricLBuehler Feb 25, 2025
e2f5d6c
Add decode using flashmla
EricLBuehler Feb 25, 2025
864b73b
Add it to dsv2
EricLBuehler Feb 26, 2025
7982be8
Add lora to dsv3
EricLBuehler Feb 26, 2025
0f4c076
Only k_c_k_pe cache, no k/v cache
EricLBuehler Feb 26, 2025
a4eedcf
Hack the cache
EricLBuehler Feb 26, 2025
425b2ca
Fix qproj weight narrow
EricLBuehler Feb 26, 2025
a9ed6e7
Add .contiguous()?
EricLBuehler Feb 26, 2025
77014cf
Some fixes
EricLBuehler Feb 26, 2025
01ccb7c
Fix rank checks in concat_and_cache mla
EricLBuehler Feb 26, 2025
7e2c662
Flatten for num_tokens in concat_and_cache
EricLBuehler Feb 26, 2025
8895bed
Fix cache?
EricLBuehler Feb 26, 2025
10d6aeb
Fix cache?
EricLBuehler Feb 26, 2025
1c097fd
Don't unsqueeze
EricLBuehler Feb 26, 2025
34a60bb
1 kv head, fix head dim
EricLBuehler Feb 26, 2025
3069b75
Fix cache, fix passing v head dim
EricLBuehler Feb 26, 2025
ea4edb8
Fix check for dummy v
EricLBuehler Feb 26, 2025
f6c94f7
Fix out shape
EricLBuehler Feb 26, 2025
10e3421
out-accum should be f32
EricLBuehler Feb 26, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ rust-version = "1.82"

[workspace.dependencies]
anyhow = "1.0.80"
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "8d3ea29" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "8d3ea29" }
candle-core = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "294dcfdf" }
candle-nn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "294dcfdf" }
serde = "1.0.197"
serde_json = "1.0.114"
indexmap = { version = "2.2.5", features = ["serde"] }
Expand Down
6 changes: 4 additions & 2 deletions mistralrs-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ candle-core.workspace = true
candle-nn.workspace = true
serde.workspace = true
serde_json.workspace = true
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "8d3ea29", optional = true }
candle-flash-attn = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "294dcfdf", optional = true }
dirs = "5.0.1"
hf-hub = { version = "0.3.3", package = "candle-hf-hub" }
thiserror = "1.0.57"
Expand Down Expand Up @@ -81,7 +81,8 @@ llguidance = { git = "https://github.com/microsoft/llguidance", rev = "cfef3df97
toktrie_hf_tokenizers = { git = "https://github.com/microsoft/llguidance", rev = "cfef3df97372a7b84d74976ff41cc9cb78bca6cc" }
objc = { version = "0.2.7", optional = true }
metal = { workspace = true, optional = true }
candle-flash-attn-v3 = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "8d3ea29", optional = true }
candle-flash-attn-v3 = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "294dcfdf", optional = true }
candle-flash-mla = { git = "https://github.com/EricLBuehler/candle.git", version = "0.8.0", rev = "294dcfdf", optional = true }
safetensors.workspace = true
serde-big-array = "0.5.1"
interprocess = "2.2.2"
Expand All @@ -96,6 +97,7 @@ cuda = [
"dep:mistralrs-paged-attn",
"mistralrs-paged-attn/cuda",
"float8/cuda",
"dep:candle-flash-mla",
]
cudnn = ["candle-core/cudnn"]
metal = [
Expand Down
Loading
Loading