model : add support for Falcon-H1 family (#14534)

* v1

* push more fixes

* another fix

* fix

* more fixes

* minor fix

* more cleaning on python code

* python fixes

* changed precision for multipliers float 32->64

* fixes

* another fix

* fix

* pre-norm -> norm

* fix

* Revert "fix"

This reverts commit 243e4d1a50.

* fix

* small fix ffn_norm

* try

* mix instead of max

* fix vocab size

* conflict solve

* fixed multipliers

* falcon-h1 specefic vocab resolved

* read arch from gguf.MODEL_ARCH

* mamba_d_ssm added to d_inner find_hparam

* remove unused functions from gguf_writer.py

* override modify_tensors instead of get_tensors

* fix conversion and d_inner

* added some cb functions for debugging puposes

* inp_out_ids moved outside of layers loop

* mup_vec create as float64

* fix rope_theta

* injected mup

* clean ups

* rm extra space

* rm unused MAMBA_CHUNK_SIZE

* rm unused key

* add bos False

* changed ROPE_TYPE

* cleaning debugging stuff

* cleaning debug quant

* fix comment

* some cleanups

* some cleanups

* Update src/llama-model-loader.cpp

* more cleanups

* moe cleanuips

* d_ssm -> d_inner;

* cleaning unused hparams

* cleanup

* more cleanups

* more cleanups on python conversion;

* minor cleanups

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* remove todo

* added falcon-h1

* tensor not required

* clean

* remove unneeded attributes

* more cleanups and fixed conversion

* remove final_norm

* flake8 fixes

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* flake8 fixes

* Update src/llama-hparams.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-arch.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update convert_hf_to_gguf.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* added hashes

* Update src/llama-arch.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Update src/llama-vocab.cpp

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* update the update file

* Revert "update the update file"

This reverts commit 082ab4ad2a.

* fix: address suggestions

* fix: update convert_hf_to_gguf.py

* Update gguf-py/gguf/constants.py

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* Update src/llama-model-loader.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* d_inner fixed

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* reshaping ssm_norm for 34B

* removing generate_mup

* remove duplicates metadata keys

* rm comment

* final comment

* fix unused args

* fix constants

* fix bad merge

* Update src/llama-model.cpp

Co-authored-by: compilade <git@compilade.net>

* falcon-h1: remove unused ssm_in_b and bad merge

* Update src/llama-model.cpp

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

* falcon-h1: fix last comment

* Update convert_hf_to_gguf.py

Co-authored-by: compilade <git@compilade.net>

* falcon-h1: revert add_add_bos(False)

* falcon-h1: fix tied weights

* falcon-h1: remove whitespace

* falcon-h1: fix wrong size param

* falcon-h1: fix whitespace issues

---------

Co-authored-by: younesbelkada <younes.belkada@tii.ae>
Co-authored-by: Younes B <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
Co-authored-by: compilade <git@compilade.net>
This commit is contained in:
ibrahim khadraoui 2025-07-09 12:03:49 +04:00 committed by GitHub
parent 20b7bf8a32
commit 04655063c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 585 additions and 9 deletions

View File

@ -818,6 +818,18 @@ class TextModel(ModelBase):
if chkhsh == "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664":
# ref: https://huggingface.co/tencent/Hunyuan-A13B-Instruct
res = "hunyuan"
if chkhsh == "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6":
# ref: https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base
res = "falcon-h1"
if chkhsh == "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86":
# ref: https://huggingface.co/tiiuae/Falcon-H1-1B-Base
res = "falcon-h1"
if chkhsh == "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896":
# ref: https://huggingface.co/tiiuae/Falcon-H1-7B-Base
res = "falcon-h1"
if chkhsh == "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b":
# ref: https://huggingface.co/tiiuae/Falcon-H1-34B-Base
res = "falcon-h1"
if res is None:
logger.warning("\n")
@ -4899,17 +4911,19 @@ class Mamba2Model(TextModel):
def set_gguf_parameters(self):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128
head_dim = self.find_hparam(["head_dim"], optional=True) or 64
head_dim = self.find_hparam(["mamba_d_head", "head_dim"], optional=True) or 64
n_group = self.find_hparam(["n_groups"], optional=True) or 1
rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5
# Fail early for models which don't have a block expansion factor of 2
# TODO: does this really matter?
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
# skip the assertion for FalconH1 Model
if self.model_arch != gguf.MODEL_ARCH.FALCON_H1:
assert d_inner == 2 * d_model
assert d_inner % head_dim == 0
self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default
self.gguf_writer.add_embedding_length(d_model)
@ -4946,7 +4960,7 @@ class Mamba2Model(TextModel):
data_torch = data_torch.reshape((*data_torch.shape, 1))
elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid):
d_model = self.find_hparam(["hidden_size", "d_model", "dim"])
d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model
d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * d_model
n_group = self.hparams.get("n_groups", 1)
data_torch = data_torch.reshape((n_group, d_inner // n_group))
@ -6539,6 +6553,113 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
@ModelBase.register("FalconH1ForCausalLM")
class FalconH1Model(Mamba2Model):
model_arch = gguf.MODEL_ARCH.FALCON_H1
def __init__(self, *args, **kwargs):
# Set the hparam prefixes for Falcon Mamba2
self.hparam_prefixes = ["mamba"]
# Initialize the base Mamba2Model
super().__init__(*args, **kwargs)
# Use Llama conversion for attention
self._transformer_model_class = LlamaModel
# n_group and d_inner are used during reshape_tensors for mamaba2
self.n_group = self.find_hparam(["n_groups"])
self.d_inner = self.find_hparam(["mamba_d_ssm"])
self.d_head = self.find_hparam(["d_head"])
# Initialize any Falcon Mamba2 specific attributes
self.has_attention = True # Falcon Mamba2 has attention components
# Load Falcon-H1 multipliers from hyperparameters
self.attention_in_multiplier = self.find_hparam(["attention_in_multiplier"], optional=True)
self.attention_out_multiplier = self.find_hparam(["attention_out_multiplier"], optional=True)
self.ssm_in_multiplier = self.find_hparam(["ssm_in_multiplier"], optional=True)
self.ssm_out_multiplier = self.find_hparam(["ssm_out_multiplier"], optional=True)
self.mlp_multipliers = self.find_hparam(["mlp_multipliers"], optional=True)
self.ssm_multipliers = self.find_hparam(["ssm_multipliers"], optional=True)
self.intermediate_size = self.find_hparam(["intermediate_size"])
self.key_multiplier = self.find_hparam(["key_multiplier"], optional=True)
def find_hparam(self, keys: Iterable[str], *args, **kwargs) -> Any:
prefixed = []
for pfx in self.hparam_prefixes:
prefixed.extend(
"_".join([pfx, k])
for k in keys
)
keys = list(keys) + prefixed
return super().find_hparam(keys, *args, **kwargs)
def set_vocab(self):
self._set_vocab_gpt2()
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
tensors = list(super().modify_tensors(data_torch, name, bid))
tensor = tensors[0][1]
if "down_proj" in name:
tensor = tensor * self.mlp_multipliers[1]
elif "gate_proj" in name:
tensor = tensor * self.mlp_multipliers[0]
elif "k_proj" in name:
tensor = tensor * self.key_multiplier * self.attention_in_multiplier
elif "q_proj" in name:
tensor = tensor * self.attention_in_multiplier
elif "v_proj" in name:
tensor = tensor * self.attention_in_multiplier
elif "o_proj" in name:
tensor = tensor * self.attention_out_multiplier
elif "out_proj" in name:
tensor = tensor * self.ssm_out_multiplier
elif "in_proj" in name:
tensor = tensor * self.ssm_in_multiplier
zxbcdt_multipliers = self.hparams["ssm_multipliers"]
intermediate_size = self.hparams["mamba_d_ssm"]
groups_time_state_size = self.hparams["mamba_n_groups"] * self.hparams["mamba_d_state"]
tensor[:intermediate_size, :] *= zxbcdt_multipliers[0]
tensor[intermediate_size:2 * intermediate_size, :] *= zxbcdt_multipliers[1]
tensor[2 * intermediate_size:2 * intermediate_size + groups_time_state_size, :] *= zxbcdt_multipliers[2]
tensor[2 * intermediate_size + groups_time_state_size:2 * intermediate_size + 2 * groups_time_state_size, :] *= zxbcdt_multipliers[3]
tensor[2 * intermediate_size + 2 * groups_time_state_size:, :] *= zxbcdt_multipliers[4]
elif "lm_head" in name:
tensor = tensor * self.hparams["lm_head_multiplier"]
elif "embed_tokens" in name:
tensor = tensor * self.hparams["embedding_multiplier"]
elif "mamba.norm" in name:
tensor = tensor.reshape(self.n_group, self.d_inner // self.n_group)
tensors = [(tensors[0][0], tensor)]
return tensors
def set_gguf_parameters(self):
super().set_gguf_parameters()
## General Params ##
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
# Override some Mamba2 defaults
self.gguf_writer.add_block_count(self.block_count)
self.gguf_writer.add_context_length(self.hparams.get("max_position_embeddings", 0))
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
## Attention params ##
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"]) # Override value 0 from Mamba2
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
self.gguf_writer.add_key_length(self.hparams["head_dim"])
self.gguf_writer.add_value_length(self.hparams["head_dim"])
## Validation ##
assert self.hparams.get("hidden_act") in [None, "silu"], "Only SILU activation supported"
assert self.d_inner % self.d_head == 0, f"SSM inner size {self.d_inner} not a multiple of head dim {self.d_head}"
# Add any other Falcon Mamba2 specific configuration
self.gguf_writer.add_rope_freq_base(self.find_hparam(["rope_theta"]))
@ModelBase.register("HunYuanMoEV1ForCausalLM")
class HunYuanMoEModel(TextModel):
model_arch = gguf.MODEL_ARCH.HUNYUAN_MOE

View File

@ -138,6 +138,11 @@ pre_computed_hashes = [
{"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"},
{"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"},
{"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"},
# falcon-h1 series uses 4 different tokenizers across model sizes (0.5b - 34b), hence we need to define 4 different hashes
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-0.5B-Base", "chkhsh": "a6b57017d60e6edb4d88ecc2845188e0eb333a70357e45dcc9b53964a73bbae6"},
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-1B-Base", "chkhsh": "60476e1243776c4fb1b993dbd7a5f15ac22f83c80afdf425fa5ae01c8d44ef86"},
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-7B-Base", "chkhsh": "3eda48b4c4dc7de733d1a8b3e3b4a85243dbbf704da2ee9d42c6beced8897896"},
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
]

View File

@ -288,6 +288,7 @@ class MODEL_ARCH(IntEnum):
LLAMA4 = auto()
DECI = auto()
FALCON = auto()
FALCON_H1 = auto()
BAICHUAN = auto()
GROK = auto()
GPT2 = auto()
@ -662,6 +663,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.DOTS1: "dots1",
MODEL_ARCH.ARCEE: "arcee",
MODEL_ARCH.ERNIE4_5: "ernie4_5",
MODEL_ARCH.FALCON_H1: "falcon-h1",
MODEL_ARCH.HUNYUAN_MOE: "hunyuan-moe",
MODEL_ARCH.SMOLLM3: "smollm3",
}
@ -2215,6 +2217,40 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
],
MODEL_ARCH.FALCON_H1: [
# Token embedding
MODEL_TENSOR.TOKEN_EMBD,
# Input layernorm
MODEL_TENSOR.ATTN_NORM,
# Attention components
MODEL_TENSOR.ATTN_Q, # Query projection
MODEL_TENSOR.ATTN_K, # Key projection
MODEL_TENSOR.ATTN_V, # Value projection
MODEL_TENSOR.ATTN_OUT, # Output projection
# SSM components (Mamba2 specific)
MODEL_TENSOR.SSM_IN, # Input projection for SSM
MODEL_TENSOR.SSM_CONV1D, # Convolution layer
MODEL_TENSOR.SSM_DT, # Delta time projection
MODEL_TENSOR.SSM_A, # A parameter (log form)
MODEL_TENSOR.SSM_D, # D parameter
MODEL_TENSOR.SSM_NORM, # Normalization in SSM
MODEL_TENSOR.SSM_OUT, # Output projection
# Pre-feedforward layernorm
MODEL_TENSOR.FFN_PRE_NORM,
# Feed-forward network components
MODEL_TENSOR.FFN_GATE, # Gate projection (SwiGLU)
MODEL_TENSOR.FFN_DOWN, # Down projection
MODEL_TENSOR.FFN_UP, # Up projection
# Post-feedforward layernorm
MODEL_TENSOR.OUTPUT_NORM, # Final layer norm
MODEL_TENSOR.OUTPUT, # Output projection (lm_head)
],
MODEL_ARCH.HUNYUAN_MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View File

@ -286,12 +286,14 @@ class TensorNameMap:
# Post feed-forward norm
MODEL_TENSOR.FFN_PRE_NORM: (
"model.layers.{bid}.pre_feedforward_layernorm", # gemma2
"model.layers.{bid}.pre_ff_layernorm.weight",
),
# Post feed-forward norm
MODEL_TENSOR.FFN_POST_NORM: (
"model.layers.{bid}.post_feedforward_layernorm", # gemma2 olmo2
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
"model.layers.{bid}.feed_forward.up_proj",
),
MODEL_TENSOR.FFN_GATE_INP: (
@ -363,6 +365,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_expert.up_proj", # qwen2moe
"model.layers.{bid}.mlp.shared_experts.up_proj", # deepseek deepseek2
"model.layers.{bid}.feed_forward.shared_expert.up_proj", # llama4
"model.layers.{bid}.feed_forward.down_proj",
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
),
@ -553,11 +556,13 @@ class TensorNameMap:
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj",
"backbone.layers.{bid}.mixer.in_proj",
"model.layers.{bid}.mamba.in_proj",
),
MODEL_TENSOR.SSM_CONV1D: (
"model.layers.{bid}.conv1d",
"backbone.layers.{bid}.mixer.conv1d",
"model.layers.{bid}.mamba.conv1d",
),
MODEL_TENSOR.SSM_X: (
@ -568,25 +573,30 @@ class TensorNameMap:
MODEL_TENSOR.SSM_DT: (
"model.layers.{bid}.dt_proj",
"backbone.layers.{bid}.mixer.dt_proj",
"model.layers.{bid}.mamba.dt_proj",
),
MODEL_TENSOR.SSM_A: (
"model.layers.{bid}.A_log",
"backbone.layers.{bid}.mixer.A_log",
"model.layers.{bid}.mamba.A_log",
),
MODEL_TENSOR.SSM_D: (
"model.layers.{bid}.D",
"backbone.layers.{bid}.mixer.D",
"model.layers.{bid}.mamba.D",
),
MODEL_TENSOR.SSM_NORM: (
"model.layers.{bid}.mamba.norm", # falcon-h1
"backbone.layers.{bid}.mixer.norm", # mamba2
),
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj",
"backbone.layers.{bid}.mixer.out_proj",
"model.layers.{bid}.mamba.out_proj", # falcon-h1
),
MODEL_TENSOR.TIME_MIX_W0: (

View File

@ -46,6 +46,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_STARCODER2, "starcoder2" },
{ LLM_ARCH_MAMBA, "mamba" },
{ LLM_ARCH_MAMBA2, "mamba2" },
{ LLM_ARCH_FALCON_H1, "falcon-h1" },
{ LLM_ARCH_XVERSE, "xverse" },
{ LLM_ARCH_COMMAND_R, "command-r" },
{ LLM_ARCH_COHERE2, "cohere2" },
@ -1024,6 +1025,30 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_FALCON_H1,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
{
LLM_ARCH_XVERSE,
{
@ -1967,9 +1992,10 @@ bool llm_arch_is_recurrent(const llm_arch & arch) {
}
bool llm_arch_is_hybrid(const llm_arch & arch) {
// TODO: There are currently no hybrid models! Once there are, this will be
// the place to identify them
// List all mamba-attention hybrid models here
switch (arch) {
case LLM_ARCH_FALCON_H1:
return true;
default:
return false;
}

View File

@ -50,6 +50,7 @@ enum llm_arch {
LLM_ARCH_STARCODER2,
LLM_ARCH_MAMBA,
LLM_ARCH_MAMBA2,
LLM_ARCH_FALCON_H1,
LLM_ARCH_XVERSE,
LLM_ARCH_COMMAND_R,
LLM_ARCH_COHERE2,

View File

@ -1550,6 +1550,37 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_FALCON_H1:
{
// Common parameters
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
// SSM parameters
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
std::fill(hparams.recurrent_layer_arr.begin(), hparams.recurrent_layer_arr.end(), true);
switch (hparams.n_layer) {
case 36:
type = LLM_TYPE_0_5B; break;
case 24:
type = LLM_TYPE_1_5B; break;
case 66:
type = LLM_TYPE_1B; break;
case 32:
type = LLM_TYPE_3B; break;
case 44:
type = LLM_TYPE_7B; break;
case 72:
type = LLM_TYPE_34B; break;
default:
type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_HUNYUAN_MOE:
{
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@ -4497,6 +4528,83 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_FALCON_H1:
{
// Common
const int64_t hidden_size = hparams.n_embd; // hidden_size
// mamba2 Mixer SSM params
const int64_t ssm_conv_kernel_size = hparams.ssm_d_conv; // ssm_conv_kernel_size
const int64_t ssm_n_groups = hparams.ssm_n_group; // ssm_n_groups
const int64_t ssm_state_size = hparams.ssm_d_state; // ssm_state_size
const int64_t ssm_intermediate_size = hparams.ssm_d_inner; // TODO expand
const int64_t ssm_num_heads = hparams.ssm_dt_rank; // ssm_num_heads
const int64_t ssm_conv_dim = ssm_intermediate_size + 2 * ssm_n_groups * ssm_state_size;
const int64_t ssm_projection_size = ssm_intermediate_size + ssm_conv_dim + ssm_num_heads;
// attn params
const int64_t attn_num_attention_head = hparams.n_head(0); // rename to: attn_num_attention_head
const int64_t attn_num_key_value_head = hparams.n_head_kv(0);
// ffn params
const int64_t ffn_intermediate_size = hparams.n_ff(0);
// embeddings
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, 0);
// output
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hidden_size, n_vocab}, TENSOR_NOT_REQUIRED);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {hidden_size}, 0);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hidden_size, n_vocab}, TENSOR_DUPLICATED);
}
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
/*SSM LAYERS*/
// ssm in
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {hidden_size, ssm_projection_size}, 0);
// ssm 1d conv
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {ssm_conv_kernel_size, ssm_conv_dim}, 0);
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {ssm_conv_dim}, TENSOR_NOT_REQUIRED);
// ssm_dt
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {ssm_num_heads}, 0);
// no "weight" suffix for these
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, ssm_num_heads}, 0);
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, ssm_num_heads}, 0);
// ssm_norm
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {ssm_intermediate_size / ssm_n_groups, ssm_n_groups}, TENSOR_NOT_REQUIRED);
// out_proj
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {ssm_intermediate_size, hidden_size}, 0);
/*ATTENTION LAYERS*/
// attention layers (with optional bias)
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {hidden_size, n_embd_head_k * attn_num_attention_head}, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_k}, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {hidden_size, attn_num_key_value_head * n_embd_head_v}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * attn_num_attention_head, hidden_size}, 0);
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED);
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {attn_num_key_value_head * n_embd_head_k}, TENSOR_NOT_REQUIRED);
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {attn_num_key_value_head * n_embd_head_v}, TENSOR_NOT_REQUIRED);
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED);
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {hidden_size}, 0);
// feed forward (w/ optional biases)
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, i), {hidden_size}, 0);
layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {hidden_size, ffn_intermediate_size}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { ffn_intermediate_size, hidden_size}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {hidden_size, ffn_intermediate_size}, 0);
layer.ffn_gate_b = create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED);
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {hidden_size}, TENSOR_NOT_REQUIRED);
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {ffn_intermediate_size}, TENSOR_NOT_REQUIRED);
}
} break;
case LLM_ARCH_HUNYUAN_MOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@ -10147,7 +10255,7 @@ struct llm_build_mamba : public llm_graph_context {
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
// cb(cur, "mamba_out", il);
cb(cur, "mamba_out", il);
return cur;
}
@ -14598,6 +14706,267 @@ struct llm_build_ernie4_5 : public llm_graph_context {
}
};
struct llm_build_falcon_h1 : public llm_graph_context {
const llama_model & model;
llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params), model(model) {
const int64_t n_embd_head = hparams.n_embd_head_v;
ggml_tensor * cur;
ggml_tensor * inpL;
inpL = build_inp_embd(model.tok_embd);
// inp_pos - contains the positions
ggml_tensor * inp_pos = build_inp_pos();
// Build the inputs in the recurrent & kv cache
auto * inp = build_inp_mem_hybrid();
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
ggml_tensor * inp_out_ids = build_inp_out_ids();
for (int il = 0; il < n_layer; ++il) {
ggml_tensor * inpSA = inpL;
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "attn_norm", il);
// self-attention
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, hparams.rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, hparams.rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur-post-rope", il);
cb(Kcur, "Kcur-post-rope", il);
cb(Vcur, "Vcur-post-rope", il);
ggml_tensor * attn_out = build_attn(inp, gf,
model.layers[il].wo, NULL,
Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
cb(attn_out, "attn_out", il);
cur = build_norm(inpL,
model.layers[il].attn_norm, NULL,
LLM_NORM_RMS, il);
// Mamba2 layer
cb(cur, "ssm_in", il);
ggml_tensor * ssm_out = build_mamba2_layer(inp, gf, cur, ubatch, il);
cb(ssm_out, "ssm_out", il);
// // Aggregation
cur = ggml_add(ctx0, attn_out, ssm_out);
inpSA = ggml_add(ctx0, cur, inpSA);
cb(cur, "layer_out", il);
if (il == n_layer - 1 && inp_out_ids) {
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
}
ggml_tensor * ffn_inp = inpSA;
cb(ffn_inp, "ffn_inp", il);
// feed-forward network
cur = build_norm(ffn_inp,
model.layers[il].ffn_norm, NULL,
LLM_NORM_RMS, il);
cb(cur, "ffn_norm", il);
cur = build_ffn(cur,
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
NULL,
LLM_FFN_SILU, LLM_FFN_PAR, il);
cb(cur, "ffn_out", il);
cur = ggml_add(ctx0, cur, inpSA);
cur = build_cvec(cur, il);
cb(cur, "l_out", il);
// input for next layer
inpL = cur;
}
cur = inpL;
cur = build_norm(cur,
model.output_norm, NULL,
LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
// lm_head
cur = build_lora_mm(model.output, cur);
cb(cur, "result_output", -1);
res->t_logits = cur;
ggml_build_forward_expand(gf, cur);
}
ggml_tensor * build_mamba2_layer(
llm_graph_input_mem_hybrid * inp,
ggml_cgraph * gf,
ggml_tensor * cur,
const llama_ubatch & ubatch,
int il) const {
const auto * kv_state = static_cast<const llama_memory_hybrid_context *>(mctx)->get_recr();
const auto kv_head = kv_state->get_head();
const int64_t d_conv = hparams.ssm_d_conv;
const int64_t d_inner = hparams.ssm_d_inner;
const int64_t d_state = hparams.ssm_d_state;
const int64_t n_head = hparams.ssm_dt_rank;
const int64_t head_dim = d_inner / n_head;
const int64_t n_group = hparams.ssm_n_group;
const int64_t n_seqs = ubatch.n_seqs;
const int64_t n_seq_tokens = ubatch.n_seq_tokens;
GGML_ASSERT(n_seqs != 0);
GGML_ASSERT(ubatch.equal_seqs);
GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs);
ggml_tensor * conv_states_all = kv_state->get_r_l(il);
ggml_tensor * ssm_states_all = kv_state->get_s_l(il);
ggml_tensor * conv = build_rs(inp, gf, conv_states_all, hparams.n_embd_r(), n_seqs);
conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs);
// d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads
// {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs}
ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur);
cb(zxBCdt, "zxBCdt", il);
// split the above in three
ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0);
ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt));
ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt));
// conv
{
// => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs}
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0);
// copy last (d_conv - 1) columns back into the state cache
ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0]));
ggml_build_forward_expand(gf,
ggml_cpy(ctx0, last_conv,
ggml_view_1d(ctx0, conv_states_all,
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
// 1D convolution
// The equivalent is to make a self-overlapping view of conv_x
// over d_conv columns at each stride in the 3rd dimension,
// then element-wise multiply that with the conv1d weight,
// then sum the elements of each row,
// (the last two steps are a dot product over rows (also doable with mul_mat))
// then permute away the ne[0] dimension,
// and then you're left with the resulting x tensor.
// For simultaneous sequences, all sequences need to have the same length.
xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d);
// bias
xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b);
xBC = ggml_silu(ctx0, xBC);
}
// ssm
{
// These correspond to V K Q in SSM/attention duality
ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0);
ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC));
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
// {n_head, n_seq_tokens, n_seqs}
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
ggml_tensor * A = model.layers[il].ssm_a;
// use the states and the indices provided by build_rs
// (this is necessary in order to properly use the states before they are overwritten,
// while avoiding to make unnecessary copies of the states)
auto get_ssm_rows = [&](ggml_context * ctx, ggml_tensor * states, ggml_tensor * ids) {
ggml_tensor * ssm = ggml_reshape_4d(ctx, states, d_state, head_dim, n_head, kv_state->get_size());
// TODO: use semistructured matrices to implement state-space duality
// => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs}
return ggml_ssm_scan(ctx, ssm, x, dt, A, B, C, ids);
};
ggml_tensor * y_ssm = build_rs(inp, gf, ssm_states_all, hparams.n_embd_s(), ubatch.n_seqs, get_ssm_rows);
// store last states
ggml_build_forward_expand(gf,
ggml_cpy(ctx0,
ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]),
ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all))));
ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0);
// TODO: skip computing output earlier for unused tokens
y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d));
y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y);
// grouped RMS norm
if (model.layers[il].ssm_norm) {
y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs);
y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il);
}
y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs);
// {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs}
cur = build_lora_mm(model.layers[il].ssm_out, y);
}
// {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens}
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs);
cb(cur, "mamba_out", il);
return cur;
}
};
struct llm_build_arcee : public llm_graph_context {
llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
@ -15077,7 +15446,9 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
/* recurrent_type_v */ GGML_TYPE_F32,
/* recurrent_kv_size */ std::max((uint32_t) 1, cparams.n_seq_max),
/* n_seq_max */ cparams.n_seq_max,
/* offload */ cparams.offload_kqv);
/* offload */ cparams.offload_kqv,
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
/* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
} else {
const auto padding = llama_kv_cache_unified::get_padding(cparams);
@ -15419,6 +15790,10 @@ llm_graph_result_ptr llama_model::build_graph(
{
llm = std::make_unique<llm_build_smollm3>(*this, params, gf);
} break;
case LLM_ARCH_FALCON_H1:
{
llm = std::make_unique<llm_build_falcon_h1>(*this, params, gf);
} break;
default:
GGML_ABORT("fatal error");
}
@ -15577,6 +15952,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
// the pairs of head values are offset by n_rot/2
case LLM_ARCH_FALCON:
case LLM_ARCH_FALCON_H1:
case LLM_ARCH_GROK:
case LLM_ARCH_DBRX:
case LLM_ARCH_BERT:

View File

@ -1523,6 +1523,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "llama-v3" ||
tokenizer_pre == "llama-bpe"||
tokenizer_pre == "falcon3" ||
tokenizer_pre == "falcon-h1" ||
tokenizer_pre == "pixtral") {
pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3;
ignore_merges = true;