mirror of https://github.com/Jittor/Jittor
support cross entropy with large input
This commit is contained in:
parent
9b20ee33b2
commit
f68620108a
|
@ -370,6 +370,8 @@ class PReLU(Module):
|
|||
else:
|
||||
return jt.maximum(0, x) + self.weight * jt.minimum(0, x)
|
||||
|
||||
import jittor.other.code_cross_entropy as code_cross_entropy
|
||||
|
||||
#TODO dims is 4 will cause slowly execution
|
||||
def cross_entropy_loss(output, target, weight=None, ignore_index=None,reduction='mean'):
|
||||
target_shape = target.shape
|
||||
|
@ -379,7 +381,7 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None,reduction=
|
|||
output = output.reshape((-1, c_dim))
|
||||
|
||||
target = target.reshape((-1, ))
|
||||
target_weight = ((target >= 0) & (target < output.shape[1])).float32()
|
||||
target_weight = ((target >= 0) & (target < output.shape[1])).astype(output.dtype)
|
||||
if weight is not None:
|
||||
target_weight = weight[target]
|
||||
if ignore_index is not None:
|
||||
|
@ -389,16 +391,12 @@ def cross_entropy_loss(output, target, weight=None, ignore_index=None,reduction=
|
|||
target_weight
|
||||
)
|
||||
|
||||
import jittor.other.code_cross_entropy as code_cross_entropy
|
||||
if code_cross_entropy.can_cross_entropy(output, -1):
|
||||
cross_entropy = code_cross_entropy.cross_entropy(output, target)
|
||||
else:
|
||||
target = target.broadcast(output, [1])
|
||||
target = target.index(1) == target
|
||||
|
||||
output = output - output.max([1], keepdims=True)
|
||||
logsum = output.exp().sum(1).log()
|
||||
cross_entropy = (logsum - (output*target).sum(1))
|
||||
target = target.broadcast(output, [1])
|
||||
target = target.index(1) == target
|
||||
|
||||
output = output - output.max([1], keepdims=True)
|
||||
logsum = output.exp().sum(1).log()
|
||||
cross_entropy = (logsum - (output*target).sum(1))
|
||||
|
||||
loss = cross_entropy * target_weight
|
||||
if reduction == 'sum':
|
||||
|
|
|
@ -2,37 +2,9 @@ import jittor as jt
|
|||
from jittor import nn
|
||||
import numpy as np
|
||||
|
||||
def can_cross_entropy(a, dim):
|
||||
if not jt.flags.use_cuda:
|
||||
return False
|
||||
if dim != -1 and dim != len(a.shape)-1:
|
||||
return False
|
||||
if a.shape[-1] > 10000 and np.prod(a.shape[:-1]) < 64:
|
||||
return False
|
||||
return True
|
||||
|
||||
def cross_entropy(output, target):
|
||||
assert can_cross_entropy(output, -1)
|
||||
length = output.shape[-1]
|
||||
|
||||
if length < 65536:
|
||||
tnum = 250 if length % 250 == 0 else 256
|
||||
else:
|
||||
tnum = 125 if length % 125 == 0 else 128
|
||||
|
||||
per_thread = (length-1) // tnum + 1
|
||||
ILP = 1
|
||||
for ilp in [8,4,2]:
|
||||
if length % tnum == 0 and per_thread % ilp == 0:
|
||||
ILP = ilp
|
||||
per_thread //= ILP
|
||||
break
|
||||
for_loop = f"""
|
||||
#pragma unroll
|
||||
for (int i=0; i<{per_thread}; i++)
|
||||
"""
|
||||
if length % tnum != 0:
|
||||
for_loop += f"if ((i*{tnum}+threadIdx.x)*{ILP} < len)\n"
|
||||
tnum = min(512, output.shape[-1])
|
||||
|
||||
class CodeCrossEntropy(jt.Function):
|
||||
def execute(self, x, target):
|
||||
|
@ -42,21 +14,15 @@ def cross_entropy(output, target):
|
|||
#include <type/fp16_compute.h>
|
||||
#include <helper_cuda.h>
|
||||
''', cuda_src=f'''
|
||||
__global__ void kernel(in0_type* x, in1_type* target, out0_type* y, int len) {{
|
||||
__global__ void kernel(in0_type* x, in1_type* target, out0_type* y, size_t len) {{
|
||||
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
int id = blockIdx.x * len;
|
||||
in0_type v[{per_thread}][{ILP}];
|
||||
{for_loop}
|
||||
vload<sizeof(in0_type)*{ILP}>(v[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]);
|
||||
size_t id = blockIdx.x * len;
|
||||
|
||||
float v1 = -1e30;
|
||||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
v1 = max(v1, float(v[i][j]));
|
||||
}}
|
||||
for (size_t i = threadIdx.x; i < len; i += blockDim.x)
|
||||
v1 = ::max(v1, float(x[id + i]));
|
||||
|
||||
__shared__ float vmax;
|
||||
auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max());
|
||||
|
@ -65,146 +31,37 @@ __global__ void kernel(in0_type* x, in1_type* target, out0_type* y, int len) {{
|
|||
__syncthreads();
|
||||
|
||||
v1 = 0;
|
||||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
v1 += expf(float(float(v[i][j]) - vmax));
|
||||
}}
|
||||
|
||||
auto vsum = BlockReduce(temp_storage).Sum(v1);
|
||||
if (threadIdx.x == 0)
|
||||
y[blockIdx.x] = -float(x[id+target[blockIdx.x]]) + vmax + float(@expand_op(log,@in0_type,vsum));
|
||||
}}
|
||||
int len = in0->shape[in0->shape.size()-1];
|
||||
int bnum = in0->numel() / len;
|
||||
cudaGetLastError();
|
||||
kernel<<<bnum, {tnum}>>>(in0_p, in1_p, out0_p, len);
|
||||
getLastCudaError("Failed to run CodeCrossEntropy forward");
|
||||
''')
|
||||
return cross_entropy
|
||||
|
||||
def grad(self, grad):
|
||||
x, target = self.save_vars
|
||||
return jt.code(x.shape, x.dtype, [x, target, grad], cuda_header=f'''
|
||||
#include <{jt.compile_extern.cub_home}cub/cub.cuh>
|
||||
#include <type/fp16_compute.h>
|
||||
#include <helper_cuda.h>
|
||||
''', cuda_src=f'''
|
||||
__global__ void kernel(in0_type* x, in1_type* target, in2_type* grad, out0_type* y, int len) {{
|
||||
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
int id = blockIdx.x * len;
|
||||
in0_type v[{per_thread}][{ILP}];
|
||||
{for_loop}
|
||||
vload<sizeof(in0_type)*{ILP}>(v[i], &x[id+(i*{tnum}+threadIdx.x)*{ILP}]);
|
||||
float v1 = -1e30;
|
||||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
v1 = max(v1, float(v[i][j]));
|
||||
}}
|
||||
__shared__ float vmax;
|
||||
auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max());
|
||||
if (threadIdx.x == 0)
|
||||
vmax = tmp;
|
||||
__syncthreads();
|
||||
|
||||
v1 = 0;
|
||||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
v[i][j] = expf(float(v[i][j]) - vmax);
|
||||
v1 += float(v[i][j]);
|
||||
}}
|
||||
|
||||
tmp = BlockReduce(temp_storage).Sum(v1);
|
||||
__shared__ float vsum;
|
||||
if (threadIdx.x == 0)
|
||||
vsum = tmp;
|
||||
__syncthreads();
|
||||
|
||||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++)
|
||||
v[i][j] = float(v[i][j])/vsum * float(grad[blockIdx.x]);
|
||||
|
||||
{for_loop}
|
||||
vload<sizeof(out0_type)*{ILP}>(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == blockIdx.x)
|
||||
y[id + target[blockIdx.x]] -= grad[blockIdx.x];
|
||||
}}
|
||||
int len = in0->shape[in0->shape.size()-1];
|
||||
int bnum = in0->numel() / len;
|
||||
cudaGetLastError();
|
||||
kernel<<<bnum, {tnum}>>>(in0_p, in1_p, in2_p, out0_p, len);
|
||||
getLastCudaError("Failed to run CodeCrossEntropy backward");
|
||||
''')
|
||||
return CodeCrossEntropy()(output, target)
|
||||
|
||||
|
||||
def cross_entropy_v2(output, target):
|
||||
class CodeCrossEntropy(jt.Function):
|
||||
def execute(self, x, target):
|
||||
self.save_vars = [x, target]
|
||||
cross_entropy = jt.code(target.shape, x.dtype, [x, target], cuda_header=f'''
|
||||
#include <{jt.compile_extern.cub_home}cub/cub.cuh>
|
||||
#include <type/fp16_compute.h>
|
||||
#include <helper_cuda.h>
|
||||
''', cuda_src=f'''
|
||||
__global__ void kernel(in0_type* x, in1_type* target, out0_type* y, int len) {{
|
||||
typedef cub::BlockReduce<float, 1024> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
int id = blockIdx.x * len;
|
||||
|
||||
float v1 = -1e30;
|
||||
for (int i = threadIdx.x; i < len; i += blockDim.x)
|
||||
v1 = max(v1, float(x[id + i]));
|
||||
|
||||
__shared__ float vmax;
|
||||
auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max());
|
||||
if (threadIdx.x == 0)
|
||||
vmax = tmp;
|
||||
__syncthreads();
|
||||
|
||||
v1 = 0;
|
||||
for (int i = threadIdx.x; i < len; i += blockDim.x)
|
||||
for (size_t i = threadIdx.x; i < len; i += blockDim.x)
|
||||
v1 += expf(float(float(x[id + i]) - vmax));
|
||||
|
||||
auto vsum = BlockReduce(temp_storage).Sum(v1);
|
||||
if (threadIdx.x == 0)
|
||||
y[blockIdx.x] = -float(x[id+target[blockIdx.x]]) + vmax + float(@expand_op(log,@in0_type,vsum));
|
||||
}}
|
||||
int len = in0->shape[in0->shape.size()-1];
|
||||
int bnum = in0->numel() / len;
|
||||
size_t len = in0->shape[in0->shape.size()-1];
|
||||
size_t bnum = in0->numel() / len;
|
||||
cudaGetLastError();
|
||||
kernel<<<bnum, 1024>>>(in0_p, in1_p, out0_p, len);
|
||||
kernel<<<bnum, {tnum}>>>(in0_p, in1_p, out0_p, len);
|
||||
getLastCudaError("Failed to run CodeCrossEntropy forward");
|
||||
''')
|
||||
return cross_entropy
|
||||
|
||||
def grad(self, grad):
|
||||
x, target = self.save_vars
|
||||
# target = target.broadcast(x, [1])
|
||||
# target = target.index(1) == target
|
||||
# return (jt.nn.softmax(x, dim=1) - target) * grad.broadcast(x, [1])
|
||||
return jt.code(x.shape, x.dtype, [x, target, grad], cuda_header=f'''
|
||||
#include <{jt.compile_extern.cub_home}cub/cub.cuh>
|
||||
#include <type/fp16_compute.h>
|
||||
#include <helper_cuda.h>
|
||||
''', cuda_src=f'''
|
||||
__global__ void kernel(in0_type* x, in1_type* target, in2_type* grad, out0_type* y, int len) {{
|
||||
typedef cub::BlockReduce<float, 1024> BlockReduce;
|
||||
__global__ void kernel(in0_type* x, in1_type* target, in2_type* grad, out0_type* y, size_t len) {{
|
||||
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
int id = blockIdx.x * len;
|
||||
float v1 = -1e30;
|
||||
for (int i = threadIdx.x; i < len; i += blockDim.x)
|
||||
v1 = max(v1, float(x[id + i]));
|
||||
size_t id = blockIdx.x * len;
|
||||
|
||||
float v1 = -1e30;
|
||||
for (size_t i = threadIdx.x; i < len; i += blockDim.x)
|
||||
v1 = ::max(v1, float(x[id + i]));
|
||||
__shared__ float vmax;
|
||||
auto tmp = BlockReduce(temp_storage).Reduce(v1, cub::Max());
|
||||
if (threadIdx.x == 0)
|
||||
|
@ -212,35 +69,28 @@ __global__ void kernel(in0_type* x, in1_type* target, in2_type* grad, out0_type*
|
|||
__syncthreads();
|
||||
|
||||
v1 = 0;
|
||||
for (int i = threadIdx.x; i < len; i += blockDim.x) {{
|
||||
float _x = expf(float(x[id + i]) - vmax);
|
||||
y[id + i] = _x;
|
||||
v1 += _x;
|
||||
for (size_t i = threadIdx.x; i < len; i += blockDim.x) {{
|
||||
y[id + i] = expf(float(x[id + i]) - vmax);
|
||||
v1 += float(y[id + i]);
|
||||
}}
|
||||
|
||||
tmp = BlockReduce(temp_storage).Sum(v1);
|
||||
__shared__ float vsum;
|
||||
if (threadIdx.x == 0) {{
|
||||
vsum = tmp;
|
||||
}}
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0)
|
||||
if (vsum != vsum)
|
||||
printf("found nan! %d\\n", threadIdx.x);
|
||||
|
||||
for (int i = threadIdx.x; i < len; i += blockDim.x) {{
|
||||
y[id + i] = float(y[id + i]) * float(grad[blockIdx.x]) / vsum;
|
||||
}}
|
||||
vsum = tmp;
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {{
|
||||
y[id + target[blockIdx.x]] -= (out0_type) grad[blockIdx.x];
|
||||
}}
|
||||
|
||||
for (size_t i = threadIdx.x; i < len; i += blockDim.x)
|
||||
y[id + i] = float(y[id + i]) / vsum * float(grad[blockIdx.x]);
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0)
|
||||
y[id + target[blockIdx.x]] -= grad[blockIdx.x];
|
||||
}}
|
||||
int len = in0->shape[in0->shape.size()-1];
|
||||
int bnum = in0->numel() / len;
|
||||
size_t len = in0->shape[in0->shape.size()-1];
|
||||
size_t bnum = in0->numel() / len;
|
||||
cudaGetLastError();
|
||||
kernel<<<bnum, 1024>>>(in0_p, in1_p, in2_p, out0_p, len);
|
||||
kernel<<<bnum, {tnum}>>>(in0_p, in1_p, in2_p, out0_p, len);
|
||||
getLastCudaError("Failed to run CodeCrossEntropy backward");
|
||||
''')
|
||||
return CodeCrossEntropy()(output, target)
|
||||
|
|
|
@ -343,13 +343,9 @@ class TestOther(unittest.TestCase):
|
|||
jt.array(0).broadcast(target_weight),
|
||||
target_weight
|
||||
)
|
||||
|
||||
target = target.broadcast(output, [1])
|
||||
target = target.index(1) == target
|
||||
|
||||
output = output - output.max([1], keepdims=True)
|
||||
logsum = output.exp().sum(1).log()
|
||||
cross_entropy = (logsum - (output*target).sum(1))
|
||||
|
||||
import jittor.other.code_cross_entropy as code_cross_entropy
|
||||
cross_entropy = code_cross_entropy.cross_entropy(output, target)
|
||||
|
||||
loss = cross_entropy * target_weight
|
||||
if reduction == 'sum':
|
||||
|
@ -363,7 +359,7 @@ class TestOther(unittest.TestCase):
|
|||
jt.set_global_seed(42)
|
||||
|
||||
with jt.flag_scope(use_cuda = 1):
|
||||
for dtype in ["float16", "float32"]:
|
||||
for dtype in ["float16", "bfloat16", "float32"]:
|
||||
for shape in [(3, 3), (200, 2000), (200, 2049), (16380, 65000)]:
|
||||
print(shape)
|
||||
x = jt.rand(shape, dtype=dtype)
|
||||
|
@ -374,8 +370,8 @@ class TestOther(unittest.TestCase):
|
|||
d2 = jt.grad(bb, x)
|
||||
jt.sync_all(True)
|
||||
|
||||
np.testing.assert_allclose(bb.data, b.data, rtol=1e-3, atol=1e-5)
|
||||
np.testing.assert_allclose(d1.data, d2.data, rtol=1e-3, atol=1e-3)
|
||||
np.testing.assert_allclose(bb.astype(jt.float32).data, b.astype(jt.float32).data, rtol=1e-3, atol=1e-2)
|
||||
np.testing.assert_allclose(bb.astype(jt.float32).data, b.astype(jt.float32).data, rtol=1e-3, atol=1e-2)
|
||||
|
||||
|
||||
def test_nan(self):
|
||||
|
|
Loading…
Reference in New Issue