Skip to content

Commit d1e130d

Browse files
ggerganovjpohhhh
authored andcommitted
graph : simplify attn input build for unified KV cache (ggml-org#12381)
ggml-ci
1 parent bb288b2 commit d1e130d

File tree

3 files changed

+53
-58
lines changed

3 files changed

+53
-58
lines changed

src/llama-graph.cpp

+4-10
Original file line numberDiff line numberDiff line change
@@ -1311,29 +1311,23 @@ ggml_tensor * llm_graph_context::build_attn(
13111311
return cur;
13121312
}
13131313

1314-
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified(
1315-
bool causal,
1316-
bool swa) const {
1314+
llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified() const {
13171315
const llama_kv_cache_unified * kv_self = static_cast<const llama_kv_cache_unified *>(memory);
13181316

13191317
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
13201318

13211319
const auto n_kv = kv_self->n;
13221320

1323-
inp->self_kq_mask = causal
1324-
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
1325-
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1321+
inp->self_kq_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13261322
//cb(inp->self_kq_mask, "KQ_mask", -1);
13271323
ggml_set_input(inp->self_kq_mask);
13281324

13291325
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
13301326

1331-
if (swa) {
1327+
if (hparams.n_swa_pattern > 1) {
13321328
GGML_ASSERT(hparams.n_swa > 0);
13331329

1334-
inp->self_kq_mask_swa = causal
1335-
? ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD))
1336-
: ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
1330+
inp->self_kq_mask_swa = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD));
13371331
//cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
13381332
ggml_set_input(inp->self_kq_mask_swa);
13391333

src/llama-graph.h

+1-3
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,7 @@ struct llm_graph_context {
509509
float kq_scale,
510510
int il) const;
511511

512-
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified(
513-
bool causal,
514-
bool swa) const;
512+
llm_graph_input_attn_kv_unified * build_attn_inp_kv_unified() const;
515513

516514
ggml_tensor * build_attn(
517515
llm_graph_input_attn_kv_unified * inp,

0 commit comments

Comments
 (0)