vulkan: fix rms_norm_mul to handle broadcasting dim0 (#14817)
This commit is contained in:
parent
d4d1522b20
commit
84712b6043
|
@ -10248,7 +10248,7 @@ static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, st
|
||||||
}
|
}
|
||||||
// if rms_norm is the B operand, then we don't handle broadcast
|
// if rms_norm is the B operand, then we don't handle broadcast
|
||||||
if (rms_norm == mul->src[1] &&
|
if (rms_norm == mul->src[1] &&
|
||||||
mul->src[0]->ne[1] != rms_norm->ne[1]) {
|
!ggml_are_same_shape(mul->src[0], rms_norm)) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
// rms_norm shader assumes contiguous rows
|
// rms_norm shader assumes contiguous rows
|
||||||
|
|
|
@ -50,9 +50,15 @@ void main() {
|
||||||
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
|
||||||
|
|
||||||
if (do_multiply) {
|
if (do_multiply) {
|
||||||
|
if (ncols > p.ne10) {
|
||||||
|
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||||
|
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)]));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
|
||||||
}
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
[[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
|
||||||
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
|
||||||
|
|
Loading…
Reference in New Issue