opencl : update upscale to support align corners (#14488)

This commit is contained in:
lhez 2025-07-02 00:07:42 -07:00 committed by GitHub
parent 611ba4b264
commit 603e43dc91
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 25 deletions

View File

@ -4453,7 +4453,8 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
const ggml_scale_mode mode = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0); const int mode_flags = (ggml_scale_mode) ggml_get_op_params_i32(dst, 0);
const ggml_scale_mode mode = (ggml_scale_mode) (mode_flags & 0xFF);
cl_kernel kernel = nullptr; cl_kernel kernel = nullptr;
if (mode == GGML_SCALE_MODE_NEAREST) { if (mode == GGML_SCALE_MODE_NEAREST) {
@ -4484,18 +4485,22 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3]; const cl_ulong nb03 = src0->nb[3];
const int ne00_src = src0->ne[0]; const int ne00 = src0->ne[0];
const int ne01_src = src0->ne[1]; const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const int ne10_dst = dst->ne[0]; const int ne0 = dst->ne[0];
const int ne11_dst = dst->ne[1]; const int ne1 = dst->ne[1];
const int ne12_dst = dst->ne[2]; const int ne2 = dst->ne[2];
const int ne13_dst = dst->ne[3]; const int ne3 = dst->ne[3];
const float sf0 = (float)dst->ne[0] / src0->ne[0]; float sf0 = (float)ne0 / ne00;
const float sf1 = (float)dst->ne[1] / src0->ne[1]; float sf1 = (float)ne1 / ne01;
const float sf2 = (float)dst->ne[2] / src0->ne[2]; float sf2 = (float)ne2 / ne02;
const float sf3 = (float)dst->ne[3] / src0->ne[3]; float sf3 = (float)ne3 / ne03;
float pixel_offset = 0.5f;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0));
@ -4507,29 +4512,36 @@ static void ggml_cl_upscale(ggml_backend_t backend, const ggml_tensor * src0, gg
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb03));
if (mode == GGML_SCALE_MODE_NEAREST) { if (mode == GGML_SCALE_MODE_NEAREST) {
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne10_dst)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11_dst)); CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12_dst)); CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne2));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne13_dst)); CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne3));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0)); CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &sf0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float), &sf1));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf2));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf3));
} else if (mode == GGML_SCALE_MODE_BILINEAR) { } else if (mode == GGML_SCALE_MODE_BILINEAR) {
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00_src)); if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01_src)); sf0 = (float)(ne0 - 1) / (ne00 - 1);
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10_dst)); sf1 = (float)(ne1 - 1) / (ne01 - 1);
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11_dst)); pixel_offset = 0.0f;
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12_dst)); }
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne13_dst));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne2));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne3));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0)); CL_CHECK(clSetKernelArg(kernel, 14, sizeof(float), &sf0));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1)); CL_CHECK(clSetKernelArg(kernel, 15, sizeof(float), &sf1));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2)); CL_CHECK(clSetKernelArg(kernel, 16, sizeof(float), &sf2));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3)); CL_CHECK(clSetKernelArg(kernel, 17, sizeof(float), &sf3));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(float), &pixel_offset));
} }
size_t dst_total_elements = (size_t)ne10_dst * ne11_dst * ne12_dst * ne13_dst; size_t dst_total_elements = (size_t)ne0 * ne1 * ne2 * ne3;
if (dst_total_elements == 0) { if (dst_total_elements == 0) {
return; return;
} }

View File

@ -60,7 +60,8 @@ kernel void kernel_upscale_bilinear(
float sf0, float sf0,
float sf1, float sf1,
float sf2, float sf2,
float sf3 float sf3,
float pixel_offset
) { ) {
global const char * src_base = (global const char *)p_src0 + off_src0; global const char * src_base = (global const char *)p_src0 + off_src0;
global float * dst_base = (global float *)((global char *)p_dst + off_dst); global float * dst_base = (global float *)((global char *)p_dst + off_dst);
@ -80,8 +81,6 @@ kernel void kernel_upscale_bilinear(
int i02_src = (int)(i12_dst / sf2); int i02_src = (int)(i12_dst / sf2);
int i03_src = (int)(i13_dst / sf3); int i03_src = (int)(i13_dst / sf3);
const float pixel_offset = 0.5f;
float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset; float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
long y0_src = (long)floor(y_src_f); long y0_src = (long)floor(y_src_f);
long y1_src = y0_src + 1; long y1_src = y0_src + 1;