CUDA: fix pointer incrementation in FA (#14916)

This commit is contained in:
Johannes Gäßler 2025-07-28 14:30:22 +02:00 committed by GitHub
parent 6c6e397aff
commit 946b1f6859
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 8 additions and 10 deletions

View File

@ -174,7 +174,10 @@ static __global__ void flash_attn_vec_ext_f16(
K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
// Increment pointers after each loop:
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
if (mask) {
@ -291,10 +294,6 @@ static __global__ void flash_attn_vec_ext_f16(
}
}
K += gridDim.y*D * nb11;
V += gridDim.y*D * nb21;
maskh += gridDim.y*D;
__syncthreads();
}

View File

@ -180,7 +180,10 @@ static __global__ void flash_attn_vec_ext_f32(
K += blockIdx.y*D * nb11;
V += blockIdx.y*D * nb21;
maskh += blockIdx.y*D;
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D,
// Increment pointers after each loop:
K += gridDim.y*D*nb11, V += gridDim.y*D*nb21, maskh += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
if (mask) {
@ -286,10 +289,6 @@ static __global__ void flash_attn_vec_ext_f32(
}
}
K += gridDim.y*D * nb11;
V += gridDim.y*D * nb21;
maskh += gridDim.y*D;
__syncthreads();
}