Skip to content

Commit 081bee8

Browse files
authored
hparams : add SWA rope parameters (#12374)
ggml-ci
1 parent 84d5475 commit 081bee8

5 files changed

+26
-20
lines changed

src/llama-context.cpp

+5-9
Original file line numberDiff line numberDiff line change
@@ -537,16 +537,12 @@ llm_graph_result_ptr llama_context::build_kv_self_shift(
537537
const int64_t n_head_kv = hparams.n_head_kv(il);
538538
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
539539

540-
float freq_base_l = cparams.rope_freq_base;
541-
float freq_scale_l = cparams.rope_freq_scale;
540+
const bool is_swa = hparams.is_swa(il);
542541

543-
// TODO: improve
544-
if (model.arch == LLM_ARCH_GEMMA3) {
545-
const bool is_sliding = hparams.is_sliding(il);
546-
547-
freq_base_l = is_sliding ? 10000.0f : cparams.rope_freq_base;
548-
freq_scale_l = is_sliding ? 1.0f : cparams.rope_freq_scale;
549-
}
542+
// note: the swa rope params could become part of the cparams in the future
543+
// if we decide to make them configurable, like the non-sliding ones
544+
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
545+
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
550546

551547
ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il);
552548

src/llama-graph.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1403,9 +1403,9 @@ ggml_tensor * llm_graph_context::build_attn(
14031403
ggml_build_forward_expand(gf, ggml_cpy(ctx0, v_cur, v_cache_view));
14041404
}
14051405

1406-
const bool is_sliding = hparams.is_sliding(il);
1406+
const bool is_swa = hparams.is_swa(il);
14071407

1408-
const auto & kq_mask = is_sliding ? inp->get_kq_mask_swa() : inp->get_kq_mask();
1408+
const auto & kq_mask = is_swa ? inp->get_kq_mask_swa() : inp->get_kq_mask();
14091409

14101410
const auto n_kv = kv_self->n;
14111411

src/llama-hparams.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ uint32_t llama_hparams::n_embd_v_s() const {
7070
return ssm_d_state * ssm_d_inner;
7171
}
7272

73-
bool llama_hparams::is_sliding(uint32_t il) const {
73+
bool llama_hparams::is_swa(uint32_t il) const {
7474
if (il < n_layer) {
7575
return n_swa > 0 && n_swa_pattern > 0 && il % n_swa_pattern < (n_swa_pattern - 1);
7676
}

src/llama-hparams.h

+3-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ struct llama_hparams {
7979

8080
float rope_attn_factor = 1.0f;
8181
float rope_freq_base_train;
82+
float rope_freq_base_train_swa;
8283
float rope_freq_scale_train;
84+
float rope_freq_scale_train_swa;
8385
uint32_t n_ctx_orig_yarn;
8486
float rope_yarn_log_mul;
8587

@@ -135,7 +137,7 @@ struct llama_hparams {
135137
// dimension of the recurrent state embeddings
136138
uint32_t n_embd_v_s() const;
137139

138-
bool is_sliding(uint32_t il) const;
140+
bool is_swa(uint32_t il) const;
139141
};
140142

141143
static_assert(std::is_trivially_copyable<llama_hparams>::value, "llama_hparams must be trivially copyable");

src/llama-model.cpp

+15-7
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,10 @@ void llama_model::load_hparams(llama_model_loader & ml) {
475475
}
476476
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
477477

478+
// by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers
479+
hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train;
480+
hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train;
481+
478482
ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false);
479483

480484
// non-transformer models do not have attention heads
@@ -877,6 +881,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
877881
{
878882
hparams.n_swa_pattern = 6;
879883

884+
hparams.rope_freq_base_train_swa = 10000.0f;
885+
hparams.rope_freq_scale_train_swa = 1.0f;
886+
880887
ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa);
881888
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
882889

@@ -1346,13 +1353,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
13461353
const int i_gpu_start = std::max((int) hparams.n_layer - n_gpu_layers, (int) 0);
13471354
const int act_gpu_layers = devices.empty() ? 0 : std::min(n_gpu_layers, (int)n_layer + 1);
13481355
auto get_layer_buft_list = [&](int il) -> llama_model::impl::layer_dev {
1356+
const bool is_swa = il < (int) hparams.n_layer && hparams.is_swa(il);
13491357
if (il < i_gpu_start || (il - i_gpu_start) >= act_gpu_layers) {
1350-
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(cpu_dev));
1358+
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(cpu_dev), is_swa);
13511359
return {cpu_dev, &pimpl->cpu_buft_list};
13521360
}
13531361
const int layer_gpu = std::upper_bound(splits.begin(), splits.begin() + n_devices(), float(il - i_gpu_start)/act_gpu_layers) - splits.begin();
13541362
auto * dev = devices.at(layer_gpu);
1355-
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s\n", il, ggml_backend_dev_name(dev));
1363+
LLAMA_LOG_DEBUG("load_tensors: layer %3d assigned to device %s, is_swa = %d\n", il, ggml_backend_dev_name(dev), is_swa);
13561364
return {dev, &pimpl->gpu_buft_list.at(dev)};
13571365
};
13581366

@@ -7381,10 +7389,10 @@ struct llm_build_gemma3 : public llm_graph_context {
73817389
auto * inp_attn = build_attn_inp_kv_unified(true, true);
73827390

73837391
for (int il = 0; il < n_layer; ++il) {
7384-
const bool is_sliding = hparams.is_sliding(il);
7392+
const bool is_swa = hparams.is_swa(il);
73857393

7386-
const float freq_base_l = is_sliding ? 10000.0f : freq_base;
7387-
const float freq_scale_l = is_sliding ? 1.0f : freq_scale;
7394+
const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base;
7395+
const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale;
73887396

73897397
// norm
73907398
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
@@ -7973,7 +7981,7 @@ struct llm_build_cohere2 : public llm_graph_context {
79737981
auto * inp_attn = build_attn_inp_kv_unified(true, true);
79747982

79757983
for (int il = 0; il < n_layer; ++il) {
7976-
const bool is_sliding = hparams.is_sliding(il);
7984+
const bool is_swa = hparams.is_swa(il);
79777985

79787986
// norm
79797987
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il);
@@ -8007,7 +8015,7 @@ struct llm_build_cohere2 : public llm_graph_context {
80078015
cb(Vcur, "Vcur", il);
80088016
}
80098017

8010-
if (is_sliding) {
8018+
if (is_swa) {
80118019
Qcur = ggml_rope_ext(ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, rope_factors,
80128020
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, attn_factor,
80138021
beta_fast, beta_slow);

0 commit comments

Comments
 (0)