opencl : broadcast for soft_max (#14510)

This commit is contained in:
lhez 2025-07-03 11:22:24 -07:00 committed by GitHub
parent 2b72bedec1
commit bee28421be
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 132 additions and 59 deletions

View File

@ -5763,19 +5763,31 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0; cl_ulong offset1 = extra1 ? extra1->offset + src1->view_offs : offset0;
const int ne00 = src0 ? src0->ne[0] : 0; const int ne00 = src0->ne[0];
const int ne01 = src0 ? src0->ne[1] : 0; const int ne01 = src0->ne[1];
const int ne02 = src0 ? src0->ne[2] : 0; const int ne02 = src0->ne[2];
const int ne03 = src0 ? src0->ne[3] : 0; const int ne03 = src0->ne[3];
const cl_long nb01 = src0->nb[1];
const cl_long nb02 = src0->nb[2];
const cl_long nb03 = src0->nb[3];
const int ne12 = src1 ? src1->ne[2] : 0;
const int ne13 = src1 ? src1->ne[3] : 0;
const cl_long nb11 = src1 ? src1->nb[1] : 0;
const cl_long nb12 = src1 ? src1->nb[2] : 0;
const cl_long nb13 = src1 ? src1->nb[3] : 0;
const cl_long nb1 = dst->nb[1];
const cl_long nb2 = dst->nb[2];
const cl_long nb3 = dst->nb[3];
float scale, max_bias; float scale, max_bias;
memcpy(&scale, dst->op_params + 0, sizeof(float)); memcpy(&scale, dst->op_params + 0, sizeof(float));
memcpy(&max_bias, dst->op_params + 1, sizeof(float)); memcpy(&max_bias, dst->op_params + 1, sizeof(float));
const int nrows_x = ggml_nrows(src0); const int n_head = src0->ne[2];
const int nrows_y = src0->ne[1];
const int n_head = nrows_x/nrows_y;
const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); const int n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
@ -5820,13 +5832,22 @@ static void ggml_cl_soft_max(ggml_backend_t backend, const ggml_tensor * src0, c
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(float), &scale)); CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(float), &max_bias)); CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &m0)); CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &m1)); CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &n_head_log2)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &nb3));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(float), &max_bias));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(float), &m0));
CL_CHECK(clSetKernelArg(kernel, 21, sizeof(float), &m1));
CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &n_head_log2));
size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)nth, 1, 1}; size_t local_work_size[] = {(size_t)nth, 1, 1};

View File

@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64 REQD_SUBGROUP_SIZE_64
#endif #endif
kernel void kernel_soft_max_4_f16( kernel void kernel_soft_max_4_f16(
global float * src0, global char * src0,
ulong offset0, ulong offset0,
global half * src1, global char * src1,
ulong offset1, ulong offset1,
global float * dst, global char * dst,
ulong offsetd, ulong offsetd,
int ne00, int ne00,
int ne01, ulong nb01,
int ne02, ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale, float scale,
float max_bias, float max_bias,
float m0, float m0,
float m1, float m1,
int n_head_log2 int n_head_log2
) { ) {
src0 = (global float *)((global char *)src0 + offset0); src0 = src0 + offset0;
src1 = (global half *)((global char *)src1 + offset1); src1 = src1 + offset1;
dst = (global float *)((global char *)dst + offsetd); dst = dst + offsetd;
int i03 = get_group_id(2); int i03 = get_group_id(2);
int i02 = get_group_id(1); int i02 = get_group_id(1);
int i01 = get_group_id(0); int i01 = get_group_id(0);
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); int i13 = i03%ne13;
global half4 * pmask = (global char *)src1 != (global char *)src0 ? (global half4 *)(src1 + i01*ne00) : 0; int i12 = i02%ne12;
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); int i11 = i01;
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half4 * pmask = src1 != src0 ? (global half4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f; float slope = 1.0f;

View File

@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64 REQD_SUBGROUP_SIZE_64
#endif #endif
kernel void kernel_soft_max_4( kernel void kernel_soft_max_4(
global float * src0, global char * src0,
ulong offset0, ulong offset0,
global float * src1, global char * src1,
ulong offset1, ulong offset1,
global float * dst, global char * dst,
ulong offsetd, ulong offsetd,
int ne00, int ne00,
int ne01, ulong nb01,
int ne02, ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale, float scale,
float max_bias, float max_bias,
float m0, float m0,
float m1, float m1,
int n_head_log2 int n_head_log2
) { ) {
src0 = (global float*)((global char*)src0 + offset0); src0 = src0 + offset0;
src1 = (global float*)((global char*)src1 + offset1); src1 = src1 + offset1;
dst = (global float*)((global char*)dst + offsetd); dst = dst + offsetd;
int i03 = get_group_id(2); int i03 = get_group_id(2);
int i02 = get_group_id(1); int i02 = get_group_id(1);
int i01 = get_group_id(0); int i01 = get_group_id(0);
global float4 * psrc4 = (global float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); int i13 = i03%ne13;
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i01*ne00) : 0; int i12 = i02%ne12;
global float4 * pdst4 = (global float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); int i11 = i01;
global float4 * psrc4 = (global float4 *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float4 * pmask = src1 != src0 ? (global float4 *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float4 * pdst4 = (global float4 *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f; float slope = 1.0f;

View File

@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64 REQD_SUBGROUP_SIZE_64
#endif #endif
kernel void kernel_soft_max_f16( kernel void kernel_soft_max_f16(
global float * src0, global char * src0,
ulong offset0, ulong offset0,
global half * src1, global char * src1,
ulong offset1, ulong offset1,
global float * dst, global char * dst,
ulong offsetd, ulong offsetd,
int ne00, int ne00,
int ne01, ulong nb01,
int ne02, ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale, float scale,
float max_bias, float max_bias,
float m0, float m0,
float m1, float m1,
int n_head_log2 int n_head_log2
) { ) {
src0 = (global float *)((global char *)src0 + offset0); src0 = src0 + offset0;
src1 = (global half *)((global char *)src1 + offset1); src1 = src1 + offset1;
dst = (global float *)((global char *)dst + offsetd); dst = dst + offsetd;
int i03 = get_group_id(2); int i03 = get_group_id(2);
int i02 = get_group_id(1); int i02 = get_group_id(1);
int i01 = get_group_id(0); int i01 = get_group_id(0);
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; int i13 = i03%ne13;
global half * pmask = (global char *)src1 != (global char *)src0 ? src1 + i01*ne00 : 0; int i12 = i02%ne12;
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; int i11 = i01;
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global half * pmask = src1 != src0 ? (global half *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f; float slope = 1.0f;

View File

@ -22,32 +22,45 @@
REQD_SUBGROUP_SIZE_64 REQD_SUBGROUP_SIZE_64
#endif #endif
kernel void kernel_soft_max( kernel void kernel_soft_max(
global float * src0, global char * src0,
ulong offset0, ulong offset0,
global float * src1, global char * src1,
ulong offset1, ulong offset1,
global float * dst, global char * dst,
ulong offsetd, ulong offsetd,
int ne00, int ne00,
int ne01, ulong nb01,
int ne02, ulong nb02,
ulong nb03,
int ne12,
int ne13,
ulong nb11,
ulong nb12,
ulong nb13,
ulong nb1,
ulong nb2,
ulong nb3,
float scale, float scale,
float max_bias, float max_bias,
float m0, float m0,
float m1, float m1,
int n_head_log2 int n_head_log2
) { ) {
src0 = (global float*)((global char*)src0 + offset0); src0 = src0 + offset0;
src1 = (global float*)((global char*)src1 + offset1); src1 = src1 + offset1;
dst = (global float*)((global char*)dst + offsetd); dst = dst + offsetd;
int i03 = get_group_id(2); int i03 = get_group_id(2);
int i02 = get_group_id(1); int i02 = get_group_id(1);
int i01 = get_group_id(0); int i01 = get_group_id(0);
global float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; int i13 = i03%ne13;
global float * pmask = src1 != src0 ? src1 + i01*ne00 : 0; int i12 = i02%ne12;
global float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; int i11 = i01;
global float * psrc0 = (global float *)(src0 + i01*nb01 + i02*nb02 + i03*nb03);
global float * pmask = src1 != src0 ? (global float *)(src1 + i11*nb11 + i12*nb12 + i13*nb13) : 0;
global float * pdst = (global float *)(dst + i01*nb1 + i02*nb2 + i03*nb3);
float slope = 1.0f; float slope = 1.0f;