opencl: add fused `rms_norm_mul` (#14841)

* opencl: add fused `rms_norm` + `mul`

* opencl: improve workgroup size for `rms_norm_mul`
This commit is contained in:
lhez 2025-07-25 08:12:13 -07:00 committed by GitHub
parent e7fecba934
commit ce111d39d6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 240 additions and 2 deletions

View File

@ -333,6 +333,7 @@ struct ggml_backend_opencl_context {
size_t max_alloc_size;
bool fp16_support;
bool has_vector_subgroup_broadcast;
bool disable_fusion;
ggml_cl_compiler_version adreno_cl_compiler_version;
int adreno_wave_size;
@ -411,7 +412,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_geglu, kernel_reglu, kernel_swiglu, kernel_geglu_erf, kernel_geglu_quick,
kernel_geglu_f16, kernel_reglu_f16, kernel_swiglu_f16, kernel_geglu_erf_f16, kernel_geglu_quick_f16;
cl_kernel kernel_norm;
cl_kernel kernel_rms_norm;
cl_kernel kernel_rms_norm, kernel_rms_norm_mul;
cl_kernel kernel_group_norm;
cl_kernel kernel_diag_mask_inf, kernel_diag_mask_inf_8;
cl_kernel kernel_soft_max, kernel_soft_max_4;
@ -1100,7 +1101,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
backend_ctx->program_rms_norm =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
CL_CHECK((backend_ctx->kernel_rms_norm = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm", &err), err));
CL_CHECK((backend_ctx->kernel_rms_norm_mul = clCreateKernel(backend_ctx->program_rms_norm, "kernel_rms_norm_mul", &err), err));
GGML_LOG_CONT(".");
}
@ -2110,6 +2112,8 @@ static ggml_backend_opencl_context * ggml_cl2_init(ggml_backend_dev_t dev) {
CL_CHECK((backend_ctx->B_d_max = clCreateBuffer(context, 0, max_B_d_bytes, NULL, &err), err));
#endif // GGML_OPENCL_USE_ADRENO_KERNELS
backend_ctx->disable_fusion = getenv("GGML_OPENCL_DISABLE_FUSION") != nullptr;
dev_ctx->backend_ctx = backend_ctx.release();
return dev_ctx->backend_ctx;
}
@ -2279,7 +2283,45 @@ static void sync_with_other_backends(ggml_backend_t backend) {
sync_with_other_backends(backend_ctx);
}
static bool ggml_opencl_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32);
GGML_ASSERT(rms_norm->type == GGML_TYPE_F32);
// rms_norm only supports f32
if (mul->src[0]->type != GGML_TYPE_F32 ||
mul->src[1]->type != GGML_TYPE_F32 ||
mul->type != GGML_TYPE_F32) {
return false;
}
// if rms_norm is the B operand, then we don't handle broadcast
if (rms_norm == mul->src[1] &&
!ggml_are_same_shape(mul->src[0], rms_norm->src[1])) {
return false;
}
// rms_norm assumes contiguous rows
if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) {
return false;
}
}
return true;
}
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor);
static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
@ -2292,6 +2334,12 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm
continue;
}
if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
ggml_opencl_op_rms_norm_fused(backend, node, cgraph->nodes[i+1]);
i++;
continue;
}
bool ok = ggml_cl_compute_forward(backend, node);
if (!ok) {
GGML_LOG_ERROR("%s: error: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op));
@ -4455,6 +4503,117 @@ static void ggml_cl_rms_norm(ggml_backend_t backend, const ggml_tensor * src0, c
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_opencl_op_rms_norm_fused(ggml_backend_t backend, ggml_tensor * rms_norm_tensor, ggml_tensor * mul_tensor) {
GGML_ASSERT(mul_tensor);
GGML_ASSERT(rms_norm_tensor);
// src0 is the src of rms_norm, src1 is the other src of mul (one being rms_norm)
const ggml_tensor * src0 = rms_norm_tensor->src[0];
const ggml_tensor * src1;
if (mul_tensor->src[0] == rms_norm_tensor) {
src1 = mul_tensor->src[1];
} else if (mul_tensor->src[1] == rms_norm_tensor) {
src1 = mul_tensor->src[0];
} else {
GGML_ASSERT(false && "Invalid args for rms_norm and mul");
}
const ggml_tensor * dst = mul_tensor;
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(src1);
GGML_ASSERT(src1->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
float eps;
memcpy(&eps, rms_norm_tensor->op_params, sizeof(float));
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const int ne10 = src1->ne[0];
const int ne11 = src1->ne[1];
const int ne12 = src1->ne[2];
const int ne13 = src1->ne[3];
const cl_ulong nb11 = src1->nb[1];
const cl_ulong nb12 = src1->nb[2];
const cl_ulong nb13 = src1->nb[3];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
GGML_ASSERT(ne00 % 4 == 0);
size_t sgs;
if (backend_ctx->gpu_family == ADRENO) {
sgs = 64;
} else if (backend_ctx->gpu_family == INTEL) {
sgs = 32;
} else {
GGML_ASSERT(false && "Unsupported GPU");
}
cl_kernel kernel = backend_ctx->kernel_rms_norm_mul;
int nth = sgs;
int max_workgroup_size = backend_ctx->get_kernel_workgroup_size(kernel);
while (nth < ne00 && nth < max_workgroup_size) {
nth *= 2;
}
nth = MIN(nth, max_workgroup_size);
nth = MIN(nth, ne00);
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1};
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 23, sizeof(float), &eps));
CL_CHECK(clSetKernelArg(kernel, 24, sizeof(float)*nth/sgs, NULL));
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_group_norm(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);

View File

@ -94,3 +94,82 @@ kernel void kernel_rms_norm(
}
}
}
//------------------------------------------------------------------------------
// rms_norm_mul
//------------------------------------------------------------------------------
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_32
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_rms_norm_mul(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
int ne10,
int ne11,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float eps,
local float * sum
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int i03 = get_group_id(2);
int i02 = get_group_id(1);
int i01 = get_group_id(0);
global float4 * x = (global float4 *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
global float4 * f = (global float4 *) (src1 + (i03%ne13)*nb13 + (i02%ne12)*nb12 + (i01%ne11)*nb11);
float sumf = 0;
// parallel sum
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
sumf += dot(x[i00], x[i00]);
}
sumf = sub_group_reduce_add(sumf);
if (get_sub_group_local_id() == 0) {
sum[get_sub_group_id()] = sumf;
}
barrier(CLK_LOCAL_MEM_FENCE);
for (uint i = get_local_size(0) / get_max_sub_group_size() / 2; i > 0; i /= 2) {
if (get_local_id(0) < i) {
sum[get_local_id(0)] += sum[get_local_id(0) + i];
}
}
if (get_local_id(0) == 0) {
sum[0] /= ne00;
}
barrier(CLK_LOCAL_MEM_FENCE);
float mean = sum[0];
float scale = 1.0f/sqrt(mean + eps);
global float4 * y = (global float4 *) (dst + i03*nb3 + i02*nb2 + i01*nb1);
for (int i00 = get_local_id(0); i00 < ne00/4; i00 += get_local_size(0)) {
y[i00] = (x[i00] * scale) * f[i00%(ne10/4)];
}
}