mtmd : add support for Voxtral (#14862)

* mtmd : add support for Voxtral

* clean up

* fix python requirements

* add [BEGIN_AUDIO] token

* also support Devstral conversion

* add docs and tests

* fix regression for ultravox

* minor coding style improvement

* correct project activation fn

* Apply suggestions from code review

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>

---------

Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
Xuan-Son Nguyen 2025-07-28 15:01:48 +02:00 committed by GitHub
parent 946b1f6859
commit 00fa15fedc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 546 additions and 46 deletions

1
.gitignore vendored
View File

@ -82,6 +82,7 @@ models/*
models-mnt
!models/.editorconfig
!models/ggml-vocab-*.gguf*
!models/templates
# Zig
zig-out/

View File

@ -1900,6 +1900,7 @@ class StableLMModel(TextModel):
"MixtralForCausalLM",
"VLlama3ForCausalLM",
"LlavaForConditionalGeneration",
"VoxtralForConditionalGeneration",
"LlamaModel")
class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA
@ -1912,6 +1913,11 @@ class LlamaModel(TextModel):
self.hparams["num_attention_heads"] = self.hparams.get("num_attention_heads", 32)
def set_vocab(self):
path_tekken_json = self.dir_model / "tekken.json"
path_tokenizer_json = self.dir_model / "tokenizer.json"
if path_tekken_json.is_file() and not path_tokenizer_json.is_file():
return self.set_vocab_tekken()
try:
self._set_vocab_sentencepiece()
except FileNotFoundError:
@ -1944,6 +1950,52 @@ class LlamaModel(TextModel):
if self.hparams.get("vocab_size", 32000) == 49152:
self.gguf_writer.add_add_bos_token(False)
def set_vocab_tekken(self):
vocab = gguf.vocab.MistralVocab(self.dir_model)
self.gguf_writer.add_tokenizer_model(vocab.gguf_tokenizer_model)
tokens = []
scores = []
toktypes = []
for text, score, toktype in vocab.all_tokens():
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
assert len(tokens) == vocab.vocab_size, (
f"token count ({len(tokens)}) != vocab size ({vocab.vocab_size})"
)
if vocab.tokenizer_type == gguf.vocab.MistralTokenizerType.tekken:
self.gguf_writer.add_tokenizer_pre("tekken")
self.gguf_writer.add_token_merges(
vocab.extract_vocab_merges_from_model()
)
logger.info(
f"Setting bos, eos, unk and pad token IDs to {vocab.bos_id}, {vocab.eos_id}, {vocab.unk_id}, {vocab.pad_id}."
)
self.gguf_writer.add_bos_token_id(vocab.bos_id)
self.gguf_writer.add_eos_token_id(vocab.eos_id)
self.gguf_writer.add_unk_token_id(vocab.unk_id)
self.gguf_writer.add_pad_token_id(vocab.pad_id)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
self.gguf_writer.add_vocab_size(vocab.vocab_size)
self.gguf_writer.add_add_bos_token(True)
self.gguf_writer.add_add_eos_token(False)
script_dir = Path(__file__).parent
template_path = script_dir / "models/templates/unsloth-mistral-Devstral-Small-2507.jinja"
with open(template_path, "r", encoding="utf-8") as f:
template = f.read()
self.gguf_writer.add_chat_template(template)
def set_gguf_parameters(self):
super().set_gguf_parameters()
hparams = self.hparams
@ -1971,12 +2023,13 @@ class LlamaModel(TextModel):
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
n_head = self.hparams["num_attention_heads"]
n_kv_head = self.hparams.get("num_key_value_heads")
is_vision_tensor = "vision_tower" in name \
is_multimodal_tensor = "vision_tower" in name \
or "vision_model" in name \
or "audio_tower" in name \
or "model.connector" in name \
or "multi_modal_projector" in name
if is_vision_tensor:
if is_multimodal_tensor:
return [] # skip vision tensors
elif self.hf_arch == "LlamaModel":
name = "model." + name
@ -7231,9 +7284,10 @@ class WhisperEncoderModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.hparams["hidden_size"] = self.hparams["d_model"]
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]
if "hidden_size" not in self.hparams and "intermediate_size" not in self.hparams:
self.hparams["hidden_size"] = self.hparams["d_model"]
self.hparams["intermediate_size"] = self.hparams["encoder_ffn_dim"]
self.hparams["num_attention_heads"] = self.hparams["encoder_attention_heads"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
@ -7272,9 +7326,21 @@ class UltravoxWhisperEncoderModel(WhisperEncoderModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.ULTRAVOX)
self.gguf_writer.add_audio_stack_factor(self.global_config["stack_factor"])
@ModelBase.register("VoxtralForConditionalGeneration")
class VoxtralWhisperEncoderModel(WhisperEncoderModel):
has_vision_encoder = False # no vision encoder
has_audio_encoder = True
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.VOXTRAL)
self.gguf_writer.add_audio_stack_factor(4) # == intermediate_size // hidden_size
@ModelBase.register("FalconH1ForCausalLM")
class FalconH1Model(Mamba2Model):
model_arch = gguf.MODEL_ARCH.FALCON_H1

View File

@ -97,6 +97,9 @@ NOTE: some models may require large context window, for example: `-c 8192`
# Qwen2-Audio and SeaLLM-Audio
# note: no pre-quantized GGUF this model, as they have very poor result
# ref: https://github.com/ggml-org/llama.cpp/pull/13760
# Mistral's Voxtral
(tool_name) -hf ggml-org/Voxtral-Mini-3B-2507-GGUF
```
**Mixed modalities**:

View File

@ -2724,6 +2724,7 @@ class VisionProjectorType:
INTERNVL = "internvl"
QWEN2A = "qwen2a" # audio
QWEN25O = "qwen2.5o" # omni
VOXTRAL = "voxtral"
# Items here are (block size, type size)

View File

@ -1,5 +1,6 @@
from __future__ import annotations
from enum import Enum
import re
import logging
import json
@ -12,6 +13,25 @@ try:
except ImportError:
SentencePieceProcessor = None
try:
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
from mistral_common.tokens.tokenizers.utils import (
_filter_valid_tokenizer_files,
)
from mistral_common.tokens.tokenizers.sentencepiece import (
SentencePieceTokenizer,
)
except ImportError:
_mistral_common_installed = False
MistralTokenizer = None
Tekkenizer = None
SentencePieceTokenizer = None
_filter_valid_tokenizer_files = None
else:
_mistral_common_installed = True
import gguf
from .gguf_writer import GGUFWriter
@ -592,3 +612,262 @@ class LlamaHfVocab(Vocab):
def __repr__(self) -> str:
return f"<LlamaHfVocab with {self.vocab_size_base} base tokens and {len(self.added_tokens_list)} added tokens>"
class MistralTokenizerType(str, Enum):
spm = "spm"
tekken = "tekken"
# Copied from Transformers (Apache 2.0)
# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py#L1544
def bytes_to_unicode() -> dict[int, str]:
"""
Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
characters the bpe code barfs on.
The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
tables between utf-8 bytes and unicode strings.
"""
bs = (
list(range(ord("!"), ord("~") + 1))
+ list(range(ord("¡"), ord("¬") + 1))
+ list(range(ord("®"), ord("ÿ") + 1))
)
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8 + n)
n += 1
cs_str = [chr(n) for n in cs]
return dict(zip(bs, cs_str))
class MistralVocab(Vocab):
tokenizer_model = "mistral"
name = "mistral"
added_tokens_dict: dict[str, int] = {}
added_tokens_list: list[str] = []
def __init__(self, base_path: Path):
if not _mistral_common_installed:
raise ImportError(
"To use MistralVocab, please install the `mistral-common` package. "
"You can install it with `pip install mistral-common`."
)
assert _filter_valid_tokenizer_files is not None, "mistral_common is not installed"
assert MistralTokenizer is not None, "mistral_common is not installed"
assert Tekkenizer is not None, "mistral_common is not installed"
logger.info(f"Loading Mistral tokenizer from {base_path}")
# Find the tokenizer files
all_files = [f.as_posix() for f in base_path.glob("**/*") if f.is_file()]
valid_tokenizer_files = _filter_valid_tokenizer_files(all_files)
if len(valid_tokenizer_files) == 0:
raise ValueError(f"No tokenizer file found in the directory: {base_path}")
# If there are multiple tokenizer files, we use tekken.json if it exists, otherwise the versioned one.
if len(valid_tokenizer_files) > 1:
if "tekken.json" in valid_tokenizer_files:
tokenizer_file = "tekken.json"
else:
tokenizer_file = sorted(valid_tokenizer_files)[-1]
logger.warning(
f"Multiple tokenizer files found in {base_path}. Using {tokenizer_file}"
)
else:
tokenizer_file = valid_tokenizer_files[0]
self.tokenizer = MistralTokenizer.from_file(
base_path / tokenizer_file
).instruct_tokenizer.tokenizer
self.tokenizer_type = (
MistralTokenizerType.tekken
if isinstance(self.tokenizer, Tekkenizer)
else MistralTokenizerType.spm
)
self.vocab_size = self.tokenizer.n_words
self.fname_tokenizer = base_path / tokenizer_file
self._name = (
"mistral-" + self.tokenizer_type.value + "-" + self.tokenizer.version
)
@property
def tokenizer_name(self) -> str:
return self._name
@property
def gguf_tokenizer_model(self) -> str:
return "llama" if self.tokenizer_type == MistralTokenizerType.spm else "gpt2"
def _sentencepiece_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
assert SentencePieceTokenizer is not None, "mistral_common is not installed"
assert isinstance(self.tokenizer, SentencePieceTokenizer), (
f"Expected SentencePieceTokenizer, got {type(self.tokenizer)}"
)
for i in range(self.tokenizer._model.vocab_size()):
piece = self.tokenizer._model.IdToPiece(i)
text = piece.encode("utf-8")
score: float = self.tokenizer._model.GetScore(i)
toktype = gguf.TokenType.NORMAL
if self.tokenizer._model.IsUnknown(i):
toktype = gguf.TokenType.UNKNOWN
if self.tokenizer._model.IsControl(i):
toktype = gguf.TokenType.CONTROL
if self.tokenizer._model.IsUnused(i):
toktype = gguf.TokenType.UNUSED
if self.tokenizer._model.IsByte(i):
toktype = gguf.TokenType.BYTE
yield text, score, toktype
def _tekken_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
assert Tekkenizer is not None, "mistral_common is not installed"
assert isinstance(self.tokenizer, Tekkenizer), (
f"Expected Tekkenizer, got {type(self.tokenizer)}"
)
byte_encoder = bytes_to_unicode()
for token_id in range(self.tokenizer.num_special_tokens):
yield (
self.tokenizer.id_to_piece(token_id).encode("utf-8"),
0,
gguf.TokenType.CONTROL
)
for token in self.tokenizer._tekken_token2id_nospecial:
yield (
self.token_bytes_to_string(token, byte_encoder).encode("utf-8"),
0,
gguf.TokenType.NORMAL,
)
def get_token_id(self, token: str) -> int:
assert SentencePieceTokenizer is not None and Tekkenizer is not None, "mistral_common is not installed"
if self.tokenizer_type == MistralTokenizerType.spm:
assert isinstance(self.tokenizer, SentencePieceTokenizer)
return self.tokenizer._vocab.index(token)
elif self.tokenizer_type == MistralTokenizerType.tekken:
assert isinstance(self.tokenizer, Tekkenizer)
return (
self.tokenizer._vocab.index(token) + self.tokenizer.num_special_tokens
)
else:
raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
@property
def bos_id(self) -> int:
return self.tokenizer.bos_id
@property
def eos_id(self) -> int:
return self.tokenizer.eos_id
@property
def pad_id(self) -> int:
if self.tokenizer.pad_id == -1:
return self.eos_id
return self.tokenizer.pad_id
@property
def unk_id(self) -> int:
return self.tokenizer.unk_id
@property
def bos_token(self) -> str:
return self.tokenizer.id_to_piece(self.tokenizer.bos_id)
@property
def eos_token(self) -> str:
return self.tokenizer.id_to_piece(self.tokenizer.eos_id)
@property
def pad_token(self) -> str:
return self.tokenizer.id_to_piece(self.tokenizer.pad_id)
@property
def unk_token(self) -> str:
return self.tokenizer.id_to_piece(self.tokenizer.unk_id)
def all_tokens(self) -> Iterable[tuple[bytes, float, gguf.TokenType]]:
if self.tokenizer_type == MistralTokenizerType.spm:
yield from self._sentencepiece_tokens()
elif self.tokenizer_type == MistralTokenizerType.tekken:
yield from self._tekken_tokens()
else:
raise ValueError(f"Unknown tokenizer type: {self.tokenizer_type}")
@staticmethod
def token_bytes_to_string(b, byte_encoder):
return "".join([byte_encoder[ord(char)] for char in b.decode("latin-1")])
def extract_vocab_merges_from_model(self):
# Adapted from Transformers (Apache 2.0)
# https://github.com/huggingface/transformers/blob/main/src/transformers/convert_slow_tokenizer.py
assert Tekkenizer is not None and isinstance(self.tokenizer, Tekkenizer), (
f"Expected Tekkenizer, got {type(self.tokenizer)}"
)
mergeable_ranks = self.tokenizer._model._mergeable_ranks
token_bytes_map = {
rank: token_bytes for token_bytes, rank in mergeable_ranks.items()
}
merge_pairs = []
# Sort vocab by rank to ensure correct merge order
for i in range(256, self.vocab_size - self.tokenizer.num_special_tokens):
merged_token = token_bytes_map[i]
local = []
for j in range(1, len(merged_token)):
left = merged_token[:j]
right = merged_token[j:]
if (
left in mergeable_ranks
and right in mergeable_ranks
and (left + right) in mergeable_ranks
):
local.append((left, right, i))
if not local:
raise ValueError(
f"Could not find valid merge for token at rank {i}: {merged_token.decode('latin-1')}"
)
local = sorted(
local,
key=lambda x: (mergeable_ranks[x[0]], mergeable_ranks[x[1]]),
reverse=False,
)
merge_pairs.extend(local)
merge_pairs = sorted(merge_pairs, key=lambda val: val[2], reverse=False)
byte_encoder = bytes_to_unicode()
decoded_merge_pairs = [
[
self.token_bytes_to_string(val[0], byte_encoder),
self.token_bytes_to_string(val[1], byte_encoder),
]
for val in merge_pairs
]
merges = [
" ".join(
[
# ensure the spaces are properly encoded
"".join(chr(ord(c) + 256) if c == " " else c for c in part)
for part in pair
]
)
for pair in decoded_merge_pairs
]
return merges

File diff suppressed because one or more lines are too long

View File

@ -1,3 +1,5 @@
mistral-common>=1.8.3
-r ./requirements-convert_legacy_llama.txt
--extra-index-url https://download.pytorch.org/whl/cpu
torch~=2.2.1; platform_machine != "s390x"

View File

@ -1,3 +1,3 @@
docstring_parser~=0.15
pydantic~=2.6.3
pydantic~=2.11.7
requests

View File

@ -131,6 +131,7 @@ enum projector_type {
PROJECTOR_TYPE_LLAMA4,
PROJECTOR_TYPE_QWEN2A,
PROJECTOR_TYPE_QWEN25O, // will be replaced by QWEN2A or QWEN25VL depending on clip_ctx
PROJECTOR_TYPE_VOXTRAL,
PROJECTOR_TYPE_UNKNOWN,
};
@ -150,6 +151,7 @@ static std::map<projector_type, std::string> PROJECTOR_TYPE_NAMES = {
{ PROJECTOR_TYPE_LLAMA4, "llama4"},
{ PROJECTOR_TYPE_QWEN2A, "qwen2a"},
{ PROJECTOR_TYPE_QWEN25O, "qwen2.5o"},
{ PROJECTOR_TYPE_VOXTRAL, "voxtral"},
};
static projector_type clip_projector_type_from_string(const std::string & str) {

View File

@ -354,6 +354,16 @@ struct clip_model {
ggml_tensor * conv1d_2_b = nullptr;
ggml_tensor * mm_norm_pre_w = nullptr;
ggml_tensor * mm_norm_mid_w = nullptr;
bool audio_has_avgpool() const {
return proj_type == PROJECTOR_TYPE_QWEN2A
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
}
bool audio_has_stack_frames() const {
return proj_type == PROJECTOR_TYPE_ULTRAVOX
|| proj_type == PROJECTOR_TYPE_VOXTRAL;
}
};
struct clip_ctx {
@ -1483,49 +1493,52 @@ struct clip_graph {
cb(cur, "after_transformer", -1);
if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
if (model.audio_has_stack_frames()) {
// StackAudioFrames
// https://huggingface.co/fixie-ai/ultravox-v0_5-llama-3_2-1b/blob/main/ultravox_model.py
{
int64_t stride = n_embd * hparams.proj_stack_factor;
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
int64_t pad = padded_len - ggml_nelements(cur);
if (pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
ggml_row_size(cur->type, stride), 0);
int64_t stride = n_embd * hparams.proj_stack_factor;
int64_t padded_len = GGML_PAD(ggml_nelements(cur), stride);
int64_t pad = padded_len - ggml_nelements(cur);
if (pad > 0) {
cur = ggml_view_1d(ctx0, cur, ggml_nelements(cur), 0);
cur = ggml_pad(ctx0, cur, pad, 0, 0, 0);
}
cur = ggml_view_2d(ctx0, cur, stride, padded_len / stride,
ggml_row_size(cur->type, stride), 0);
cb(cur, "after_stacked", -1);
}
if (ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX) {
// UltravoxProjector
{
// pre-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
// pre-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_pre_w);
// ffn in
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
// ffn in
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
// swiglu
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
cur = ggml_swiglu_swapped(ctx0, cur);
// swiglu
// see SwiGLU in ultravox_model.py, the second half passed through is silu, not the first half
cur = ggml_swiglu_swapped(ctx0, cur);
// mid-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
// mid-norm
cur = ggml_rms_norm(ctx0, cur, 1e-6);
cur = ggml_mul(ctx0, cur, model.mm_norm_mid_w);
// ffn out
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
}
// ffn out
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
} else if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
// projector
cur = ggml_mul_mat(ctx0, model.mm_fc_w, cur);
cur = ggml_add(ctx0, cur, model.mm_fc_b);
} else if (ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL) {
// projector
cur = ggml_mul_mat(ctx0, model.mm_1_w, cur);
cur = ggml_gelu_erf(ctx0, cur);
cur = ggml_mul_mat(ctx0, model.mm_2_w, cur);
} else {
GGML_ABORT("%s: unknown projector type", __func__);
}
@ -1670,8 +1683,7 @@ private:
inpL = cur;
}
// TODO @ngxson : find a way to move this outside
if (ctx->proj_type() == PROJECTOR_TYPE_QWEN2A) {
if (ctx->model.audio_has_avgpool()) {
ggml_tensor * cur = inpL;
cur = ggml_transpose(ctx0, cur);
cur = ggml_cont(ctx0, cur);
@ -1985,6 +1997,7 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32
res = graph.build_llama4();
} break;
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_QWEN2A:
{
res = graph.build_whisper_enc();
@ -2259,8 +2272,10 @@ struct clip_model_loader {
} break;
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_VOXTRAL:
{
bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX;
bool require_stack = model.proj_type == PROJECTOR_TYPE_ULTRAVOX ||
model.proj_type == PROJECTOR_TYPE_VOXTRAL;
get_u32(KEY_A_PROJ_STACK_FACTOR, hparams.proj_stack_factor, require_stack);
if (hparams.n_mel_bins != 128) {
throw std::runtime_error(string_format("%s: only 128 mel bins are supported for ultravox\n", __func__));
@ -2544,6 +2559,15 @@ struct clip_model_loader {
model.mm_fc_w = get_tensor(string_format(TN_MM_AUDIO_FC, "weight"));
model.mm_fc_b = get_tensor(string_format(TN_MM_AUDIO_FC, "bias"));
} break;
case PROJECTOR_TYPE_VOXTRAL:
{
model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight"));
model.conv1d_1_b = get_tensor(string_format(TN_CONV1D, 1, "bias"));
model.conv1d_2_w = get_tensor(string_format(TN_CONV1D, 2, "weight"));
model.conv1d_2_b = get_tensor(string_format(TN_CONV1D, 2, "bias"));
model.mm_1_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 1, "weight"));
model.mm_2_w = get_tensor(string_format(TN_MM_AUDIO_MLP, 2, "weight"));
} break;
case PROJECTOR_TYPE_INTERNVL:
{
model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight"));
@ -3570,17 +3594,26 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im
int scale_factor = ctx->model.hparams.proj_scale_factor;
n_patches_sq /= (scale_factor * scale_factor);
} break;
case PROJECTOR_TYPE_VOXTRAL:
case PROJECTOR_TYPE_ULTRAVOX:
{
const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
const int n_len = CLIP_ALIGN(img->nx, proj_stack_factor);
n_patches_sq = n_len / proj_stack_factor / 2;
} break;
case PROJECTOR_TYPE_QWEN2A:
{
// divide by 2 because of whisper
// another divide by 2 because of nn.AvgPool1d(2, stride=2)
n_patches_sq = img->nx / 4;
n_patches_sq = img->nx;
const int proj_stack_factor = ctx->model.hparams.proj_stack_factor;
if (ctx->model.audio_has_stack_frames()) {
GGML_ASSERT(proj_stack_factor > 0);
const int n_len = CLIP_ALIGN(n_patches_sq, proj_stack_factor);
n_patches_sq = n_len / proj_stack_factor;
}
// whisper downscales input token by half after conv1d
n_patches_sq /= 2;
if (ctx->model.audio_has_avgpool()) {
// divide by 2 because of nn.AvgPool1d(2, stride=2)
n_patches_sq /= 2;
}
} break;
default:
GGML_ABORT("unsupported projector type");
@ -3986,6 +4019,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima
case PROJECTOR_TYPE_INTERNVL:
case PROJECTOR_TYPE_QWEN2A:
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
{
// do nothing
} break;
@ -4086,6 +4120,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
case PROJECTOR_TYPE_IDEFICS3:
return ctx->model.projection->ne[1];
case PROJECTOR_TYPE_ULTRAVOX:
case PROJECTOR_TYPE_VOXTRAL:
return ctx->model.mm_2_w->ne[1];
case PROJECTOR_TYPE_INTERNVL:
return ctx->model.mm_3_w->ne[1];
@ -4132,7 +4167,8 @@ bool clip_has_audio_encoder(const struct clip_ctx * ctx) {
bool clip_has_whisper_encoder(const struct clip_ctx * ctx) {
return ctx->proj_type() == PROJECTOR_TYPE_ULTRAVOX
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN2A;
|| ctx->proj_type() == PROJECTOR_TYPE_QWEN2A
|| ctx->proj_type() == PROJECTOR_TYPE_VOXTRAL;
}
bool clip_encode_float_image (struct clip_ctx * ctx, int n_threads, float * img, int h, int w, float * vec) {

View File

@ -289,6 +289,10 @@ struct mtmd_context {
aud_beg = "<|audio_bos|>";
aud_end = "<|audio_eos|>";
} else if (proj == PROJECTOR_TYPE_ULTRAVOX) {
// [BEGIN_AUDIO] ... (embeddings) ...
aud_beg = "[BEGIN_AUDIO]";
}
}

View File

@ -1,5 +1,5 @@
-r ../../requirements/requirements-convert_legacy_llama.txt
--extra-index-url https://download.pytorch.org/whl/cpu
pillow~=10.2.0
pillow~=11.3.0
torch~=2.2.1
torchvision~=0.17.1

View File

@ -71,6 +71,7 @@ add_test_vision "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
add_test_audio "ggml-org/ultravox-v0_5-llama-3_2-1b-GGUF:Q8_0"
add_test_audio "ggml-org/Qwen2.5-Omni-3B-GGUF:Q4_K_M"
add_test_audio "ggml-org/Voxtral-Mini-3B-2507-GGUF:Q4_K_M"
# to test the big models, run: ./tests.sh big
if [ "$RUN_BIG_TESTS" = true ]; then