ggml : add ggml_scale_bias (#14417)

* ggml : add ggml_scale_bias

* ggml_vec_mad1_f32

* add more simd

* add CUDA

* sycl

* vulkan

* cann (placeholder)

* opencl

* will this fix cpu?

* fix cuda

* suggestions from coderabbit

* fix cann compile error

* vDSP_vsmsa

* rm __ARM_FEATURE_SVE

* use memcpy for op params

* make code looks more consistent

* use scalar for __ARM_FEATURE_SVE

* add x param to ggml_vec_mad1_f32
This commit is contained in:
Xuan-Son Nguyen 2025-07-09 18:16:12 +02:00 committed by GitHub
parent 26a48ad699
commit 98bab638fb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 139 additions and 38 deletions

View File

@ -1297,6 +1297,19 @@ extern "C" {
struct ggml_tensor * a,
float s);
// x = s * a + b
GGML_API struct ggml_tensor * ggml_scale_bias(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b);
GGML_API struct ggml_tensor * ggml_scale_bias_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b);
// b -> view(a,offset,nb1,nb2,3), return modified a
GGML_API struct ggml_tensor * ggml_set(
struct ggml_context * ctx,

View File

@ -2188,7 +2188,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_CLAMP:
@ -2210,6 +2209,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
case GGML_OP_SCALE:
float bias;
memcpy(&bias, (float*)op->op_params + 1, sizeof(float));
return bias == 0.0f; // TODO: support bias != 0.0f
case GGML_OP_SOFT_MAX:
// TODO: support broadcast
// ref: https://github.com/ggml-org/llama.cpp/pull/14435

View File

@ -4643,9 +4643,11 @@ static void ggml_compute_forward_scale_f32(
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_are_same_shape(src0, dst));
// scale factor
float v;
memcpy(&v, dst->op_params, sizeof(float));
float s; // scale factor
float b; // bias
memcpy(&s, (float *) dst->op_params + 0, sizeof(float));
memcpy(&b, (float *) dst->op_params + 1, sizeof(float));
const int ith = params->ith;
const int nth = params->nth;
@ -4664,12 +4666,22 @@ static void ggml_compute_forward_scale_f32(
const size_t nb1 = dst->nb[1];
for (int i1 = ir0; i1 < ir1; i1++) {
if (dst->data != src0->data) {
// src0 is same shape as dst => same indices
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
if (b == 0.0f) {
for (int i1 = ir0; i1 < ir1; i1++) {
if (dst->data != src0->data) {
// src0 is same shape as dst => same indices
// TODO: add x parameter to ggml_vec_scale_f32 and remove this memcpy
memcpy((char *)dst->data + i1*nb1, (char *)src0->data + i1*nb01, nc * sizeof(float));
}
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), s);
}
} else {
for (int i1 = ir0; i1 < ir1; i1++) {
ggml_vec_mad1_f32(nc,
(float *) ((char *) dst->data + i1*nb1),
(float *) ((char *) src0->data + i1*nb1),
s, b);
}
ggml_vec_scale_f32(nc, (float *) ((char *) dst->data + i1*nb1), v);
}
}

View File

@ -351,6 +351,45 @@ inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int
#endif
}
inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, const float s, const float b) {
#if defined(GGML_USE_ACCELERATE)
vDSP_vsmsa(x, 1, &s, &b, y, 1, n);
#elif defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
// scalar ; TODO: Write SVE code
for (int i = 0; i < n; ++i) {
y[i] = x[i]*s + b;
}
#else
const int np = (n & ~(GGML_F32_STEP - 1));
GGML_F32_VEC vs = GGML_F32_VEC_SET1(s);
GGML_F32_VEC vb = GGML_F32_VEC_SET1(b);
GGML_F32_VEC ay[GGML_F32_ARR];
for (int i = 0; i < np; i += GGML_F32_STEP) {
for (int j = 0; j < GGML_F32_ARR; j++) {
ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR);
ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb);
GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = x[i]*s + b;
}
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
y[i] = x[i]*s + b;
}
#endif
}
//inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { for (int i = 0; i < n; ++i) y[i] *= v; }
inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
#if defined(GGML_USE_ACCELERATE)

View File

@ -1,18 +1,18 @@
#include "scale.cuh"
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = scale * x[i];
dst[i] = scale * x[i] + bias;
}
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
}
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -25,7 +25,9 @@ void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT( dst->type == GGML_TYPE_F32);
float scale;
memcpy(&scale, dst->op_params, sizeof(float));
float bias;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
scale_f32_cuda(src0_d, dst_d, scale, bias, ggml_nelements(src0), stream);
}

View File

@ -2256,7 +2256,9 @@ static bool ggml_metal_encode_node(
GGML_ASSERT(ggml_is_contiguous(src0));
float scale;
memcpy(&scale, dst->op_params, sizeof(scale));
float bias;
memcpy(&scale, ((const int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&bias, ((const int32_t *) dst->op_params) + 1, sizeof(float));
int64_t n = ggml_nelements(dst);
@ -2273,6 +2275,7 @@ static bool ggml_metal_encode_node(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
[encoder setBytes:&bias length:sizeof(bias) atIndex:3];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;

View File

@ -1014,16 +1014,18 @@ kernel void kernel_scale(
device const float * src0,
device float * dst,
constant float & scale,
constant float & bias,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
dst[tpig] = src0[tpig] * scale + bias;
}
kernel void kernel_scale_4(
device const float4 * src0,
device float4 * dst,
constant float & scale,
constant float & bias,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
dst[tpig] = src0[tpig] * scale + bias;
}
kernel void kernel_clamp(

View File

@ -5587,7 +5587,9 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
float scale;
memcpy(&scale, dst->op_params, sizeof(scale));
float bias;
memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(float));
memcpy(&bias, ((int32_t *) dst->op_params) + 1, sizeof(float));
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
@ -5602,6 +5604,7 @@ static void ggml_cl_scale(ggml_backend_t backend, const ggml_tensor * src0, cons
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(float), &scale));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(float), &bias));
int n = ggml_nelements(dst)/4;

View File

@ -8,9 +8,10 @@ kernel void kernel_scale(
ulong offset0,
global float4 * dst,
ulong offsetd,
float scale
float scale,
float bias
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
dst[get_global_id(0)] = src0[get_global_id(0)] * scale;
dst[get_global_id(0)] = src0[get_global_id(0)] * scale + bias;
}

View File

@ -1695,7 +1695,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
}
static void scale_f32(const float * x, float * dst, const float scale, const int k,
static void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k,
const sycl::nd_item<3> &item_ct1) {
const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
item_ct1.get_local_id(2);
@ -1704,7 +1704,7 @@ static void scale_f32(const float * x, float * dst, const float scale, const int
return;
}
dst[i] = scale * x[i];
dst[i] = scale * x[i] + bias;
}
@ -1842,7 +1842,7 @@ static void ggml_mul_mat_vec_nc_f16_f32_sycl(
static void scale_f32_sycl(const float *x, float *dst, const float scale,
static void scale_f32_sycl(const float *x, float *dst, const float scale, const float bias,
const int k, queue_ptr stream) {
const int num_blocks = (k + SYCL_SCALE_BLOCK_SIZE - 1) / SYCL_SCALE_BLOCK_SIZE;
stream->parallel_for(
@ -1850,7 +1850,7 @@ static void scale_f32_sycl(const float *x, float *dst, const float scale,
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE),
sycl::range<3>(1, 1, SYCL_SCALE_BLOCK_SIZE)),
[=](sycl::nd_item<3> item_ct1) {
scale_f32(x, dst, scale, k, item_ct1);
scale_f32(x, dst, scale, bias, k, item_ct1);
});
}
@ -2319,9 +2319,11 @@ inline void ggml_sycl_op_scale(ggml_backend_sycl_context & ctx, ggml_tensor * ds
float * dst_dd = static_cast<float *>(dst->data);
float scale;
memcpy(&scale, dst->op_params, sizeof(float));
float bias;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&bias, (float *) dst->op_params + 1, sizeof(float));
scale_f32_sycl(src0_dd, dst_dd, scale, ggml_nelements(dst->src[0]), main_stream);
scale_f32_sycl(src0_dd, dst_dd, scale, bias, ggml_nelements(dst->src[0]), main_stream);
/*
DPCT1010:87: SYCL uses exceptions to report errors and does not use the
error codes. The call was replaced with 0. You need to rewrite this code.

View File

@ -7508,7 +7508,7 @@ static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, con
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
op_params[0], 0.0f,
op_params[0], op_params[1],
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
}, dryrun);
}

View File

@ -18,7 +18,7 @@ void main() {
continue;
}
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1));
data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2));
idx += num_threads;
}
}

View File

@ -3069,12 +3069,14 @@ static struct ggml_tensor * ggml_scale_impl(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b,
bool inplace) {
GGML_ASSERT(ggml_is_padded_1d(a));
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
ggml_set_op_params(result, &s, sizeof(s));
float params[2] = { s, b };
ggml_set_op_params(result, &params, sizeof(params));
result->op = GGML_OP_SCALE;
result->src[0] = a;
@ -3086,14 +3088,30 @@ struct ggml_tensor * ggml_scale(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s) {
return ggml_scale_impl(ctx, a, s, false);
return ggml_scale_impl(ctx, a, s, 0.0, false);
}
struct ggml_tensor * ggml_scale_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s) {
return ggml_scale_impl(ctx, a, s, true);
return ggml_scale_impl(ctx, a, s, 0.0, true);
}
struct ggml_tensor * ggml_scale_bias(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b) {
return ggml_scale_impl(ctx, a, s, b, false);
}
struct ggml_tensor * ggml_scale_bias_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a,
float s,
float b) {
return ggml_scale_impl(ctx, a, s, b, true);
}
// ggml_set
@ -5777,7 +5795,7 @@ static void ggml_compute_backward(
} break;
case GGML_OP_MEAN: {
if (src0_needs_grads) {
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false));
ggml_add1_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], 0.0, false));
}
} break;
case GGML_OP_REPEAT: {
@ -5854,7 +5872,7 @@ static void ggml_compute_backward(
if (src0_needs_grads) {
float s;
memcpy(&s, tensor->op_params, sizeof(float));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false));
ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, 0.0, false));
}
} break;
case GGML_OP_SET: {

View File

@ -2368,22 +2368,24 @@ struct test_scale : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;
float scale;
float bias;
std::string vars() override {
return VARS_TO_STR3(type, ne, scale);
return VARS_TO_STR4(type, ne, scale, bias);
}
test_scale(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {10, 10, 10, 10},
float scale = 2.0f)
: type(type), ne(ne), scale(scale) {}
float scale = 2.0f,
float bias = 0.0f)
: type(type), ne(ne), scale(scale), bias(bias) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(a);
ggml_set_name(a, "a");
ggml_tensor * out = ggml_scale(ctx, a, scale);
ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias);
ggml_set_name(out, "out");
return out;
@ -5044,6 +5046,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_add1());
test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f));
test_cases.emplace_back(new test_silu_back());
for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {