llama : fix kq_scale for the attention layers of PLaMo2 (#14892)

* Fix dimensions for expand

* Change dimensions to copy states to cache

* Fix the default value for plamo2 conversion

* Fix scale given to build_attn

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Shunta Saito 2025-07-27 16:38:44 +09:00 committed by GitHub
parent 446595b9b3
commit 1dc9614e06
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 10 additions and 9 deletions

View File

@ -3791,7 +3791,7 @@ class Plamo2Model(TextModel):
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(hparams.get("num_attention_heads", 32))
self.gguf_writer.add_layer_norm_rms_eps(hparams.get("rms_norm_eps", 1e-06))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 1000000.0))
self.gguf_writer.add_rope_freq_base(hparams.get("rope_theta", 10000))
# Mamba parameters
self.gguf_writer.add_ssm_state_size(hparams.get("mamba_d_state", 64))
@ -3802,7 +3802,7 @@ class Plamo2Model(TextModel):
self.gguf_writer.add_ssm_group_count(0)
# MLP feed forward parameters (for attention layers)
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 16384))
self.gguf_writer.add_feed_forward_length(hparams.get("intermediate_size", 13312))
self.gguf_writer.add_file_type(self.ftype)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:

View File

@ -16191,7 +16191,7 @@ private:
{
// PLaMo-2 uses combined QKV tensor
ggml_tensor * qkv = build_lora_mm(model.layers[il].wqkv, cur);
cb(qkv, "qkv", il);
cb(qkv, "wqkv", il);
// split QKV tensor into Q, K, V
const int64_t n_embd_head_q = hparams.n_embd_head_k;
@ -16231,7 +16231,7 @@ private:
ext_factor, attn_factor, beta_fast, beta_slow
);
cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f, il);
cur = build_attn(inp, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, NULL, NULL, 1.0f/sqrtf(float(n_embd_head_v)), il);
}
cb(cur, "attn_out", il);
@ -16306,8 +16306,9 @@ private:
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, last_conv,
ggml_view_1d(ctx0, conv_states_all,
(d_conv - 1)*(d_inner)*(n_seqs),
kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
cb(conv_states_all, "mamba_conv1d_state", il);
// 1D convolution
x = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
@ -16370,9 +16371,9 @@ private:
// store last states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]),
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs,
kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
ggml_view_1d(ctx0, y_ssm, n_heads*head_dim*d_state*n_seqs, n_heads*head_dim*n_seq_tokens*n_seqs*ggml_element_size(y_ssm)),
ggml_view_1d(ctx0, ssm_states_all, n_heads*head_dim*d_state*n_seqs, kv_head*n_seqs*n_heads*head_dim*d_state*ggml_element_size(ssm_states_all))));
cb(ssm_states_all, "mamba_ssm_states", il);
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_heads, n_seq_tokens, n_seqs, head_dim * ggml_element_size(x), head_dim * n_heads * ggml_element_size(x), head_dim * n_heads * n_seq_tokens * ggml_element_size(x), 0);
cb(y, "mamba_y_view", il);