llama : add high-throughput mode (#14363)
* kv-cache : prepare K/V buffers for separation ggml-ci * batched-bench : fix oob write ggml-ci * llama : add "virtual sequences" ggml-ci * llama : use "stream" vs "virtual sequence" ggml-ci * graph : fix stream splitting when KV cache is not used ggml-ci * kv-cache : add multi-stream save/load support ggml-ci * llama : add "--attn-streams" flag ggml-ci * kv-cache : fix handling when find_slot fails ggml-ci * kv-cache : restore find_slot impl ggml-ci * kv-cache : add comments * kv-cache : add bounds checks for sequence id ggml-ci * cont : add n_seq_max to batch allocr ggml-ci * kv-cache : perform stream copies lazily after llama_synchronize ggml-ci * kv-cache : avoid throwing exceptions across the C boundary ggml-ci * CUDA: 4D FlashAttention support (#14628) * CUDA: 4D FlashAttention support * CUDA: fix WMMA FA kernel * llama : rename attn_streams -> kv_unified ggml-ci * common : rename kv_split -> kv_unified ggml-ci --------- Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
parent
ab14019821
commit
225e7a1438
|
@ -1464,6 +1464,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
|
|||
params.swa_full = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_SWA_FULL"));
|
||||
add_opt(common_arg(
|
||||
{"--kv-unified", "-kvu"},
|
||||
string_format("use single unified KV buffer for the KV cache of all sequences (default: %s)\n"
|
||||
"[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)", params.kv_unified ? "true" : "false"),
|
||||
[](common_params & params) {
|
||||
params.kv_unified = true;
|
||||
}
|
||||
).set_env("LLAMA_ARG_KV_SPLIT"));
|
||||
add_opt(common_arg(
|
||||
{"--no-context-shift"},
|
||||
string_format("disables context shift on infinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||
|
|
|
@ -1163,6 +1163,7 @@ struct llama_context_params common_context_params_to_llama(const common_params &
|
|||
cparams.no_perf = params.no_perf;
|
||||
cparams.op_offload = !params.no_op_offload;
|
||||
cparams.swa_full = params.swa_full;
|
||||
cparams.kv_unified = params.kv_unified;
|
||||
|
||||
cparams.type_k = params.cache_type_k;
|
||||
cparams.type_v = params.cache_type_v;
|
||||
|
|
|
@ -341,6 +341,7 @@ struct common_params {
|
|||
bool no_perf = false; // disable performance metrics
|
||||
bool ctx_shift = true; // context shift on inifinite text generation
|
||||
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
bool kv_unified = false; // enable unified KV cache
|
||||
|
||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||
bool use_mmap = true; // use mmap for faster loads
|
||||
|
|
|
@ -107,7 +107,7 @@ int main(int argc, char ** argv) {
|
|||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
|
||||
const int n_ctx_train = llama_model_n_ctx_train(model);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
||||
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
|
||||
|
||||
|
|
|
@ -224,6 +224,7 @@ int main(int argc, char ** argv) {
|
|||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.smpl = common_sampler_init(model, params.sampling);
|
||||
//params.sampling.seed++;
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
|
@ -345,7 +346,7 @@ int main(int argc, char ** argv) {
|
|||
client.n_decoded = 0;
|
||||
client.i_batch = batch.n_tokens - 1;
|
||||
|
||||
LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur);
|
||||
LOG_INF("\033[31mClient %3d, seq %4d, junk = %4d, prompt = %d, started decoding ...\033[0m\n", client.id, client.seq_id, n_junk_cur, client.n_prompt);
|
||||
|
||||
g_seq_id += 1;
|
||||
|
||||
|
|
|
@ -33,8 +33,10 @@ typedef void (* fattn_kernel_t)(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -521,7 +523,7 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
|
|||
template<int D, int ncols1, int ncols2> // D == head size
|
||||
__launch_bounds__(D, 1)
|
||||
static __global__ void flash_attn_stream_k_fixup(
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
|
||||
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne03, const int ne11) {
|
||||
constexpr int ncols = ncols1*ncols2;
|
||||
|
||||
const int bidx0 = blockIdx.x;
|
||||
|
@ -535,8 +537,8 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
const int iter_k = ne11 / FATTN_KQ_STRIDE;
|
||||
const int iter_j = (ne01 + (ncols1 - 1)) / ncols1;
|
||||
|
||||
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc0 = (bidx0 + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc0_stop = (bidx0 + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
|
||||
const bool did_not_have_any_data = kbc0 == kbc0_stop;
|
||||
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
|
||||
|
@ -545,14 +547,15 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc0 / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
|
||||
const int sequence = kbc0 / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc0 - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
if (jt*ncols1 + j >= ne01) {
|
||||
return;
|
||||
}
|
||||
|
||||
dst += jt*ne02*(ncols1*D) + channel*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
dst += sequence*ne02*ne01*D + jt*ne02*(ncols1*D) + head*(ncols2*D) + (j*ne02 + c)*D + tid;
|
||||
|
||||
// Load the partial result that needs a fixup:
|
||||
float dst_val = 0.0f;
|
||||
|
@ -571,7 +574,7 @@ static __global__ void flash_attn_stream_k_fixup(
|
|||
int bidx = bidx0 - 1;
|
||||
int kbc_stop = kbc0;
|
||||
while(true) {
|
||||
const int kbc = bidx*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc = bidx*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
if (kbc == kbc_stop) { // Did not have any data.
|
||||
bidx--;
|
||||
kbc_stop = kbc;
|
||||
|
@ -617,16 +620,31 @@ static __global__ void flash_attn_combine_results(
|
|||
const float2 * __restrict__ VKQ_meta,
|
||||
float * __restrict__ dst,
|
||||
const int parallel_blocks) {
|
||||
VKQ_parts += parallel_blocks*D * gridDim.z*blockIdx.x;
|
||||
VKQ_meta += parallel_blocks * gridDim.z*blockIdx.x;
|
||||
dst += D * gridDim.z*blockIdx.x;
|
||||
// Dimension 0: threadIdx.x
|
||||
// Dimension 1: blockIdx.x
|
||||
// Dimension 2: blockIdx.y
|
||||
// Dimension 3: blockIdx.z
|
||||
// Memory layout is permuted with [0, 2, 1, 3]
|
||||
|
||||
const int ne01 = gridDim.x;
|
||||
const int ne02 = gridDim.y;
|
||||
|
||||
const int col = blockIdx.x;
|
||||
const int head = blockIdx.y;
|
||||
const int sequence = blockIdx.z;
|
||||
|
||||
const int j_dst_unrolled = (sequence*ne01 + col)*ne02 + head;
|
||||
|
||||
VKQ_parts += j_dst_unrolled * parallel_blocks*D;
|
||||
VKQ_meta += j_dst_unrolled * parallel_blocks;
|
||||
dst += j_dst_unrolled * D;
|
||||
|
||||
const int tid = threadIdx.x;
|
||||
__builtin_assume(tid < D);
|
||||
|
||||
extern __shared__ float2 meta[];
|
||||
for (int i = tid; i < 2*parallel_blocks; i += D) {
|
||||
((float *) meta)[i] = ((const float *)VKQ_meta) [blockIdx.z*(2*parallel_blocks) + i];
|
||||
((float *) meta)[i] = ((const float *)VKQ_meta) [i];
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
@ -644,11 +662,11 @@ static __global__ void flash_attn_combine_results(
|
|||
const uint32_t ftz_mask = 0xFFFFFFFF * (diff > SOFTMAX_FTZ_THRESHOLD);
|
||||
*((uint32_t *) &KQ_max_scale) &= ftz_mask;
|
||||
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*gridDim.z*D + blockIdx.z*D + tid];
|
||||
VKQ_numerator += KQ_max_scale * VKQ_parts[l*D + tid];
|
||||
VKQ_denominator += KQ_max_scale * meta[l].y;
|
||||
}
|
||||
|
||||
dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator;
|
||||
dst[tid] = VKQ_numerator / VKQ_denominator;
|
||||
}
|
||||
|
||||
[[noreturn]]
|
||||
|
@ -705,8 +723,6 @@ void launch_fattn(
|
|||
|
||||
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
|
||||
|
||||
GGML_ASSERT(Q->ne[3] == 1);
|
||||
|
||||
ggml_cuda_pool & pool = ctx.pool();
|
||||
cudaStream_t main_stream = ctx.stream();
|
||||
const int id = ggml_cuda_get_device();
|
||||
|
@ -853,8 +869,8 @@ void launch_fattn(
|
|||
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
|
||||
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0,
|
||||
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0,
|
||||
mask ? mask->ne[1] : 0, mask ? mask->ne[2] : 0, mask ? mask->ne[3] : 0,
|
||||
mask ? mask->nb[1] : 0, mask ? mask->nb[2] : 0, mask ? mask->nb[3] : 0,
|
||||
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||
nb11, nb12, nb13,
|
||||
nb21, nb22, nb23,
|
||||
|
@ -869,11 +885,11 @@ void launch_fattn(
|
|||
|
||||
flash_attn_stream_k_fixup<DV, ncols1, ncols2>
|
||||
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
|
||||
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], Q->ne[3], K->ne[1]);
|
||||
}
|
||||
} else if (parallel_blocks > 1) {
|
||||
const dim3 block_dim_combine(DV, 1, 1);
|
||||
const dim3 blocks_num_combine(Q->ne[1], 1, blocks_num.z);
|
||||
const dim3 blocks_num_combine(Q->ne[1], Q->ne[2], Q->ne[3]);
|
||||
const size_t nbytes_shared_combine = parallel_blocks*sizeof(float2);
|
||||
|
||||
flash_attn_combine_results<DV>
|
||||
|
|
|
@ -1224,8 +1224,10 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -1274,8 +1276,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
constexpr int kb_niter = FATTN_KQ_STRIDE / c::nbatch_fa; // Number of kernel iterations per assigned KQ slice.
|
||||
|
||||
// kbc == k block continuous, current index in continuous ijk space.
|
||||
int kbc = (blockIdx.x + 0)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*(ne02/ncols2) / gridDim.x;
|
||||
int kbc = (blockIdx.x + 0)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
const int kbc_stop = (blockIdx.x + 1)*(iter_k*iter_j*(ne02/ncols2)*ne03) / gridDim.x;
|
||||
|
||||
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
|
||||
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
|
||||
|
@ -1285,18 +1287,19 @@ static __global__ void flash_attn_ext_f16(
|
|||
int kb0_start = kbc % iter_k;
|
||||
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
|
||||
while (kbc < kbc_stop && kb0_stop == iter_k) {
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
||||
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
|
@ -1325,18 +1328,19 @@ static __global__ void flash_attn_ext_f16(
|
|||
return;
|
||||
}
|
||||
|
||||
const int channel = kbc / (iter_k*iter_j);
|
||||
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
|
||||
const int sequence = kbc / (iter_k*iter_j*(ne02/ncols2));
|
||||
const int head = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence) / (iter_k*iter_j);
|
||||
const int jt = (kbc - iter_k*iter_j*(ne02/ncols2)*sequence - iter_k*iter_j*head) / iter_k; // j index of current tile.
|
||||
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel*ncols2);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(channel*ncols2 / gqa_ratio));
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03*sequence + nb02*(head*ncols2));
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head*ncols2 / gqa_ratio));
|
||||
const half2 * mask_h2 = ncols2 == 1 && !mask ? nullptr :
|
||||
(const half2 *) (mask + nb32*(channel % ne32) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + channel*(ncols2 * DV/2);
|
||||
(const half2 *) (mask + nb33*(sequence % ne33) + nb31*jt*ncols1);
|
||||
float2 * dstk = ((float2 *) dst) + (sequence*ne01*ne02 + head*ncols2) * (DV/2);
|
||||
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb22*(channel*ncols2 / gqa_ratio));
|
||||
const half2 * V_h2 = mla ? K_h2 + (DKQ/2 - DV/2) : (const half2 *) (V + nb23*sequence + nb22*(head*ncols2 / gqa_ratio));
|
||||
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f;
|
||||
const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head, n_head_log2, m0, m1) : 1.0f;
|
||||
|
||||
const int kb0_start_kernel = kb0_start * kb_niter;
|
||||
const int kb0_stop_kernel = kb0_stop * kb_niter;
|
||||
|
|
|
@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -62,15 +64,17 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
@ -255,6 +259,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
__syncthreads();
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
|
@ -266,21 +272,21 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
|
||||
kqsum_j = warp_reduce_sum((float)kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
||||
const int i0 = i00 + threadIdx.x;
|
||||
|
||||
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
||||
if (gridDim.y == 1) {
|
||||
dst_val /= __half2half2(kqsum_j);
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
|
||||
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
@ -290,8 +296,8 @@ static __global__ void flash_attn_tile_ext_f16(
|
|||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
|
|
|
@ -31,8 +31,10 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -74,15 +76,17 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const int stride_KV2 = nb11 / sizeof(half2);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
||||
|
@ -265,6 +269,8 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
__syncthreads();
|
||||
}
|
||||
|
||||
float2 * dst2 = (float2 *) dst;
|
||||
|
||||
#pragma unroll
|
||||
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
|
||||
const int j_VKQ = j_VKQ_0 + threadIdx.y;
|
||||
|
@ -276,22 +282,22 @@ static __global__ void flash_attn_tile_ext_f32(
|
|||
float kqsum_j = kqsum[j_VKQ_0/nwarps];
|
||||
kqsum_j = warp_reduce_sum(kqsum_j);
|
||||
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D; i00 += 2*WARP_SIZE) {
|
||||
const int i0 = i00 + 2*threadIdx.x;
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
|
||||
#pragma unroll
|
||||
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
|
||||
const int i0 = i00 + threadIdx.x;
|
||||
|
||||
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
|
||||
if (gridDim.y == 1) {
|
||||
dst_val.x /= kqsum_j;
|
||||
dst_val.y /= kqsum_j;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
|
||||
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && threadIdx.x == 0) {
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
|
||||
}
|
||||
}
|
||||
#else
|
||||
|
|
|
@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -65,14 +67,16 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.z + nb01*ic0;
|
||||
K += nb12*(blockIdx.z / gqa_ratio);
|
||||
V += nb22*(blockIdx.z / gqa_ratio);
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
|
@ -330,12 +334,11 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
|
@ -344,8 +347,8 @@ static __global__ void flash_attn_vec_ext_f16(
|
|||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
|
|
|
@ -28,8 +28,10 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -53,8 +55,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
|
||||
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
|
||||
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
|
||||
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
|
||||
|
@ -77,14 +79,16 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
|
||||
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
Q += nb02* blockIdx.z + nb01*ic0;
|
||||
K += nb12*(blockIdx.z / gqa_ratio);
|
||||
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
|
||||
Q += nb03*sequence + nb02* head + nb01*ic0;
|
||||
K += nb13*sequence + nb12*(head / gqa_ratio);
|
||||
V += nb23*sequence + nb22*(head / gqa_ratio);
|
||||
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
|
||||
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
|
||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||
constexpr int nwarps = D / WARP_SIZE;
|
||||
|
@ -326,12 +330,11 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
if (gridDim.y == 1) {
|
||||
dst_val /= kqsum[j_VKQ];
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
|
||||
dst[(((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y)*D + tid] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
|
||||
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
dst_meta[((sequence*ne01 + ic0 + tid)*ne02 + head)*gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
|
@ -340,8 +343,8 @@ static __global__ void flash_attn_vec_ext_f32(
|
|||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33);
|
||||
GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb33);
|
||||
GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
|
||||
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
|
|
|
@ -47,8 +47,10 @@ static __global__ void flash_attn_ext_f16(
|
|||
const int ne13,
|
||||
const int ne31,
|
||||
const int ne32,
|
||||
const int ne33,
|
||||
const int nb31,
|
||||
const int nb32,
|
||||
const int nb33,
|
||||
const int nb01,
|
||||
const int nb02,
|
||||
const int nb03,
|
||||
|
@ -95,17 +97,19 @@ static __global__ void flash_attn_ext_f16(
|
|||
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
|
||||
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
|
||||
|
||||
const int sequence = blockIdx.z / ne02;
|
||||
const int head = blockIdx.z - sequence*ne02;
|
||||
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
|
||||
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb32*(blockIdx.z % ne32) + nb31*ic0);
|
||||
const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0);
|
||||
const half * K_h = (const half *) (K + nb13* sequence + nb12*(head / gqa_ratio));
|
||||
const half * V_h = (const half *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
|
||||
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
|
||||
const half2 * mask2 = (const half2 *) maskh;
|
||||
|
||||
const int stride_Q = nb01 / sizeof(float);
|
||||
const int stride_KV = nb11 / sizeof(half);
|
||||
|
||||
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
|
||||
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
|
||||
const half slopeh = __float2half(slopef);
|
||||
const half2 slope2 = make_half2(slopef, slopef);
|
||||
|
||||
|
@ -400,7 +404,6 @@ static __global__ void flash_attn_ext_f16(
|
|||
if (ic0 + j_VKQ >= ne01) {
|
||||
return;
|
||||
}
|
||||
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
|
||||
|
||||
float KQ_rowsum_j;
|
||||
if (std::is_same<KQ_acc_t, float>::value) {
|
||||
|
@ -409,6 +412,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
|
||||
}
|
||||
|
||||
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
|
||||
|
||||
#pragma unroll
|
||||
for (int i0 = 0; i0 < D; i0 += warp_size) {
|
||||
const int i = i0 + threadIdx.x;
|
||||
|
@ -419,7 +424,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
if (gridDim.y == 1) {
|
||||
dst_val /= KQ_rowsum_j;
|
||||
}
|
||||
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
|
||||
dst[j_dst_unrolled*D + i] = dst_val;
|
||||
}
|
||||
|
||||
if (gridDim.y == 1 || threadIdx.x != 0) {
|
||||
|
@ -433,7 +438,7 @@ static __global__ void flash_attn_ext_f16(
|
|||
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
|
||||
}
|
||||
dst_meta_val.y = KQ_rowsum_j;
|
||||
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
|
||||
dst_meta[j_dst_unrolled] = dst_meta_val;
|
||||
}
|
||||
#else
|
||||
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
|
||||
|
@ -442,7 +447,8 @@ static __global__ void flash_attn_ext_f16(
|
|||
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
|
||||
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
|
||||
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(nb31); GGML_UNUSED(nb32); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(ne31); GGML_UNUSED(ne32); GGML_UNUSED(ne33); GGML_UNUSED(nb31);
|
||||
GGML_UNUSED(nb32); GGML_UNUSED(nb33); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
|
||||
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
|
||||
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
|
||||
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
|
||||
|
|
|
@ -3413,12 +3413,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
if (op->src[0]->ne[0] == 192) {
|
||||
return false;
|
||||
}
|
||||
// TODO: support broadcast
|
||||
// note: this was initially implemented in https://github.com/ggml-org/llama.cpp/pull/14500, but
|
||||
// the interface of ggml_flash_attn_ext() changed in https://github.com/ggml-org/llama.cpp/pull/14505
|
||||
if (op->src[0]->ne[3] != 1) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
|
||||
return false;
|
||||
}
|
||||
|
@ -3431,6 +3425,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
|||
if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
|
||||
return true;
|
||||
}
|
||||
if (op->src[3] && op->src[3]->ne[2] != 1) {
|
||||
return false;
|
||||
}
|
||||
return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
|
||||
op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
|
||||
}
|
||||
|
|
|
@ -335,6 +335,9 @@ extern "C" {
|
|||
bool swa_full; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
|
||||
// NOTE: setting to false when n_seq_max > 1 can cause bad performance in some cases
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/13845#issuecomment-2924800573
|
||||
bool kv_unified; // use a unified buffer across the input sequences when computing the attention
|
||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||
};
|
||||
|
||||
// model quantization parameters
|
||||
|
|
|
@ -27,6 +27,7 @@ bool llama_batch_allocr::init(
|
|||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all) {
|
||||
clear();
|
||||
|
||||
|
@ -40,6 +41,11 @@ bool llama_batch_allocr::init(
|
|||
// validate input batch
|
||||
//
|
||||
|
||||
if (n_seq_max > LLAMA_MAX_SEQ) {
|
||||
LLAMA_LOG_ERROR("%s: n_seq_max = %d > %d\n", __func__, n_seq_max, LLAMA_MAX_SEQ);
|
||||
return false;
|
||||
}
|
||||
|
||||
if (batch.token) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
if (batch.token[i] < 0 || (uint32_t) batch.token[i] >= vocab.n_tokens()) {
|
||||
|
@ -52,8 +58,8 @@ bool llama_batch_allocr::init(
|
|||
if (batch.seq_id) {
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
for (int32_t s = 0; s < batch.n_seq_id[i]; ++s) {
|
||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= LLAMA_MAX_SEQ)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], LLAMA_MAX_SEQ);
|
||||
if (batch.seq_id && (batch.seq_id[i][s] < 0 || batch.seq_id[i][s] >= (llama_seq_id) n_seq_max)) {
|
||||
LLAMA_LOG_ERROR("%s: invalid seq_id[%d][%d] = %d > %d\n", __func__, i, s, batch.seq_id[i][s], (llama_seq_id) n_seq_max);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
@ -86,7 +92,7 @@ bool llama_batch_allocr::init(
|
|||
|
||||
// initialize the starting position for each sequence based on the positions in the memory
|
||||
llama_pos p0[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (!memory) {
|
||||
// if no memory -> start from 0
|
||||
p0[s] = 0;
|
||||
|
@ -143,7 +149,8 @@ bool llama_batch_allocr::init(
|
|||
// compute stats
|
||||
//
|
||||
|
||||
this->n_embd = n_embd;
|
||||
this->n_embd = n_embd;
|
||||
this->n_seq_max = n_seq_max;
|
||||
|
||||
// count the outputs in this batch
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
|
@ -189,7 +196,7 @@ bool llama_batch_allocr::init(
|
|||
seq_set_map[cur].push_back(i);
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
seq_idx[s] = seq_id_unq.size();
|
||||
seq_id_unq.push_back(s);
|
||||
|
@ -241,7 +248,7 @@ bool llama_batch_allocr::init(
|
|||
// consistency checks
|
||||
//
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_pos[s].empty()) {
|
||||
continue;
|
||||
}
|
||||
|
@ -284,8 +291,8 @@ bool llama_batch_allocr::init(
|
|||
}
|
||||
|
||||
if (memory) {
|
||||
for (int32_t s0 = 0; s0 < LLAMA_MAX_SEQ; ++s0) {
|
||||
for (int32_t s1 = 0; s1 < LLAMA_MAX_SEQ; ++s1) {
|
||||
for (uint32_t s0 = 0; s0 < n_seq_max; ++s0) {
|
||||
for (uint32_t s1 = 0; s1 < n_seq_max; ++s1) {
|
||||
if (seq_cpl[s0][s1]) {
|
||||
if (memory->seq_pos_min(s0) != memory->seq_pos_min(s1) ||
|
||||
memory->seq_pos_max(s0) != memory->seq_pos_max(s1)) {
|
||||
|
@ -316,12 +323,12 @@ bool llama_batch_allocr::init(
|
|||
//
|
||||
{
|
||||
seq_set_t cur_seq_set[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
cur_seq_set[s].set();
|
||||
}
|
||||
|
||||
llama_pos cur_seq_pos[LLAMA_MAX_SEQ];
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
cur_seq_pos[s] = -1;
|
||||
}
|
||||
|
||||
|
@ -692,7 +699,7 @@ llama_ubatch llama_batch_allocr::ubatch_add(const std::vector<int32_t> & idxs, u
|
|||
}
|
||||
}
|
||||
|
||||
for (int32_t s = 0; s < LLAMA_MAX_SEQ; ++s) {
|
||||
for (uint32_t s = 0; s < n_seq_max; ++s) {
|
||||
if (seq_set_unq.test(s)) {
|
||||
ubatch.seq_idx[s] = ubatch.seq_id_unq.size();
|
||||
ubatch.seq_id_unq.push_back(s);
|
||||
|
|
|
@ -48,6 +48,7 @@ public:
|
|||
const llama_vocab & vocab,
|
||||
const llama_memory_i * memory,
|
||||
uint32_t n_embd,
|
||||
uint32_t n_seq_max,
|
||||
bool output_all);
|
||||
|
||||
const llama_batch & get_batch() const;
|
||||
|
@ -100,6 +101,7 @@ private:
|
|||
const uint32_t n_pos_per_embd;
|
||||
|
||||
uint32_t n_embd;
|
||||
uint32_t n_seq_max;
|
||||
uint32_t n_outputs;
|
||||
|
||||
std::array<llama_seq_id, 1> seq_id_0 = { 0 }; // default sequence id
|
||||
|
|
|
@ -98,10 +98,20 @@ llama_context::llama_context(
|
|||
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
|
||||
cparams.n_batch = GGML_KQ_MASK_PAD;
|
||||
}
|
||||
|
||||
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
|
||||
|
||||
cparams.op_offload = params.op_offload;
|
||||
cparams.kv_unified = params.kv_unified;
|
||||
|
||||
{
|
||||
const char * LLAMA_SET_ROWS = getenv("LLAMA_SET_ROWS");
|
||||
const bool supports_set_rows = LLAMA_SET_ROWS ? atoi(LLAMA_SET_ROWS) : 0;
|
||||
|
||||
if (!supports_set_rows && !cparams.kv_unified) {
|
||||
LLAMA_LOG_WARN("%s: non-unified KV cache requires ggml_set_rows() - forcing unified KV cache\n", __func__);
|
||||
cparams.kv_unified = true;
|
||||
}
|
||||
}
|
||||
|
||||
const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max;
|
||||
|
||||
|
@ -112,6 +122,7 @@ llama_context::llama_context(
|
|||
LLAMA_LOG_INFO("%s: n_ubatch = %u\n", __func__, cparams.n_ubatch);
|
||||
LLAMA_LOG_INFO("%s: causal_attn = %d\n", __func__, cparams.causal_attn);
|
||||
LLAMA_LOG_INFO("%s: flash_attn = %d\n", __func__, cparams.flash_attn);
|
||||
LLAMA_LOG_INFO("%s: kv_unified = %s\n", __func__, cparams.kv_unified ? "true" : "false");
|
||||
LLAMA_LOG_INFO("%s: freq_base = %.1f\n", __func__, cparams.rope_freq_base);
|
||||
LLAMA_LOG_INFO("%s: freq_scale = %g\n", __func__, cparams.rope_freq_scale);
|
||||
|
||||
|
@ -267,7 +278,7 @@ llama_context::llama_context(
|
|||
|
||||
// reserve worst-case graph
|
||||
if (!hparams.vocab_only && memory) {
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
|
@ -300,7 +311,7 @@ llama_context::llama_context(
|
|||
|
||||
// reserve with tg graph to get the number of splits and nodes
|
||||
{
|
||||
auto * gf = graph_reserve(1, 1, 1, mctx.get());
|
||||
auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute tg buffers");
|
||||
}
|
||||
|
@ -311,6 +322,10 @@ llama_context::llama_context(
|
|||
|
||||
// reserve again with pp graph to avoid ggml-alloc reallocations during inference
|
||||
{
|
||||
// TODO: not sure if the following graph would be worster case for multi-stream KV caches:
|
||||
//
|
||||
// auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get());
|
||||
//
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
if (!gf) {
|
||||
throw std::runtime_error("failed to allocate compute pp buffers");
|
||||
|
@ -475,7 +490,7 @@ bool llama_context::kv_self_update(bool optimize) {
|
|||
throw std::runtime_error("failed to initialize memory context");
|
||||
}
|
||||
|
||||
const uint32_t n_seqs = cparams.n_seq_max;
|
||||
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
|
||||
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
|
||||
|
||||
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
|
||||
|
@ -735,13 +750,15 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
const int32_t n_vocab = model.vocab.n_tokens();
|
||||
|
||||
// note: during encode, we always pass the full sequence starting from pos = 0
|
||||
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, true)) {
|
||||
if (!balloc->init(batch_inp, model.vocab, nullptr, n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
||||
const uint32_t n_tokens = balloc->get_n_tokens();
|
||||
|
||||
// [TAG_NO_CACHE_PAD]
|
||||
// TODO: add new split mode where we pad the input sequences so that ubatch.equal_seqs == true
|
||||
const llama_ubatch ubatch = balloc->split_simple(n_tokens);
|
||||
|
||||
// micro-batching is not possible for non-causal encoding, so we process the batch in a single shot
|
||||
|
@ -910,7 +927,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, output_all)) {
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return -1;
|
||||
}
|
||||
|
@ -2039,7 +2056,7 @@ void llama_context::opt_epoch_iter(
|
|||
batch.logits [pos_batch] = true;
|
||||
}
|
||||
|
||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, true)) {
|
||||
if (!balloc->init(batch, model.vocab, nullptr, model.hparams.n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, true)) {
|
||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||
return;
|
||||
}
|
||||
|
@ -2198,6 +2215,7 @@ llama_context_params llama_context_default_params() {
|
|||
/*.no_perf =*/ true,
|
||||
/*.op_offload =*/ true,
|
||||
/*.swa_full =*/ true,
|
||||
/*.kv_unified =*/ false,
|
||||
};
|
||||
|
||||
return result;
|
||||
|
|
|
@ -11,8 +11,8 @@ struct llama_cparams {
|
|||
uint32_t n_batch;
|
||||
uint32_t n_ubatch;
|
||||
uint32_t n_seq_max;
|
||||
int n_threads; // number of threads to use for generation
|
||||
int n_threads_batch; // number of threads to use for batch processing
|
||||
int32_t n_threads; // number of threads to use for generation
|
||||
int32_t n_threads_batch; // number of threads to use for batch processing
|
||||
|
||||
float rope_freq_base;
|
||||
float rope_freq_scale;
|
||||
|
@ -33,6 +33,7 @@ struct llama_cparams {
|
|||
bool no_perf;
|
||||
bool warmup;
|
||||
bool op_offload;
|
||||
bool kv_unified;
|
||||
|
||||
enum llama_pooling_type pooling_type;
|
||||
|
||||
|
|
|
@ -982,13 +982,16 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
float kq_scale) const {
|
||||
const bool v_trans = v->nb[1] > v->nb[2];
|
||||
|
||||
// split the batch into streams if needed
|
||||
const auto n_stream = k->ne[3];
|
||||
|
||||
q = ggml_reshape_4d(ctx0, q, q->ne[0], q->ne[1], q->ne[2]/n_stream, n_stream);
|
||||
|
||||
q = ggml_permute(ctx0, q, 0, 2, 1, 3);
|
||||
k = ggml_permute(ctx0, k, 0, 2, 1, 3);
|
||||
v = ggml_permute(ctx0, v, 0, 2, 1, 3);
|
||||
|
||||
const auto n_tokens = q->ne[1];
|
||||
const auto n_head = q->ne[2];
|
||||
const auto n_kv = k->ne[1];
|
||||
const auto n_kv = k->ne[1];
|
||||
|
||||
ggml_tensor * cur;
|
||||
|
||||
|
@ -1030,7 +1033,7 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
#endif
|
||||
}
|
||||
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
cur = ggml_reshape_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
||||
} else {
|
||||
ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
|
||||
|
||||
|
@ -1075,7 +1078,8 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
|
||||
cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3);
|
||||
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens);
|
||||
// recombine streams
|
||||
cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*cur->ne[1], cur->ne[2]*cur->ne[3]);
|
||||
|
||||
if (!cparams.offload_kqv) {
|
||||
// all nodes between the KV store and the attention output are run on the CPU
|
||||
|
@ -1122,6 +1126,10 @@ ggml_tensor * llm_graph_context::build_attn(
|
|||
|
||||
const auto & kq_mask = inp->get_kq_mask();
|
||||
|
||||
// [TAG_NO_CACHE_PAD]
|
||||
// TODO: if ubatch.equal_seqs == true, we can split the three tensors below into ubatch.n_seqs_unq streams
|
||||
assert(ubatch.equal_seqs == false);
|
||||
|
||||
ggml_tensor * q = q_cur;
|
||||
ggml_tensor * k = k_cur;
|
||||
ggml_tensor * v = v_cur;
|
||||
|
@ -1156,13 +1164,14 @@ static std::unique_ptr<llm_graph_input_attn_kv_unified> build_attn_inp_kv_unifie
|
|||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_unified_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
@ -1362,13 +1371,15 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_unified_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
{
|
||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||
|
||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
@ -1382,7 +1393,7 @@ llm_graph_input_attn_kv_unified_iswa * llm_graph_context::build_attn_inp_kv_unif
|
|||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), 1, 1);
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
|
|
|
@ -255,10 +255,10 @@ public:
|
|||
ggml_tensor * get_kq_mask() const { return self_kq_mask_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
@ -289,14 +289,14 @@ public:
|
|||
ggml_tensor * get_kq_mask_swa() const { return self_kq_mask_swa_cnv; }
|
||||
|
||||
ggml_tensor * self_k_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
ggml_tensor * self_k_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch]
|
||||
ggml_tensor * self_v_idxs_swa = nullptr; // I64 [n_batch] or [n_batch*n_embd_v_gqa]
|
||||
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch, 1, 1]
|
||||
ggml_tensor * self_kq_mask = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa = nullptr; // F32 [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
ggml_tensor * self_kq_mask_swa_cnv = nullptr; // [n_kv, n_batch/n_stream, 1, n_stream]
|
||||
|
||||
const llama_hparams & hparams;
|
||||
const llama_cparams & cparams;
|
||||
|
|
|
@ -65,6 +65,46 @@ uint32_t llama_hparams::n_embd_v_gqa(uint32_t il) const {
|
|||
return n_embd_head_v * n_head_kv;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_n_embd_k_gqa_variable() const {
|
||||
const uint32_t val = n_embd_k_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (val != n_embd_k_gqa(il)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
bool llama_hparams::is_n_embd_v_gqa_variable() const {
|
||||
const uint32_t val = n_embd_v_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
if (val != n_embd_v_gqa(il)) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_k_gqa_max() const {
|
||||
uint32_t val = n_embd_k_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
val = std::max(val, n_embd_k_gqa(il));
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_v_gqa_max() const {
|
||||
uint32_t val = n_embd_v_gqa();
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
val = std::max(val, n_embd_v_gqa(il));
|
||||
}
|
||||
|
||||
return val;
|
||||
}
|
||||
|
||||
uint32_t llama_hparams::n_embd_r() const {
|
||||
if (wkv_head_size != 0) {
|
||||
// for RWKV models
|
||||
|
|
|
@ -191,6 +191,14 @@ struct llama_hparams {
|
|||
// dimension of value embeddings across all k-v heads
|
||||
uint32_t n_embd_v_gqa(uint32_t il = 0) const;
|
||||
|
||||
// true if any layer has a different n_embd_k_gqa/n_embd_v_gqa
|
||||
bool is_n_embd_k_gqa_variable() const;
|
||||
bool is_n_embd_v_gqa_variable() const;
|
||||
|
||||
// return the maximum n_embd_k_gqa/n_embd_v_gqa across all layers
|
||||
uint32_t n_embd_k_gqa_max() const;
|
||||
uint32_t n_embd_v_gqa_max() const;
|
||||
|
||||
// dimension of the rolling state embeddings
|
||||
// corresponds to Mamba's conv_states size or RWKV's token_shift states size
|
||||
uint32_t n_embd_r() const;
|
||||
|
|
|
@ -18,16 +18,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
uint32_t n_pad) : hparams(model.hparams) {
|
||||
uint32_t n_pad) : hparams(model.hparams), unified(unified) {
|
||||
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
|
||||
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
|
||||
|
||||
const uint32_t size_base = kv_size;
|
||||
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_ubatch, n_pad));
|
||||
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*(unified ? n_seq_max : 1) + n_ubatch, n_pad));
|
||||
|
||||
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size
|
||||
if (swa_full) {
|
||||
|
@ -41,14 +42,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
|
|||
|
||||
kv_base = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_base), type_k, type_v,
|
||||
v_trans, offload, size_base, n_seq_max, n_pad,
|
||||
v_trans, offload, unified, size_base, n_seq_max, n_pad,
|
||||
0, LLAMA_SWA_TYPE_NONE);
|
||||
|
||||
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
|
||||
|
||||
kv_swa = std::make_unique<llama_kv_cache_unified>(
|
||||
model, std::move(filter_swa), type_k, type_v,
|
||||
v_trans, offload, size_swa, n_seq_max, n_pad,
|
||||
v_trans, offload, unified, size_swa, n_seq_max, n_pad,
|
||||
hparams.n_swa, hparams.swa_type);
|
||||
}
|
||||
|
||||
|
@ -100,6 +101,11 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
|
||||
// first try simple split
|
||||
do {
|
||||
if (!unified) {
|
||||
// requires equal splits, so we skip the simple split
|
||||
break;
|
||||
}
|
||||
|
||||
balloc.split_reset();
|
||||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
|
@ -140,7 +146,7 @@ llama_memory_context_ptr llama_kv_cache_unified_iswa::init_batch(llama_batch_all
|
|||
|
||||
std::vector<llama_ubatch> ubatches;
|
||||
while (true) {
|
||||
auto ubatch = balloc.split_equal(n_ubatch, false);
|
||||
auto ubatch = balloc.split_equal(n_ubatch, !unified);
|
||||
|
||||
if (ubatch.n_tokens == 0) {
|
||||
break;
|
||||
|
|
|
@ -20,6 +20,7 @@ public:
|
|||
bool v_trans,
|
||||
bool offload,
|
||||
bool swa_full,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_ubatch,
|
||||
|
@ -68,6 +69,8 @@ public:
|
|||
private:
|
||||
const llama_hparams & hparams;
|
||||
|
||||
const bool unified;
|
||||
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_base;
|
||||
std::unique_ptr<llama_kv_cache_unified> kv_swa;
|
||||
};
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -35,16 +35,50 @@ public:
|
|||
std::vector<uint32_t> ids;
|
||||
};
|
||||
|
||||
struct stream_copy_info {
|
||||
bool empty() const {
|
||||
assert(ssrc.size() == sdst.size());
|
||||
return ssrc.empty();
|
||||
}
|
||||
|
||||
std::vector<uint32_t> ssrc;
|
||||
std::vector<uint32_t> sdst;
|
||||
};
|
||||
|
||||
// for each ubatch, create a slot_info that contains information about where the ubatch should be inserted in the
|
||||
// KV cells. for example, cell indices for each token, such that: token[i] -> goes to cells[idxs[i]]
|
||||
struct slot_info {
|
||||
// data for ggml_set_rows
|
||||
using idx_vec_t = std::vector<uint32_t>;
|
||||
|
||||
idx_vec_t idxs;
|
||||
// number of streams: ns = s1 - s0 + 1
|
||||
llama_seq_id s0;
|
||||
llama_seq_id s1;
|
||||
|
||||
std::vector<llama_seq_id> strm; // [ns]
|
||||
std::vector<idx_vec_t> idxs; // [ns]
|
||||
|
||||
uint32_t head() const {
|
||||
return idxs.at(0);
|
||||
GGML_ASSERT(idxs.size() == 1);
|
||||
GGML_ASSERT(!idxs[0].empty());
|
||||
|
||||
return idxs[0][0];
|
||||
}
|
||||
|
||||
void resize(size_t n) {
|
||||
strm.resize(n);
|
||||
idxs.resize(n);
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
GGML_ASSERT(idxs.size() == strm.size());
|
||||
GGML_ASSERT(!idxs.empty());
|
||||
|
||||
return idxs[0].size();
|
||||
}
|
||||
|
||||
size_t n_stream() const {
|
||||
return strm.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
|
@ -54,9 +88,6 @@ public:
|
|||
void clear() {
|
||||
idxs.clear();
|
||||
}
|
||||
|
||||
// TODO: implement
|
||||
//std::vector<idx_vec_t> seq_idxs;
|
||||
};
|
||||
|
||||
using slot_info_vec_t = std::vector<slot_info>;
|
||||
|
@ -68,6 +99,7 @@ public:
|
|||
ggml_type type_v,
|
||||
bool v_trans,
|
||||
bool offload,
|
||||
bool unified,
|
||||
uint32_t kv_size,
|
||||
uint32_t n_seq_max,
|
||||
uint32_t n_pad,
|
||||
|
@ -111,7 +143,8 @@ public:
|
|||
// llama_kv_cache_unified specific API
|
||||
//
|
||||
|
||||
uint32_t get_size() const;
|
||||
uint32_t get_size() const;
|
||||
uint32_t get_n_stream() const;
|
||||
|
||||
bool get_has_shift() const;
|
||||
|
||||
|
@ -122,8 +155,8 @@ public:
|
|||
uint32_t get_n_kv() const;
|
||||
|
||||
// get views of the current state of the cache
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv) const;
|
||||
ggml_tensor * get_k(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
ggml_tensor * get_v(ggml_context * ctx, int32_t il, uint32_t n_kv, const slot_info & sinfo) const;
|
||||
|
||||
// store k_cur and v_cur in the cache based on the provided head location
|
||||
ggml_tensor * cpy_k(ggml_context * ctx, ggml_tensor * k_cur, ggml_tensor * k_idxs, int32_t il, const slot_info & sinfo) const;
|
||||
|
@ -137,7 +170,7 @@ public:
|
|||
// return empty vector on failure
|
||||
slot_info_vec_t prepare(const std::vector<llama_ubatch> & ubatches);
|
||||
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo);
|
||||
bool update(llama_context * lctx, bool do_shift, const defrag_info & dinfo, const stream_copy_info & sc_info);
|
||||
|
||||
// find a slot of kv cells that can hold the ubatch
|
||||
// if cont == true, then the slot must be continuous
|
||||
|
@ -157,8 +190,9 @@ public:
|
|||
void set_input_k_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
void set_input_v_idxs(ggml_tensor * dst, const llama_ubatch * ubatch, const slot_info & sinfo) const;
|
||||
|
||||
void set_input_k_shift(ggml_tensor * dst) const;
|
||||
|
||||
void set_input_kq_mask (ggml_tensor * dst, const llama_ubatch * ubatch, bool causal_attn) const;
|
||||
void set_input_k_shift (ggml_tensor * dst) const;
|
||||
void set_input_pos_bucket(ggml_tensor * dst, const llama_ubatch * ubatch) const;
|
||||
|
||||
private:
|
||||
|
@ -172,15 +206,15 @@ private:
|
|||
|
||||
ggml_tensor * k;
|
||||
ggml_tensor * v;
|
||||
|
||||
std::vector<ggml_tensor *> k_stream;
|
||||
std::vector<ggml_tensor *> v_stream;
|
||||
};
|
||||
|
||||
bool v_trans = true; // the value tensor is transposed
|
||||
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
uint32_t head = 0;
|
||||
|
||||
const uint32_t n_seq_max = 1;
|
||||
const uint32_t n_stream = 1;
|
||||
|
||||
// required padding
|
||||
const uint32_t n_pad = 1;
|
||||
|
@ -200,7 +234,17 @@ private:
|
|||
std::vector<ggml_context_ptr> ctxs;
|
||||
std::vector<ggml_backend_buffer_ptr> bufs;
|
||||
|
||||
llama_kv_cells_unified cells;
|
||||
// the current index from where we start searching for a free slot in the ring buffer of KV cells (see find_slot())
|
||||
// note: this is not part of the KV state and it's only used to speed-up the find_slot() method
|
||||
std::vector<uint32_t> v_heads;
|
||||
|
||||
std::vector<llama_kv_cells_unified> v_cells;
|
||||
|
||||
// maps from a sequence id to a stream id
|
||||
std::vector<uint32_t> seq_to_stream;
|
||||
|
||||
// pending stream copies that will be applied during the next update
|
||||
stream_copy_info sc_info;
|
||||
|
||||
std::vector<kv_layer> layers;
|
||||
|
||||
|
@ -237,18 +281,25 @@ private:
|
|||
ggml_cgraph * gf,
|
||||
const defrag_info & dinfo) const;
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) const;
|
||||
struct cell_ranges_t {
|
||||
uint32_t strm;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t cell_count);
|
||||
std::vector<std::pair<uint32_t, uint32_t>> data; // ranges, from inclusive, to exclusive
|
||||
};
|
||||
|
||||
void state_write_meta(llama_io_write_i & io, const cell_ranges_t & cr, llama_seq_id seq_id = -1) const;
|
||||
void state_write_data(llama_io_write_i & io, const cell_ranges_t & cr) const;
|
||||
|
||||
bool state_read_meta(llama_io_read_i & io, uint32_t strm, uint32_t cell_count, llama_seq_id dest_seq_id = -1);
|
||||
bool state_read_data(llama_io_read_i & io, uint32_t strm, uint32_t cell_count);
|
||||
};
|
||||
|
||||
class llama_kv_cache_unified_context : public llama_memory_context_i {
|
||||
public:
|
||||
// some shorthands
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
using slot_info_vec_t = llama_kv_cache_unified::slot_info_vec_t;
|
||||
using defrag_info = llama_kv_cache_unified::defrag_info;
|
||||
using stream_copy_info = llama_kv_cache_unified::stream_copy_info;
|
||||
|
||||
// used for errors
|
||||
llama_kv_cache_unified_context(llama_memory_status status);
|
||||
|
@ -262,7 +313,8 @@ public:
|
|||
llama_kv_cache_unified * kv,
|
||||
llama_context * lctx,
|
||||
bool do_shift,
|
||||
defrag_info dinfo);
|
||||
defrag_info dinfo,
|
||||
stream_copy_info sc_info);
|
||||
|
||||
// used to create a batch procesing context from a batch
|
||||
llama_kv_cache_unified_context(
|
||||
|
@ -320,6 +372,8 @@ private:
|
|||
|
||||
defrag_info dinfo;
|
||||
|
||||
stream_copy_info sc_info;
|
||||
|
||||
//
|
||||
// batch processing context
|
||||
//
|
||||
|
|
|
@ -40,6 +40,7 @@ llama_memory_hybrid::llama_memory_hybrid(
|
|||
offload,
|
||||
kv_size,
|
||||
n_seq_max,
|
||||
1,
|
||||
n_pad,
|
||||
n_swa,
|
||||
swa_type
|
||||
|
|
|
@ -16647,7 +16647,18 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
} else {
|
||||
const auto padding = llama_kv_cache_unified::get_padding(cparams);
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||
uint32_t n_ctx_per_stream = cparams.n_ctx;
|
||||
|
||||
if (!cparams.kv_unified) {
|
||||
n_ctx_per_stream = (cparams.n_ctx + cparams.n_seq_max - 1)/cparams.n_seq_max;
|
||||
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
|
||||
|
||||
cparams.n_ctx = n_ctx_per_stream*cparams.n_seq_max;
|
||||
} else {
|
||||
n_ctx_per_stream = GGML_PAD(n_ctx_per_stream, padding);
|
||||
|
||||
cparams.n_ctx = n_ctx_per_stream;
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx);
|
||||
|
||||
|
@ -16661,7 +16672,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
params.swa_full,
|
||||
cparams.n_ctx,
|
||||
cparams.kv_unified,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_seq_max,
|
||||
cparams.n_ubatch,
|
||||
padding);
|
||||
|
@ -16675,7 +16687,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
params.type_v,
|
||||
!cparams.flash_attn,
|
||||
cparams.offload_kqv,
|
||||
cparams.n_ctx,
|
||||
cparams.kv_unified,
|
||||
n_ctx_per_stream,
|
||||
cparams.n_seq_max,
|
||||
padding,
|
||||
hparams.n_swa,
|
||||
|
|
|
@ -4282,7 +4282,7 @@ struct test_flash_attn_ext : public test_case {
|
|||
|
||||
ggml_tensor * m = nullptr;
|
||||
if (mask) {
|
||||
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), nr23[0], nr23[1]);
|
||||
m = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, nr23[1]);
|
||||
ggml_set_name(m, "m");
|
||||
}
|
||||
|
||||
|
|
|
@ -127,10 +127,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
|
||||
for (int i = 0; i < pp; ++i) {
|
||||
common_batch_add(batch, 0, i, { j }, false);
|
||||
common_batch_add(batch, 0, i, { j }, i == pp - 1);
|
||||
}
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
const auto t_pp_start = ggml_time_us();
|
||||
|
||||
|
|
Loading…
Reference in New Issue