llama : fix parallel processing for plamo2 (#14716)
This commit is contained in:
parent
538cc77f7f
commit
e4841d24d3
|
@ -15763,6 +15763,7 @@ private:
|
||||||
cb(zx, "mamba_in_proj", il);
|
cb(zx, "mamba_in_proj", il);
|
||||||
// {8192, 5, 1, 1} -> {8192, 1, 5, 1}
|
// {8192, 5, 1, 1} -> {8192, 1, 5, 1}
|
||||||
zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
|
zx = ggml_permute(ctx0, zx, 0, 2, 1, 3);
|
||||||
|
zx = ggml_cont(ctx0, zx);
|
||||||
zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
|
zx = ggml_reshape_4d(ctx0, zx, head_dim * 2, n_heads, n_seq_tokens, n_seqs);
|
||||||
cb(zx, "mamba_in_proj_out", il);
|
cb(zx, "mamba_in_proj_out", il);
|
||||||
|
|
||||||
|
@ -15780,7 +15781,6 @@ private:
|
||||||
// conv1d
|
// conv1d
|
||||||
{
|
{
|
||||||
// => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
|
// => {d_conv - 1 + n_seq_tokens, d_inner, n_seqs}
|
||||||
x = ggml_view_2d(ctx0, x, d_inner, n_seq_tokens * n_seqs, d_inner * x->nb[0], 0);
|
|
||||||
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
|
ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, x), 0);
|
||||||
cb(conv_x, "mamba_conv1d_input", il);
|
cb(conv_x, "mamba_conv1d_input", il);
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue