ggml : implement REGLU/GEGLU/SWIGLU ops (#14158)
* implement unary REGLU/GEGLU/SWIGLU cpu ops * relax constraints * duplicate shape of source * fix ggml_vec_geglu_f16 * special case gated ops * implement unary REGLU/GEGLU/SWIGLU cuda ops * tighten constraints again * refactor into GGML_GLU_OP * metal : add glu kernels ggml-ci * add CUDA_GLU_BLOCK_SIZE [no ci] * more constraints and use 64bit ints ggml-ci * 64bit multiplication [no ci] * implement swapped variants (cpu/cuda) * update comment [no ci] ggml-ci * Vulkan: Add GLU ops and shaders * SYCL: Implement fused kernel GEGLU, SWIGLU and REGLU for single up+gate * ggml : implement GLU for split up/gate (#14181) * implement GLU for split up/gate * add tests for ggml_glu_split * Vulkan: Implement glu_split logic and shader support * add split to logging [no ci] * SYCL: refactor element_size ops and add split up and gate support to gated kernels * SYCL: switch GEGLU to use tanh approximation --------- Co-authored-by: 0cc4m <picard12@live.de> Co-authored-by: Akarshan <akarshan@menlo.ai> * GGML: increase OP count in assertion * Refactor: Optimize SYCL element-wise operations with unary function inlining This commit refactors the SYCL element-wise operations to improve performance by: - Inlining unary operations (sgn, abs, elu, gelu, silu, etc.) to reduce kernel launch overhead. - Introducing helper functions `op_xxx` for each unary operation to encapsulate the logic. - Replacing direct kernel calls with calls to these inlined functions. - Using `__dpct_inline__` to encourage compiler inlining. - Minor code cleanup and consistency improvements. The changes aim to reduce kernel launch overhead and improve the overall efficiency of element-wise operations on SYCL devices. * vulkan: Increase workgroup size for GLU, for performance (#14345) * vulkan: Increase workgroup size for GLU, for performance * vulkan: change GLU shaders to do one element per invocation rather than one row per workgroup * merge fix * metal : add support for split and swap ggml-ci --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: 0cc4m <picard12@live.de> Co-authored-by: Akarshan <akarshan@menlo.ai> Co-authored-by: Jeff Bolz <jbolz@nvidia.com>
This commit is contained in:
parent
bd9c981d72
commit
a0535ffa0d
|
@ -520,6 +520,8 @@ extern "C" {
|
||||||
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
GGML_OP_CROSS_ENTROPY_LOSS_BACK,
|
||||||
GGML_OP_OPT_STEP_ADAMW,
|
GGML_OP_OPT_STEP_ADAMW,
|
||||||
|
|
||||||
|
GGML_OP_GLU,
|
||||||
|
|
||||||
GGML_OP_COUNT,
|
GGML_OP_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -543,6 +545,14 @@ extern "C" {
|
||||||
GGML_UNARY_OP_COUNT,
|
GGML_UNARY_OP_COUNT,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
enum ggml_glu_op {
|
||||||
|
GGML_GLU_OP_REGLU,
|
||||||
|
GGML_GLU_OP_GEGLU,
|
||||||
|
GGML_GLU_OP_SWIGLU,
|
||||||
|
|
||||||
|
GGML_GLU_OP_COUNT,
|
||||||
|
};
|
||||||
|
|
||||||
enum ggml_object_type {
|
enum ggml_object_type {
|
||||||
GGML_OBJECT_TYPE_TENSOR,
|
GGML_OBJECT_TYPE_TENSOR,
|
||||||
GGML_OBJECT_TYPE_GRAPH,
|
GGML_OBJECT_TYPE_GRAPH,
|
||||||
|
@ -658,6 +668,7 @@ extern "C" {
|
||||||
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
GGML_API const char * ggml_op_symbol(enum ggml_op op);
|
||||||
|
|
||||||
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
|
GGML_API const char * ggml_unary_op_name(enum ggml_unary_op op);
|
||||||
|
GGML_API const char * ggml_glu_op_name(enum ggml_glu_op op);
|
||||||
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
|
GGML_API const char * ggml_op_desc(const struct ggml_tensor * t); // unary or op name
|
||||||
|
|
||||||
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
GGML_API size_t ggml_element_size(const struct ggml_tensor * tensor);
|
||||||
|
@ -762,6 +773,7 @@ extern "C" {
|
||||||
GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
|
GGML_API void ggml_unravel_index(const struct ggml_tensor * tensor, int64_t i, int64_t * i0, int64_t * i1, int64_t * i2, int64_t * i3);
|
||||||
|
|
||||||
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
|
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
|
||||||
|
GGML_API enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor);
|
||||||
|
|
||||||
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
||||||
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
||||||
|
@ -1090,6 +1102,63 @@ extern "C" {
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a);
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
// gated linear unit ops
|
||||||
|
// A: n columns, r rows,
|
||||||
|
// result is n / 2 columns, r rows,
|
||||||
|
// expects gate in second half of row, unless swapped is true
|
||||||
|
GGML_API struct ggml_tensor * ggml_glu(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_glu_op op,
|
||||||
|
bool swapped);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_reglu(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_reglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_geglu(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_geglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_swiglu(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_swiglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a);
|
||||||
|
|
||||||
|
// A: n columns, r rows,
|
||||||
|
// B: n columns, r rows,
|
||||||
|
GGML_API struct ggml_tensor * ggml_glu_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
enum ggml_glu_op op);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_reglu_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_geglu_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_swiglu_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b);
|
||||||
|
|
||||||
// normalize along rows
|
// normalize along rows
|
||||||
GGML_API struct ggml_tensor * ggml_norm(
|
GGML_API struct ggml_tensor * ggml_norm(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
|
|
@ -1949,6 +1949,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_unary(params, tensor);
|
ggml_compute_forward_unary(params, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_glu(params, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_GET_REL_POS:
|
case GGML_OP_GET_REL_POS:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_get_rel_pos(params, tensor);
|
ggml_compute_forward_get_rel_pos(params, tensor);
|
||||||
|
@ -2159,6 +2163,18 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(node)) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
{
|
||||||
|
n_tasks = n_threads;
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_SILU_BACK:
|
case GGML_OP_SILU_BACK:
|
||||||
case GGML_OP_MUL:
|
case GGML_OP_MUL:
|
||||||
case GGML_OP_DIV:
|
case GGML_OP_DIV:
|
||||||
|
|
|
@ -3184,6 +3184,435 @@ void ggml_compute_forward_silu_back(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_reglu
|
||||||
|
|
||||||
|
static void ggml_compute_forward_reglu_f32(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
float * src0_p = (float *) (src0_d + i1*src0_o);
|
||||||
|
float * src1_p = (float *) (src1_d + i1*src1_o);
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_reglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
assert(!isnan(x));
|
||||||
|
assert(!isinf(x));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_reglu_f16(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
||||||
|
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_reglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||||
|
const float v = GGML_FP16_TO_FP32(x);
|
||||||
|
GGML_UNUSED(v);
|
||||||
|
assert(!isnan(v));
|
||||||
|
assert(!isinf(v));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_reglu(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_reglu_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_reglu_f16(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_geglu
|
||||||
|
|
||||||
|
static void ggml_compute_forward_geglu_f32(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
float * src0_p = (float *) (src0_d + i1*src0_o);
|
||||||
|
float * src1_p = (float *) (src1_d + i1*src1_o);
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_geglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
assert(!isnan(x));
|
||||||
|
assert(!isinf(x));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_geglu_f16(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
||||||
|
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_geglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||||
|
const float v = GGML_FP16_TO_FP32(x);
|
||||||
|
GGML_UNUSED(v);
|
||||||
|
assert(!isnan(v));
|
||||||
|
assert(!isinf(v));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_geglu(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_geglu_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_geglu_f16(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_swiglu
|
||||||
|
|
||||||
|
static void ggml_compute_forward_swiglu_f32(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
float * src0_p = (float *) (src0_d + i1*src0_o);
|
||||||
|
float * src1_p = (float *) (src1_d + i1*src1_o);
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_swiglu_f32(nc, (float *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||||
|
GGML_UNUSED(x);
|
||||||
|
assert(!isnan(x));
|
||||||
|
assert(!isinf(x));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_swiglu_f16(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
char * src0_d = (char *) src0->data;
|
||||||
|
char * src1_d = (char *) (src1 ? src1->data : src0->data);
|
||||||
|
const size_t src0_o = src0->nb[1];
|
||||||
|
const size_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(dst));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
const int nr = ggml_nrows(src0);
|
||||||
|
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == nr);
|
||||||
|
|
||||||
|
const int32_t swapped = ggml_get_op_params_i32(dst, 1);
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
for (int i1 = ir0; i1 < ir1; i1++) {
|
||||||
|
ggml_fp16_t * src0_p = (ggml_fp16_t *) (src0_d + i1*src0_o);
|
||||||
|
ggml_fp16_t * src1_p = (ggml_fp16_t *) (src1_d + i1*src1_o);
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_swiglu_f16(nc, (ggml_fp16_t *) ((char *) dst->data + i1*(dst->nb[1])), src0_p, src1_p);
|
||||||
|
|
||||||
|
#ifndef NDEBUG
|
||||||
|
for (int k = 0; k < nc; k++) {
|
||||||
|
const ggml_fp16_t x = ((ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])))[k];
|
||||||
|
const float v = GGML_FP16_TO_FP32(x);
|
||||||
|
GGML_UNUSED(v);
|
||||||
|
assert(!isnan(v));
|
||||||
|
assert(!isinf(v));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_swiglu(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_swiglu_f32(params, dst);
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_swiglu_f16(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_norm
|
// ggml_compute_forward_norm
|
||||||
|
|
||||||
static void ggml_compute_forward_norm_f32(
|
static void ggml_compute_forward_norm_f32(
|
||||||
|
@ -8052,6 +8481,34 @@ void ggml_compute_forward_unary(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//ggml_compute_forward_glu
|
||||||
|
|
||||||
|
void ggml_compute_forward_glu(
|
||||||
|
const ggml_compute_params * params,
|
||||||
|
ggml_tensor * dst) {
|
||||||
|
|
||||||
|
const ggml_glu_op op = ggml_get_glu_op(dst);
|
||||||
|
|
||||||
|
switch (op) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_reglu(params, dst);
|
||||||
|
} break;
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_geglu(params, dst);
|
||||||
|
} break;
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_swiglu(params, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_get_rel_pos
|
// ggml_compute_forward_get_rel_pos
|
||||||
|
|
||||||
static void ggml_compute_forward_get_rel_pos_f16(
|
static void ggml_compute_forward_get_rel_pos_f16(
|
||||||
|
|
|
@ -94,6 +94,7 @@ void ggml_compute_forward_ssm_scan(const struct ggml_compute_params * params, st
|
||||||
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_win_part(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_win_unpart(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_unary(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
void ggml_compute_forward_glu(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_get_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_add_rel_pos(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
void ggml_compute_forward_rwkv_wkv6(const struct ggml_compute_params * params, struct ggml_tensor * dst);
|
||||||
|
|
|
@ -254,6 +254,30 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g) {
|
||||||
|
int i = 0;
|
||||||
|
#if defined(__AVX512F__) && defined(__AVX512DQ__)
|
||||||
|
for (; i + 15 < n; i += 16) {
|
||||||
|
_mm512_storeu_ps(y + i, _mm512_mul_ps(ggml_v_silu(_mm512_loadu_ps(x + i)), _mm512_loadu_ps(g + i)));
|
||||||
|
}
|
||||||
|
#elif defined(__AVX2__) && defined(__FMA__)
|
||||||
|
for (; i + 7 < n; i += 8) {
|
||||||
|
_mm256_storeu_ps(y + i, _mm256_mul_ps(ggml_v_silu(_mm256_loadu_ps(x + i)), _mm256_loadu_ps(g + i)));
|
||||||
|
}
|
||||||
|
#elif defined(__SSE2__)
|
||||||
|
for (; i + 3 < n; i += 4) {
|
||||||
|
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
|
||||||
|
}
|
||||||
|
#elif defined(__ARM_NEON) && defined(__aarch64__)
|
||||||
|
for (; i + 3 < n; i += 4) {
|
||||||
|
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
for (; i < n; ++i) {
|
||||||
|
y[i] = ggml_silu_f32(x[i]) * g[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
|
ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) {
|
||||||
int i = 0;
|
int i = 0;
|
||||||
ggml_float sum = 0;
|
ggml_float sum = 0;
|
||||||
|
|
|
@ -905,6 +905,60 @@ inline static void ggml_vec_silu_backward_f16(const int n, ggml_fp16_t * dx, con
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static void ggml_vec_reglu_f32 (const int n, float * y, const float * x, const float * g) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
y[i] = (x[i] > 0.f) ? x[i] * g[i] : 0.f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
inline static void ggml_vec_reglu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
float v = GGML_FP16_TO_FP32(x[i]);
|
||||||
|
y[i] = GGML_FP32_TO_FP16((v > 0.f) ? v * GGML_FP16_TO_FP32(g[i]) : 0.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#ifdef GGML_GELU_FP16
|
||||||
|
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
|
||||||
|
uint16_t t;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
if (x[i] <= -10.0f) {
|
||||||
|
y[i] = 0.0f;
|
||||||
|
} else if (x[i] >= 10.0f) {
|
||||||
|
y[i] = x[i] * g[i];
|
||||||
|
} else {
|
||||||
|
ggml_fp16_t fp16 = GGML_FP32_TO_FP16(x[i]);
|
||||||
|
memcpy(&t, &fp16, sizeof(uint16_t));
|
||||||
|
y[i] = GGML_FP16_TO_FP32(ggml_table_gelu_f16[t]) * g[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
inline static void ggml_vec_geglu_f32(const int n, float * y, const float * x, const float * g) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
y[i] = ggml_gelu_f32(x[i]) * g[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
inline static void ggml_vec_geglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
||||||
|
const uint16_t * i16 = (const uint16_t *) x;
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
float v = GGML_FP16_TO_FP32(g[i]);
|
||||||
|
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(ggml_table_gelu_f16[i16[i]]) * v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * g);
|
||||||
|
|
||||||
|
inline static void ggml_vec_swiglu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x, const ggml_fp16_t * g) {
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
float v = GGML_FP16_TO_FP32(x[i]);
|
||||||
|
float w = GGML_FP16_TO_FP32(g[i]);
|
||||||
|
y[i] = GGML_FP32_TO_FP16((v/(1.0f + expf(-v))) * w);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
inline static void ggml_vec_sum_f32(const int n, float * s, const float * x) {
|
||||||
#ifndef GGML_USE_ACCELERATE
|
#ifndef GGML_USE_ACCELERATE
|
||||||
ggml_float sum = 0.0;
|
ggml_float sum = 0.0;
|
||||||
|
|
|
@ -2303,6 +2303,21 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(dst)) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
ggml_cuda_op_reglu(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
ggml_cuda_op_geglu(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
ggml_cuda_op_swiglu(ctx, dst);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
ggml_cuda_op_norm(ctx, dst);
|
ggml_cuda_op_norm(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -3096,6 +3111,16 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(op)) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
return ggml_is_contiguous_1(op->src[0]);
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
|
|
|
@ -196,6 +196,95 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_unary<op_log>(ctx, dst);
|
ggml_cuda_op_unary<op_log>(ctx, dst);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* gated ops */
|
||||||
|
|
||||||
|
template <float (*op)(float), typename T>
|
||||||
|
static __global__ void unary_gated_op_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1) {
|
||||||
|
const int64_t i = int64_t(blockDim.x)*blockIdx.x + threadIdx.x;
|
||||||
|
|
||||||
|
if (i >= k) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// perform base op and multiply with gate (either offset in same tensor or a separate one)
|
||||||
|
const int64_t j0 = (i / n) * o0 + (i % n);
|
||||||
|
const int64_t j1 = o0 == o1 ? j0 : (i / n) * o1 + (i % n);
|
||||||
|
|
||||||
|
dst[i] = (T)(op((float)x[j0]) * (float)g[j1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <float (*op)(float), typename T>
|
||||||
|
static void unary_gated_cuda(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, cudaStream_t stream) {
|
||||||
|
const int64_t num_blocks = (k + CUDA_GLU_BLOCK_SIZE - 1) / CUDA_GLU_BLOCK_SIZE;
|
||||||
|
unary_gated_op_kernel<op><<<num_blocks, CUDA_GLU_BLOCK_SIZE, 0, stream>>>(x, g, dst, k, n, o0, o1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <float (*op)(float)>
|
||||||
|
void ggml_cuda_op_unary_gated(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
const ggml_tensor * src0 = dst->src[0];
|
||||||
|
const ggml_tensor * src1 = dst->src[1];
|
||||||
|
void * src0_d = src0->data;
|
||||||
|
void * src1_d = src1 ? src1->data : src0->data;
|
||||||
|
const int64_t src0_o = src0->nb[1];
|
||||||
|
const int64_t src1_o = src1 ? src1->nb[1] : src0->nb[1];
|
||||||
|
void * dst_d = dst->data;
|
||||||
|
const int64_t nc = src1 ? src0->ne[0] : src0->ne[0] / 2;
|
||||||
|
cudaStream_t stream = ctx.stream();
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
GGML_ASSERT(src0->nb[0] == ggml_element_size(src0));
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(dst));
|
||||||
|
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(src0->type == dst->type);
|
||||||
|
GGML_ASSERT(dst->ne[0] == nc);
|
||||||
|
GGML_ASSERT(ggml_nrows(dst) == ggml_nrows(src0));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src1));
|
||||||
|
GGML_ASSERT(src1->nb[0] == ggml_element_size(src1));
|
||||||
|
GGML_ASSERT(src1->ne[0] == nc);
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const int32_t swapped = ((const int32_t *) dst->op_params)[1];
|
||||||
|
|
||||||
|
if (src0->type == GGML_TYPE_F16) {
|
||||||
|
half * src0_p = (half *) src0_d;
|
||||||
|
half * src1_p = (half *) src1_d;
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
unary_gated_cuda<op>(src0_p, src1_p, (half *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(half), src1_o / sizeof(half), stream);
|
||||||
|
} else {
|
||||||
|
float * src0_p = (float *) src0_d;
|
||||||
|
float * src1_p = (float *) src1_d;
|
||||||
|
|
||||||
|
if (!src1) {
|
||||||
|
src0_p += swapped ? nc : 0;
|
||||||
|
src1_p += swapped ? 0 : nc;
|
||||||
|
}
|
||||||
|
|
||||||
|
unary_gated_cuda<op>(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), stream);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary_gated<op_relu>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary_gated<op_gelu>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
|
||||||
|
ggml_cuda_op_unary_gated<op_silu>(ctx, dst);
|
||||||
|
}
|
||||||
|
|
||||||
/* silu_back */
|
/* silu_back */
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
static __device__ __forceinline__ float op_silu_back(float grad, float x) {
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
#define CUDA_SQRT_BLOCK_SIZE 256
|
#define CUDA_SQRT_BLOCK_SIZE 256
|
||||||
#define CUDA_SIN_BLOCK_SIZE 256
|
#define CUDA_SIN_BLOCK_SIZE 256
|
||||||
#define CUDA_COS_BLOCK_SIZE 256
|
#define CUDA_COS_BLOCK_SIZE 256
|
||||||
|
#define CUDA_GLU_BLOCK_SIZE 256
|
||||||
|
|
||||||
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
@ -57,3 +58,9 @@ void ggml_cuda_op_sin(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
void ggml_cuda_op_swiglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
|
||||||
|
|
|
@ -422,6 +422,17 @@ typedef struct {
|
||||||
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
int32_t KHW; // KH * KW, pre-computed on CPU to save GPU resources
|
||||||
} ggml_metal_kargs_im2col;
|
} ggml_metal_kargs_im2col;
|
||||||
|
|
||||||
|
typedef struct{
|
||||||
|
int32_t ne00;
|
||||||
|
uint64_t nb01;
|
||||||
|
int32_t ne10;
|
||||||
|
uint64_t nb11;
|
||||||
|
int32_t ne0;
|
||||||
|
uint64_t nb1;
|
||||||
|
int32_t i00;
|
||||||
|
int32_t i10;
|
||||||
|
} ggml_metal_kargs_glu;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int64_t ne00;
|
int64_t ne00;
|
||||||
int64_t ne01;
|
int64_t ne01;
|
||||||
|
|
|
@ -526,6 +526,9 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_SIN,
|
GGML_METAL_KERNEL_TYPE_SIN,
|
||||||
GGML_METAL_KERNEL_TYPE_COS,
|
GGML_METAL_KERNEL_TYPE_COS,
|
||||||
GGML_METAL_KERNEL_TYPE_NEG,
|
GGML_METAL_KERNEL_TYPE_NEG,
|
||||||
|
GGML_METAL_KERNEL_TYPE_REGLU,
|
||||||
|
GGML_METAL_KERNEL_TYPE_GEGLU,
|
||||||
|
GGML_METAL_KERNEL_TYPE_SWIGLU,
|
||||||
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
GGML_METAL_KERNEL_TYPE_SUM_ROWS,
|
||||||
GGML_METAL_KERNEL_TYPE_MEAN,
|
GGML_METAL_KERNEL_TYPE_MEAN,
|
||||||
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
|
||||||
|
@ -1502,6 +1505,9 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN, sin, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS, cos, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NEG, neg, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_REGLU, reglu, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GEGLU, geglu, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SWIGLU, swiglu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MEAN, mean, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX, argmax, true);
|
||||||
|
@ -1680,6 +1686,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(op)) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
return ggml_is_contiguous_1(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
case GGML_OP_NONE:
|
case GGML_OP_NONE:
|
||||||
case GGML_OP_RESHAPE:
|
case GGML_OP_RESHAPE:
|
||||||
case GGML_OP_VIEW:
|
case GGML_OP_VIEW:
|
||||||
|
@ -2419,6 +2434,62 @@ static bool ggml_metal_encode_node(
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(src0));
|
||||||
|
|
||||||
|
if (src1) {
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(src0, src1));
|
||||||
|
}
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
switch (ggml_get_glu_op(node)) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_REGLU].pipeline;
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GEGLU].pipeline;
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SWIGLU].pipeline;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
GGML_ABORT("fatal error");
|
||||||
|
}
|
||||||
|
|
||||||
|
const int32_t swp = ((const int32_t *) dst->op_params)[1];
|
||||||
|
|
||||||
|
const int32_t i00 = swp ? ne0 : 0;
|
||||||
|
const int32_t i10 = swp ? 0 : ne0;
|
||||||
|
|
||||||
|
ggml_metal_kargs_glu args = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.ne10 =*/ src1 ? ne10 : ne00,
|
||||||
|
/*.nb11 =*/ src1 ? nb11 : nb01,
|
||||||
|
/*.ne0 =*/ ne0,
|
||||||
|
/*.nb1 =*/ nb1,
|
||||||
|
/*.i00 =*/ src1 ? 0 : i00,
|
||||||
|
/*.i10 =*/ src1 ? 0 : i10,
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
if (src1) {
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
} else {
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
|
}
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
||||||
|
|
||||||
|
const int64_t nrows = ggml_nrows(src0);
|
||||||
|
|
||||||
|
const int32_t nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00/2);
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_SQR:
|
case GGML_OP_SQR:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(ggml_is_contiguous(src0));
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
|
@ -1191,6 +1191,70 @@ kernel void kernel_neg(
|
||||||
dst[tpig] = -src0[tpig];
|
dst[tpig] = -src0[tpig];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
kernel void kernel_reglu(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
constant ggml_metal_kargs_glu & args,
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||||
|
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||||
|
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
||||||
|
|
||||||
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||||
|
const float x0 = src0_row[i0];
|
||||||
|
const float x1 = src1_row[i0];
|
||||||
|
|
||||||
|
dst_row[i0] = x0*x1*(x0 > 0.0f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_geglu(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
constant ggml_metal_kargs_glu & args,
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||||
|
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||||
|
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
||||||
|
|
||||||
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||||
|
const float x0 = src0_row[i0];
|
||||||
|
const float x1 = src1_row[i0];
|
||||||
|
|
||||||
|
const float gelu = 0.5f*x0*(1.0f + precise::tanh(SQRT_2_OVER_PI*x0*(1.0f + GELU_COEF_A*x0*x0)));
|
||||||
|
|
||||||
|
dst_row[i0] = gelu*x1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_swiglu(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
constant ggml_metal_kargs_glu & args,
|
||||||
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint ntg[[threads_per_threadgroup]]) {
|
||||||
|
device const float * src0_row = (device const float *) ((device const char *) src0 + tgpig*args.nb01) + args.i00;
|
||||||
|
device const float * src1_row = (device const float *) ((device const char *) src1 + tgpig*args.nb11) + args.i10;
|
||||||
|
device float * dst_row = (device float *) ((device char *) dst + tgpig*args.nb1);
|
||||||
|
|
||||||
|
for (int i0 = tpitg; i0 < args.ne0; i0 += ntg) {
|
||||||
|
const float x0 = src0_row[i0];
|
||||||
|
const float x1 = src1_row[i0];
|
||||||
|
|
||||||
|
const float silu = x0 / (1.0f + exp(-x0));
|
||||||
|
|
||||||
|
dst_row[i0] = silu*x1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <bool norm>
|
template <bool norm>
|
||||||
kernel void kernel_sum_rows(
|
kernel void kernel_sum_rows(
|
||||||
constant ggml_metal_kargs_sum_rows & args,
|
constant ggml_metal_kargs_sum_rows & args,
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -3,27 +3,30 @@
|
||||||
|
|
||||||
#include "common.hpp"
|
#include "common.hpp"
|
||||||
#include "ggml.h"
|
#include "ggml.h"
|
||||||
#include <limits.h>
|
#include <limits> // For std::numeric_limits
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
T neg_infinity() {
|
T neg_infinity() {
|
||||||
return -std::numeric_limits<T>::infinity();
|
return -std::numeric_limits<T>::infinity();
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T_Dst, typename T_Src = T_Dst>
|
||||||
struct typed_data {
|
struct typed_data {
|
||||||
const T * src;
|
const T_Src * src;
|
||||||
T * dst;
|
T_Dst * dst;
|
||||||
};
|
};
|
||||||
|
|
||||||
template<typename T>
|
template<typename T_Dst, typename T_Src = T_Dst>
|
||||||
typed_data<T> cast_data(ggml_tensor * dst) {
|
typed_data<T_Dst, T_Src> cast_data(ggml_tensor * dst) {
|
||||||
return {
|
return {
|
||||||
/* .src = */ static_cast<const T *>(dst->src[0]->data),
|
/* .src = */ static_cast<const T_Src *>(dst->src[0]->data),
|
||||||
/* .dst = */ static_cast<T *>(dst->data)
|
/* .dst = */ static_cast<T_Dst *>(dst->data)
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const float GELU_QUICK_COEF = -1.702f;
|
||||||
|
|
||||||
|
|
||||||
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
void ggml_sycl_sqrt(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
void ggml_sycl_sin(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
@ -73,5 +76,9 @@ void ggml_sycl_sgn(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
void ggml_sycl_abs(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
void ggml_sycl_elu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
|
||||||
|
|
||||||
|
void ggml_sycl_geglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
void ggml_sycl_reglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
void ggml_sycl_swiglu(ggml_backend_sycl_context & ctx, ggml_tensor * dst);
|
||||||
|
|
||||||
|
#endif // GGML_SYCL_ELEMENTWISE_HPP
|
||||||
|
|
|
@ -3676,6 +3676,21 @@ static bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct gg
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(dst)) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
ggml_sycl_reglu(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
ggml_sycl_geglu(ctx, dst);
|
||||||
|
break;
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
ggml_sycl_swiglu(ctx, dst);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_NORM:
|
case GGML_OP_NORM:
|
||||||
ggml_sycl_norm(ctx, dst);
|
ggml_sycl_norm(ctx, dst);
|
||||||
break;
|
break;
|
||||||
|
@ -4212,6 +4227,16 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(op)) {
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
return ggml_is_contiguous_1(op->src[0]);
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
|
|
|
@ -437,6 +437,10 @@ struct vk_device_struct {
|
||||||
vk_pipeline pipeline_tanh[2];
|
vk_pipeline pipeline_tanh[2];
|
||||||
vk_pipeline pipeline_sigmoid[2];
|
vk_pipeline pipeline_sigmoid[2];
|
||||||
|
|
||||||
|
vk_pipeline pipeline_geglu[2];
|
||||||
|
vk_pipeline pipeline_reglu[2];
|
||||||
|
vk_pipeline pipeline_swiglu[2];
|
||||||
|
|
||||||
vk_pipeline pipeline_leaky_relu_f32;
|
vk_pipeline pipeline_leaky_relu_f32;
|
||||||
vk_pipeline pipeline_silu_back_f32;
|
vk_pipeline pipeline_silu_back_f32;
|
||||||
vk_pipeline pipeline_diag_mask_inf_f32;
|
vk_pipeline pipeline_diag_mask_inf_f32;
|
||||||
|
@ -661,6 +665,13 @@ struct vk_op_push_constants {
|
||||||
float param2;
|
float param2;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct vk_op_glu_push_constants {
|
||||||
|
uint32_t N;
|
||||||
|
uint32_t ne00;
|
||||||
|
uint32_t ne20;
|
||||||
|
uint32_t mode; // 0: default, 1: swapped, 2: split
|
||||||
|
};
|
||||||
|
|
||||||
struct vk_op_unary_push_constants {
|
struct vk_op_unary_push_constants {
|
||||||
uint32_t ne;
|
uint32_t ne;
|
||||||
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
|
||||||
|
@ -2757,6 +2768,15 @@ static void ggml_vk_load_shaders(vk_device& device) {
|
||||||
CREATE_UNARY(sigmoid)
|
CREATE_UNARY(sigmoid)
|
||||||
#undef CREATE_UNARY
|
#undef CREATE_UNARY
|
||||||
|
|
||||||
|
#define CREATE_GLU(name) \
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \
|
||||||
|
ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
|
||||||
|
|
||||||
|
CREATE_GLU(geglu)
|
||||||
|
CREATE_GLU(reglu)
|
||||||
|
CREATE_GLU(swiglu)
|
||||||
|
#undef CREATE_GLU
|
||||||
|
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||||
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
|
||||||
|
|
||||||
|
@ -6473,6 +6493,24 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
|
||||||
|
(dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
|
||||||
|
(src0->type != dst->type)) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
switch (ggml_get_glu_op(dst)) {
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16];
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16];
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
|
||||||
return ctx->device->pipeline_diag_mask_inf_f32;
|
return ctx->device->pipeline_diag_mask_inf_f32;
|
||||||
|
@ -6933,6 +6971,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
case GGML_OP_CONCAT:
|
case GGML_OP_CONCAT:
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
|
case GGML_OP_GLU:
|
||||||
case GGML_OP_CONV_2D_DW:
|
case GGML_OP_CONV_2D_DW:
|
||||||
{
|
{
|
||||||
uint32_t ne = ggml_nelements(dst);
|
uint32_t ne = ggml_nelements(dst);
|
||||||
|
@ -6973,7 +7012,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (op == GGML_OP_SOFT_MAX) {
|
if (op == GGML_OP_SOFT_MAX || op == GGML_OP_GLU) {
|
||||||
// Empty src1 is possible in soft_max, but the shader needs a buffer
|
// Empty src1 is possible in soft_max, but the shader needs a buffer
|
||||||
vk_subbuffer subbuf_y;
|
vk_subbuffer subbuf_y;
|
||||||
if (use_src1) {
|
if (use_src1) {
|
||||||
|
@ -7566,6 +7605,25 @@ static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, con
|
||||||
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
|
||||||
|
const bool swapped = (bool)dst->op_params[1];
|
||||||
|
const bool split = src1 != nullptr;
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(src0));
|
||||||
|
|
||||||
|
if (!split) {
|
||||||
|
GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
|
||||||
|
} else {
|
||||||
|
GGML_ASSERT(src0->ne[0] == src1->ne[0]);
|
||||||
|
GGML_ASSERT(src0->ne[0] == dst->ne[0]);
|
||||||
|
GGML_ASSERT(src0->type == src1->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t mode = split ? 2 : (swapped ? 1 : 0);
|
||||||
|
|
||||||
|
ggml_vk_op_f32<vk_op_glu_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, { (uint32_t)ggml_nelements(dst), (uint32_t)src0->ne[0], (uint32_t)dst->ne[0], mode }, dryrun);
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
|
||||||
int32_t * op_params = (int32_t *)dst->op_params;
|
int32_t * op_params = (int32_t *)dst->op_params;
|
||||||
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
ggml_vk_op_f32<vk_op_diag_mask_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun);
|
||||||
|
@ -8778,6 +8836,16 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(node)) {
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_REPEAT:
|
case GGML_OP_REPEAT:
|
||||||
case GGML_OP_REPEAT_BACK:
|
case GGML_OP_REPEAT_BACK:
|
||||||
case GGML_OP_GET_ROWS:
|
case GGML_OP_GET_ROWS:
|
||||||
|
@ -8870,6 +8938,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
case GGML_OP_RMS_NORM_BACK:
|
case GGML_OP_RMS_NORM_BACK:
|
||||||
case GGML_OP_L2_NORM:
|
case GGML_OP_L2_NORM:
|
||||||
case GGML_OP_UNARY:
|
case GGML_OP_UNARY:
|
||||||
|
case GGML_OP_GLU:
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
case GGML_OP_SOFT_MAX_BACK:
|
case GGML_OP_SOFT_MAX_BACK:
|
||||||
|
@ -9013,6 +9082,17 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(node)) {
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
|
ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun);
|
||||||
|
|
||||||
|
@ -9138,8 +9218,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
|
||||||
if (!ok) {
|
if (!ok) {
|
||||||
if (node->op == GGML_OP_UNARY) {
|
if (node->op == GGML_OP_UNARY) {
|
||||||
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast<ggml_unary_op>(node->op_params[0])) << ")" << std::endl;
|
||||||
}
|
} else if (node->op == GGML_OP_GLU) {
|
||||||
else {
|
std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast<ggml_glu_op>(node->op_params[0])) << ")" << std::endl;
|
||||||
|
} else {
|
||||||
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
|
std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9218,6 +9299,17 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_tensor *
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(tensor)) {
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
buf = tensor->buffer;
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
@ -10016,6 +10108,19 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_GLU:
|
||||||
|
switch (ggml_get_glu_op(op)) {
|
||||||
|
case GGML_GLU_OP_GEGLU:
|
||||||
|
case GGML_GLU_OP_REGLU:
|
||||||
|
case GGML_GLU_OP_SWIGLU:
|
||||||
|
return ggml_is_contiguous(op->src[0]) &&
|
||||||
|
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
|
||||||
|
(op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
|
||||||
|
(op->src[0]->type == op->type);
|
||||||
|
default:
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
break;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
{
|
{
|
||||||
|
@ -10746,6 +10851,12 @@ static void ggml_vk_check_results_0(ggml_tensor * tensor) {
|
||||||
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl;
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
} else if (tensor->op == GGML_OP_GLU) {
|
||||||
|
if (src_clone[1] == nullptr) {
|
||||||
|
tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]);
|
||||||
|
} else {
|
||||||
|
tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]);
|
||||||
|
}
|
||||||
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
} else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) {
|
||||||
if (src1 == nullptr) {
|
if (src1 == nullptr) {
|
||||||
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
tensor_clone = ggml_dup(ggml_ctx, src_clone[0]);
|
||||||
|
|
|
@ -0,0 +1,13 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "glu_head.comp"
|
||||||
|
|
||||||
|
const float GELU_COEF_A = 0.044715f;
|
||||||
|
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
|
||||||
|
|
||||||
|
float op(float a, float b) {
|
||||||
|
const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a);
|
||||||
|
return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "glu_main.comp"
|
|
@ -0,0 +1,15 @@
|
||||||
|
#extension GL_EXT_shader_16bit_storage : require
|
||||||
|
|
||||||
|
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
|
layout (binding = 1) readonly buffer B {A_TYPE data_b[];};
|
||||||
|
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
|
layout (push_constant) uniform parameter
|
||||||
|
{
|
||||||
|
uint N;
|
||||||
|
uint ne00;
|
||||||
|
uint ne20;
|
||||||
|
uint mode;
|
||||||
|
} p;
|
|
@ -0,0 +1,29 @@
|
||||||
|
void main() {
|
||||||
|
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
|
||||||
|
|
||||||
|
if (i >= p.N) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint row = i / p.ne20;
|
||||||
|
const uint col = i - row * p.ne20;
|
||||||
|
|
||||||
|
if (p.mode == 0) {
|
||||||
|
// Default
|
||||||
|
const uint offset = p.ne00 / 2;
|
||||||
|
const uint idx = row * p.ne00 + col;
|
||||||
|
|
||||||
|
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
|
||||||
|
} else if (p.mode == 1) {
|
||||||
|
// Swapped
|
||||||
|
const uint offset = p.ne00 / 2;
|
||||||
|
const uint idx = row * p.ne00 + col;
|
||||||
|
|
||||||
|
data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
|
||||||
|
} else {
|
||||||
|
// Split
|
||||||
|
const uint idx = row * p.ne00 + col;
|
||||||
|
|
||||||
|
data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,9 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "glu_head.comp"
|
||||||
|
|
||||||
|
float op(float a, float b) {
|
||||||
|
return max(a, 0.0f) * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "glu_main.comp"
|
|
@ -0,0 +1,9 @@
|
||||||
|
#version 450
|
||||||
|
|
||||||
|
#include "glu_head.comp"
|
||||||
|
|
||||||
|
float op(float a, float b) {
|
||||||
|
return a / (1.0f + exp(-a)) * b;
|
||||||
|
}
|
||||||
|
|
||||||
|
#include "glu_main.comp"
|
|
@ -585,6 +585,13 @@ void process_shaders() {
|
||||||
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
|
string_to_spv("geglu_f16", "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("geglu_f32", "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
string_to_spv("reglu_f16", "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("reglu_f32", "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
string_to_spv("swiglu_f16", "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
|
||||||
|
string_to_spv("swiglu_f32", "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
|
||||||
|
|
||||||
|
|
138
ggml/src/ggml.c
138
ggml/src/ggml.c
|
@ -982,9 +982,11 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"CROSS_ENTROPY_LOSS",
|
"CROSS_ENTROPY_LOSS",
|
||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
"OPT_STEP_ADAMW",
|
"OPT_STEP_ADAMW",
|
||||||
|
|
||||||
|
"GLU",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
|
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -1079,9 +1081,11 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"cross_entropy_loss(x,y)",
|
"cross_entropy_loss(x,y)",
|
||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
"adamw(x)",
|
"adamw(x)",
|
||||||
|
|
||||||
|
"glu(x)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
|
static_assert(GGML_OP_COUNT == 85, "GGML_OP_COUNT != 85");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
@ -1107,6 +1111,15 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = {
|
||||||
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
|
static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15");
|
||||||
|
|
||||||
|
|
||||||
|
static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = {
|
||||||
|
"REGLU",
|
||||||
|
"GEGLU",
|
||||||
|
"SWIGLU",
|
||||||
|
};
|
||||||
|
|
||||||
|
static_assert(GGML_GLU_OP_COUNT == 3, "GGML_GLU_OP_COUNT != 3");
|
||||||
|
|
||||||
|
|
||||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||||
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
||||||
|
|
||||||
|
@ -1209,11 +1222,19 @@ const char * ggml_unary_op_name(enum ggml_unary_op op) {
|
||||||
return GGML_UNARY_OP_NAME[op];
|
return GGML_UNARY_OP_NAME[op];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const char * ggml_glu_op_name(enum ggml_glu_op op) {
|
||||||
|
return GGML_GLU_OP_NAME[op];
|
||||||
|
}
|
||||||
|
|
||||||
const char * ggml_op_desc(const struct ggml_tensor * t) {
|
const char * ggml_op_desc(const struct ggml_tensor * t) {
|
||||||
if (t->op == GGML_OP_UNARY) {
|
if (t->op == GGML_OP_UNARY) {
|
||||||
enum ggml_unary_op uop = ggml_get_unary_op(t);
|
enum ggml_unary_op uop = ggml_get_unary_op(t);
|
||||||
return ggml_unary_op_name(uop);
|
return ggml_unary_op_name(uop);
|
||||||
}
|
}
|
||||||
|
if (t->op == GGML_OP_GLU) {
|
||||||
|
enum ggml_glu_op gop = ggml_get_glu_op(t);
|
||||||
|
return ggml_glu_op_name(gop);
|
||||||
|
}
|
||||||
return ggml_op_name(t->op);
|
return ggml_op_name(t->op);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1730,6 +1751,11 @@ enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor) {
|
||||||
return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
|
return (enum ggml_unary_op) ggml_get_op_params_i32(tensor, 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
enum ggml_glu_op ggml_get_glu_op(const struct ggml_tensor * tensor) {
|
||||||
|
GGML_ASSERT(tensor->op == GGML_OP_GLU);
|
||||||
|
return (enum ggml_glu_op) ggml_get_op_params_i32(tensor, 0);
|
||||||
|
}
|
||||||
|
|
||||||
const char * ggml_get_name(const struct ggml_tensor * tensor) {
|
const char * ggml_get_name(const struct ggml_tensor * tensor) {
|
||||||
return tensor->name;
|
return tensor->name;
|
||||||
}
|
}
|
||||||
|
@ -2609,6 +2635,114 @@ struct ggml_tensor * ggml_exp_inplace(
|
||||||
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
|
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_EXP);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_glu
|
||||||
|
|
||||||
|
static struct ggml_tensor * ggml_glu_impl(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
enum ggml_glu_op op,
|
||||||
|
bool swapped) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(a));
|
||||||
|
|
||||||
|
if (b) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous_1(b));
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(a, b));
|
||||||
|
GGML_ASSERT(a->type == b->type);
|
||||||
|
}
|
||||||
|
|
||||||
|
int64_t ne[GGML_MAX_DIMS] = { a->ne[0] / 2 }; for (int i = 1; i < GGML_MAX_DIMS; i++) ne[i] = a->ne[i];
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor_impl(ctx, a->type, GGML_MAX_DIMS, b ? a->ne : ne, NULL, 0);
|
||||||
|
|
||||||
|
ggml_set_op_params_i32(result, 0, (int32_t) op);
|
||||||
|
ggml_set_op_params_i32(result, 1, (int32_t) swapped);
|
||||||
|
|
||||||
|
result->op = GGML_OP_GLU;
|
||||||
|
result->src[0] = a;
|
||||||
|
result->src[1] = b;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_glu(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_glu_op op,
|
||||||
|
bool swapped) {
|
||||||
|
return ggml_glu_impl(ctx, a, NULL, op, swapped);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_glu_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b,
|
||||||
|
enum ggml_glu_op op) {
|
||||||
|
return ggml_glu_impl(ctx, a, b, op, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_reglu
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_reglu(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_reglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_REGLU, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_reglu_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b) {
|
||||||
|
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_REGLU, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_geglu
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_geglu(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_geglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_GEGLU, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_geglu_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b) {
|
||||||
|
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_GEGLU, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ggml_swiglu
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_swiglu(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_swiglu_swapped(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a) {
|
||||||
|
return ggml_glu_impl(ctx, a, NULL, GGML_GLU_OP_SWIGLU, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_swiglu_split(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
struct ggml_tensor * b) {
|
||||||
|
return ggml_glu_impl(ctx, a, b, GGML_GLU_OP_SWIGLU, false);
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_norm
|
// ggml_norm
|
||||||
|
|
||||||
static struct ggml_tensor * ggml_norm_impl(
|
static struct ggml_tensor * ggml_norm_impl(
|
||||||
|
|
|
@ -560,12 +560,20 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||||
|
|
||||||
switch (type_op) {
|
switch (type_op) {
|
||||||
case LLM_FFN_SILU:
|
case LLM_FFN_SILU:
|
||||||
{
|
if (gate && type_gate == LLM_FFN_PAR) {
|
||||||
|
cur = ggml_swiglu_split(ctx0, cur, tmp);
|
||||||
|
cb(cur, "ffn_swiglu", il);
|
||||||
|
type_gate = LLM_FFN_SEQ;
|
||||||
|
} else {
|
||||||
cur = ggml_silu(ctx0, cur);
|
cur = ggml_silu(ctx0, cur);
|
||||||
cb(cur, "ffn_silu", il);
|
cb(cur, "ffn_silu", il);
|
||||||
} break;
|
} break;
|
||||||
case LLM_FFN_GELU:
|
case LLM_FFN_GELU:
|
||||||
{
|
if (gate && type_gate == LLM_FFN_PAR) {
|
||||||
|
cur = ggml_geglu_split(ctx0, cur, tmp);
|
||||||
|
cb(cur, "ffn_geglu", il);
|
||||||
|
type_gate = LLM_FFN_SEQ;
|
||||||
|
} else {
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
cb(cur, "ffn_gelu", il);
|
cb(cur, "ffn_gelu", il);
|
||||||
if (act_scales != NULL) {
|
if (act_scales != NULL) {
|
||||||
|
@ -574,7 +582,11 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_FFN_RELU:
|
case LLM_FFN_RELU:
|
||||||
{
|
if (gate && type_gate == LLM_FFN_PAR) {
|
||||||
|
cur = ggml_reglu_split(ctx0, cur, tmp);
|
||||||
|
cb(cur, "ffn_reglu", il);
|
||||||
|
type_gate = LLM_FFN_SEQ;
|
||||||
|
} else {
|
||||||
cur = ggml_relu(ctx0, cur);
|
cur = ggml_relu(ctx0, cur);
|
||||||
cb(cur, "ffn_relu", il);
|
cb(cur, "ffn_relu", il);
|
||||||
} break;
|
} break;
|
||||||
|
@ -588,32 +600,19 @@ ggml_tensor * llm_graph_context::build_ffn(
|
||||||
} break;
|
} break;
|
||||||
case LLM_FFN_SWIGLU:
|
case LLM_FFN_SWIGLU:
|
||||||
{
|
{
|
||||||
// Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
|
cur = ggml_swiglu(ctx0, cur);
|
||||||
int64_t split_point = cur->ne[0] / 2;
|
cb(cur, "ffn_swiglu", il);
|
||||||
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
|
||||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
||||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
||||||
|
|
||||||
x0 = ggml_silu(ctx0, x0);
|
|
||||||
cb(cur, "ffn_silu", il);
|
|
||||||
|
|
||||||
cur = ggml_mul(ctx0, x0, x1);
|
|
||||||
cb(cur, "ffn_mul", il);
|
|
||||||
} break;
|
} break;
|
||||||
case LLM_FFN_GEGLU:
|
case LLM_FFN_GEGLU:
|
||||||
{
|
{
|
||||||
// Split into two equal parts
|
cur = ggml_geglu(ctx0, cur);
|
||||||
int64_t split_point = cur->ne[0] / 2;
|
|
||||||
// TODO: these conts should not be needed, see https://github.com/ggml-org/llama.cpp/pull/14090#discussion_r2137437217
|
|
||||||
ggml_tensor * x0 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], 0));
|
|
||||||
ggml_tensor * x1 = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, split_point, cur->ne[1], cur->nb[1], split_point * ggml_element_size(cur)));
|
|
||||||
|
|
||||||
x0 = ggml_gelu(ctx0, x0);
|
|
||||||
cb(x0, "ffn_gelu", il);
|
|
||||||
|
|
||||||
cur = ggml_mul(ctx0, x0, x1);
|
|
||||||
cb(cur, "ffn_geglu", il);
|
cb(cur, "ffn_geglu", il);
|
||||||
} break;
|
} break;
|
||||||
|
case LLM_FFN_REGLU:
|
||||||
|
{
|
||||||
|
cur = ggml_reglu(ctx0, cur);
|
||||||
|
cb(cur, "ffn_reglu", il);
|
||||||
|
} break;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gate && type_gate == LLM_FFN_PAR) {
|
if (gate && type_gate == LLM_FFN_PAR) {
|
||||||
|
@ -743,12 +742,18 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
|
|
||||||
switch (type_op) {
|
switch (type_op) {
|
||||||
case LLM_FFN_SILU:
|
case LLM_FFN_SILU:
|
||||||
{
|
if (gate_exps) {
|
||||||
|
cur = ggml_swiglu_split(ctx0, cur, up);
|
||||||
|
cb(cur, "ffn_moe_swiglu", il);
|
||||||
|
} else {
|
||||||
cur = ggml_silu(ctx0, cur);
|
cur = ggml_silu(ctx0, cur);
|
||||||
cb(cur, "ffn_moe_silu", il);
|
cb(cur, "ffn_moe_silu", il);
|
||||||
} break;
|
} break;
|
||||||
case LLM_FFN_GELU:
|
case LLM_FFN_GELU:
|
||||||
{
|
if (gate_exps) {
|
||||||
|
cur = ggml_geglu_split(ctx0, cur, up);
|
||||||
|
cb(cur, "ffn_moe_geglu", il);
|
||||||
|
} else {
|
||||||
cur = ggml_gelu(ctx0, cur);
|
cur = ggml_gelu(ctx0, cur);
|
||||||
cb(cur, "ffn_moe_gelu", il);
|
cb(cur, "ffn_moe_gelu", il);
|
||||||
} break;
|
} break;
|
||||||
|
@ -756,11 +761,6 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
||||||
GGML_ABORT("fatal error");
|
GGML_ABORT("fatal error");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (gate_exps) {
|
|
||||||
cur = ggml_mul(ctx0, cur, up); // [n_ff, n_expert_used, n_tokens]
|
|
||||||
cb(cur, "ffn_moe_gate_par", il);
|
|
||||||
}
|
|
||||||
|
|
||||||
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
experts = build_lora_mm_id(down_exps, cur, selected_experts); // [n_embd, n_expert_used, n_tokens]
|
||||||
cb(experts, "ffn_moe_down", il);
|
cb(experts, "ffn_moe_down", il);
|
||||||
|
|
||||||
|
|
|
@ -38,6 +38,7 @@ enum llm_ffn_op_type {
|
||||||
LLM_FFN_RELU_SQR,
|
LLM_FFN_RELU_SQR,
|
||||||
LLM_FFN_SWIGLU,
|
LLM_FFN_SWIGLU,
|
||||||
LLM_FFN_GEGLU,
|
LLM_FFN_GEGLU,
|
||||||
|
LLM_FFN_REGLU,
|
||||||
};
|
};
|
||||||
|
|
||||||
enum llm_ffn_gate_type {
|
enum llm_ffn_gate_type {
|
||||||
|
|
|
@ -1106,6 +1106,107 @@ struct test_unary : public test_case {
|
||||||
|
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GGML_OP_GLU
|
||||||
|
struct test_glu : public test_case {
|
||||||
|
const ggml_glu_op op;
|
||||||
|
const ggml_type type;
|
||||||
|
const std::array<int64_t, 4> ne_a;
|
||||||
|
int v; // view (1 : non-contiguous a)
|
||||||
|
bool swapped;
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR4(type, ne_a, v, swapped);
|
||||||
|
}
|
||||||
|
|
||||||
|
test_glu(ggml_glu_op op,
|
||||||
|
ggml_type type = GGML_TYPE_F32,
|
||||||
|
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
|
||||||
|
int v = 0,
|
||||||
|
bool swapped = false)
|
||||||
|
: op(op), type(type), ne_a(ne_a), v(v), swapped(swapped) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * a;
|
||||||
|
if (v & 1) {
|
||||||
|
auto ne = ne_a; ne[0] *= 3;
|
||||||
|
a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
|
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
|
||||||
|
ggml_set_name(a, "view_of_a");
|
||||||
|
} else {
|
||||||
|
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
|
ggml_set_name(a, "a");
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * out = ggml_glu(ctx, a, op, swapped);
|
||||||
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initialize_tensors(ggml_context * ctx) override {
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
// test extended range of values to check for NaNs in GELU
|
||||||
|
init_tensor_uniform(t, -150.f, 150.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
struct test_glu_split : public test_case {
|
||||||
|
const ggml_glu_op op;
|
||||||
|
const ggml_type type;
|
||||||
|
const std::array<int64_t, 4> ne_a;
|
||||||
|
int v; // view (1 : non-contiguous a)
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR3(type, ne_a, v) + ",split";
|
||||||
|
}
|
||||||
|
|
||||||
|
test_glu_split(ggml_glu_op op,
|
||||||
|
ggml_type type = GGML_TYPE_F32,
|
||||||
|
std::array<int64_t, 4> ne_a = {128, 2, 2, 2},
|
||||||
|
int v = 0)
|
||||||
|
: op(op), type(type), ne_a(ne_a), v(v) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * a;
|
||||||
|
ggml_tensor * b;
|
||||||
|
if (v & 1) {
|
||||||
|
auto ne = ne_a; ne[0] *= 3;
|
||||||
|
a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
|
a = ggml_view_4d(ctx, a, ne_a[0], ne_a[1], ne_a[2], ne_a[3], a->nb[1], a->nb[2], a->nb[3], 0);
|
||||||
|
ggml_set_name(a, "view_of_a");
|
||||||
|
|
||||||
|
b = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
|
ggml_set_name(b, "b");
|
||||||
|
|
||||||
|
b = ggml_view_4d(ctx, b, ne_a[0], ne_a[1], ne_a[2], ne_a[3], b->nb[1], b->nb[2], b->nb[3], 0);
|
||||||
|
ggml_set_name(a, "view_of_b");
|
||||||
|
} else {
|
||||||
|
a = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
|
ggml_set_name(a, "a");
|
||||||
|
|
||||||
|
b = ggml_new_tensor(ctx, type, 4, ne_a.data());
|
||||||
|
ggml_set_name(b, "b");
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_tensor * out = ggml_glu_split(ctx, a, b, op);
|
||||||
|
ggml_set_name(out, "out");
|
||||||
|
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
void initialize_tensors(ggml_context * ctx) override {
|
||||||
|
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
|
||||||
|
// test extended range of values to check for NaNs in GELU
|
||||||
|
init_tensor_uniform(t, -150.f, 150.f);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// GGML_OP_GET_ROWS
|
// GGML_OP_GET_ROWS
|
||||||
struct test_get_rows : public test_case {
|
struct test_get_rows : public test_case {
|
||||||
const ggml_type type;
|
const ggml_type type;
|
||||||
|
@ -4094,6 +4195,21 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// glu ops
|
||||||
|
for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) {
|
||||||
|
for (int v : {0, 1}) {
|
||||||
|
for (int op = 0; op < GGML_GLU_OP_COUNT; op++) {
|
||||||
|
for (bool swapped : {false, true}) {
|
||||||
|
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v, swapped));
|
||||||
|
test_cases.emplace_back(new test_glu((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v, swapped));
|
||||||
|
}
|
||||||
|
|
||||||
|
test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 128, 2, 2, 2 }, v));
|
||||||
|
test_cases.emplace_back(new test_glu_split((ggml_glu_op) op, type, { 5, 7, 11, 13 }, v));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
|
test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
|
||||||
for (ggml_type type : all_types) {
|
for (ggml_type type : all_types) {
|
||||||
for (int b : {1, 7}) {
|
for (int b : {1, 7}) {
|
||||||
|
|
Loading…
Reference in New Issue