vulkan/cuda: Fix im2col when KW!=KH (#14789)

The tid is decomposed into "ow + ky*OW + kx*OW*KH". Change "ksize" to match.
This commit is contained in:
Jeff Bolz 2025-07-21 06:35:40 -05:00 committed by GitHub
parent c82d48ec23
commit c2e058f1b4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 4 additions and 5 deletions

View File

@ -10,7 +10,7 @@ static __global__ void im2col_kernel(
return; return;
} }
const int64_t ksize = OW * (KH > 1 ? KW : 1); const int64_t ksize = OW * KH;
const int64_t kx = i / ksize; const int64_t kx = i / ksize;
const int64_t kd = kx * ksize; const int64_t kd = kx * ksize;
const int64_t ky = (i - kd) / OW; const int64_t ky = (i - kd) / OW;

View File

@ -40,12 +40,10 @@ void main() {
const uint src_base = ic * p.offset_delta + batch * p.batch_offset; const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH); const uint dst_base = ((batch * p.OH + oh) * p.OW) * p.CHW + ic * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1; const int oh_s1 = int(oh) * p.s1;
const uint ksize = p.OW * (p.KH > 1 ? p.KW : 1); const uint ksize = p.OW * p.KH;
const uint base_linear_idx = gidx * NUM_ITER; const uint base_linear_idx = gidx * NUM_ITER;
const uint max_ky = ksize / p.OW;
uint current_kx = base_linear_idx / ksize; uint current_kx = base_linear_idx / ksize;
const uint rem = base_linear_idx - (current_kx * ksize); const uint rem = base_linear_idx - (current_kx * ksize);
uint current_ky = rem / p.OW; uint current_ky = rem / p.OW;
@ -76,7 +74,7 @@ void main() {
if (++current_ix == p.OW) { if (++current_ix == p.OW) {
current_ix = 0; current_ix = 0;
if (++current_ky == max_ky) { if (++current_ky == p.KH) {
current_ky = 0; current_ky = 0;
current_kx++; current_kx++;
} }

View File

@ -5093,6 +5093,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2048}, {3, 3, 2, 2048}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 1, 2560}, {3, 3, 1, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true)); test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {12, 12, 2, 2560}, {3, 3, 2, 2560}, 1, 1, 1, 1, 1, 1, true));
test_cases.emplace_back(new test_im2col(GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_F16, {5, 5, 1, 32}, {3, 4, 1, 32}, 1, 1, 0, 0, 1, 1, true));
// Conv_2D test cases // Conv_2D test cases
#ifdef DETAILED_TESTS #ifdef DETAILED_TESTS