@@ -1311,29 +1311,23 @@ ggml_tensor * llm_graph_context::build_attn(
1311
1311
return cur;
1312
1312
}
1313
1313
1314
- llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified (
1315
- bool causal,
1316
- bool swa) const {
1314
+ llm_graph_input_attn_kv_unified * llm_graph_context::build_attn_inp_kv_unified () const {
1317
1315
const llama_kv_cache_unified * kv_self = static_cast <const llama_kv_cache_unified *>(memory);
1318
1316
1319
1317
auto inp = std::make_unique<llm_graph_input_attn_kv_unified>(hparams, cparams, kv_self);
1320
1318
1321
1319
const auto n_kv = kv_self->n ;
1322
1320
1323
- inp->self_kq_mask = causal
1324
- ? ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD))
1325
- : ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1321
+ inp->self_kq_mask = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1326
1322
// cb(inp->self_kq_mask, "KQ_mask", -1);
1327
1323
ggml_set_input (inp->self_kq_mask );
1328
1324
1329
1325
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast (ctx0, inp->self_kq_mask , GGML_TYPE_F16) : inp->self_kq_mask ;
1330
1326
1331
- if (swa ) {
1327
+ if (hparams. n_swa_pattern > 1 ) {
1332
1328
GGML_ASSERT (hparams.n_swa > 0 );
1333
1329
1334
- inp->self_kq_mask_swa = causal
1335
- ? ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD))
1336
- : ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1330
+ inp->self_kq_mask_swa = ggml_new_tensor_2d (ctx0, GGML_TYPE_F32, n_kv, GGML_PAD (n_tokens, GGML_KQ_MASK_PAD));
1337
1331
// cb(inp->self_kq_mask_swa, "KQ_mask_swa", -1);
1338
1332
ggml_set_input (inp->self_kq_mask_swa );
1339
1333
0 commit comments