polish log_softmax precision

This commit is contained in:
Dun Liang 2022-05-18 15:37:09 +08:00
parent 4c5ac0fda9
commit ba266fa99c
4 changed files with 34 additions and 16 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.4.5'
__version__ = '1.3.4.6'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -496,14 +496,12 @@ def softmax(x, dim=None, log=False):
import jittor.other.code_softmax as code_softmax
if code_softmax.can_softmax_v1(x, dim):
return code_softmax.softmax_v1(x, log)
if dim is None:
x = (x - x.max()).exp()
ret = x / x.sum()
else:
if dim is None: dim = ()
if log:
a = x-x.max(dim, keepdims=True)
return a - a.exp().sum(dim, keepdims=True).log()
x = (x-x.max(dim, keepdims=True)).exp()
ret = x / x.sum(dim, keepdims=True)
if log: return ret.log()
return ret
return x / x.sum(dim, keepdims=True)
jt.Var.softmax = softmax
def log_softmax(x,dim=None):

View File

@ -39,6 +39,7 @@ def softmax_v1(a, log=False):
''', cuda_src=f'''
__global__ void kernel(in0_type* x, out0_type* y, int len) {{
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
constexpr int need_log = {int(log)};
__shared__ typename BlockReduce::TempStorage temp_storage;
int id = blockIdx.x * len;
@ -62,9 +63,14 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
{for_loop}
#pragma unroll
for (int j=0; j<{ILP}; j++) {{
if (need_log) {{
v[i][j] = float(v[i][j]) - vmax;
v1 += expf(float(v[i][j]));
}} else {{
v[i][j] = expf(float(v[i][j]) - vmax);
v1 += float(v[i][j]);
}}
}}
tmp = BlockReduce(temp_storage).Sum(v1);
__shared__ float vsum;
@ -74,11 +80,12 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
{for_loop}
#pragma unroll
for (int j=0; j<{ILP}; j++)
v[i][j] = {
"@expand_op(log,@in0_type,float(v[i][j])/vsum)" if log
else "float(v[i][j])/vsum"
};
for (int j=0; j<{ILP}; j++) {{
if (need_log)
v[i][j] = v[i][j] - @expand_op(log,@in0_type,vsum);
else
v[i][j] = float(v[i][j])/vsum;
}}
{for_loop}
vload<sizeof(in0_type)*{ILP}>(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]);
}}

View File

@ -265,6 +265,19 @@ class TestOther(unittest.TestCase):
z2 = np.arctan2(y.data, x.data)
np.testing.assert_allclose(z.data, z2)
def test_softmax_precision(self):
# jt.flags.use_cuda = 1
a = -jt.array([1.0,2.0,1e5])
b = a.log_softmax(0)
assert b.isfinite().all().item()
print("test_softmax_precision cpu ok")
if not jt.has_cuda: return
jt.flags.use_cuda = 1
a = -jt.array([1.0,2.0,1e5])
b = a.log_softmax(0)
assert b.isfinite().all().item()
print("test_softmax_precision gpu ok")
def test_code_softmax(self):
if not jt.has_cuda: return