use max work group size for device to replace the magic number (#14732)

This commit is contained in:
Neo Zhang Jianyu 2025-07-18 10:23:14 +08:00 committed by GitHub
parent 670e1360cd
commit 349ea79fce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 5 additions and 2 deletions

View File

@ -3530,8 +3530,11 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
SYCL_CHECK(CHECK_TRY_ERROR( SYCL_CHECK(CHECK_TRY_ERROR(
stream->memset(dev_cur_src1_row.get(), 0, sizeof(int)))); stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device];
assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0);
{ {
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u)); sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, max_work_group_size));
sycl::range<3> grid_dims(1, n_ids, ids->ne[1]); sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
sycl_launch(stream, [&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
sycl::local_accessor<int, 0> src1_row_acc(cgh); sycl::local_accessor<int, 0> src1_row_acc(cgh);
@ -3575,7 +3578,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx,
ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row); ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
{ {
sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u)); sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, max_work_group_size));
sycl::range<3> grid_dims(1, 1, num_src1_rows); sycl::range<3> grid_dims(1, 1, num_src1_rows);
sycl_launch(stream, [&](sycl::handler & cgh) { sycl_launch(stream, [&](sycl::handler & cgh) {
const char *__restrict dst_contiguous_get = const char *__restrict dst_contiguous_get =