polish log_softmax precision

This commit is contained in:
Dun Liang 2022-05-18 15:37:09 +08:00
parent 4c5ac0fda9
commit ba266fa99c
4 changed files with 34 additions and 16 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.3.4.5' __version__ = '1.3.4.6'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -496,14 +496,12 @@ def softmax(x, dim=None, log=False):
import jittor.other.code_softmax as code_softmax import jittor.other.code_softmax as code_softmax
if code_softmax.can_softmax_v1(x, dim): if code_softmax.can_softmax_v1(x, dim):
return code_softmax.softmax_v1(x, log) return code_softmax.softmax_v1(x, log)
if dim is None: if dim is None: dim = ()
x = (x - x.max()).exp() if log:
ret = x / x.sum() a = x-x.max(dim, keepdims=True)
else: return a - a.exp().sum(dim, keepdims=True).log()
x = (x-x.max(dim, keepdims=True)).exp() x = (x-x.max(dim, keepdims=True)).exp()
ret = x / x.sum(dim, keepdims=True) return x / x.sum(dim, keepdims=True)
if log: return ret.log()
return ret
jt.Var.softmax = softmax jt.Var.softmax = softmax
def log_softmax(x,dim=None): def log_softmax(x,dim=None):

View File

@ -39,6 +39,7 @@ def softmax_v1(a, log=False):
''', cuda_src=f''' ''', cuda_src=f'''
__global__ void kernel(in0_type* x, out0_type* y, int len) {{ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
typedef cub::BlockReduce<float, {tnum}> BlockReduce; typedef cub::BlockReduce<float, {tnum}> BlockReduce;
constexpr int need_log = {int(log)};
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
int id = blockIdx.x * len; int id = blockIdx.x * len;
@ -62,8 +63,13 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
{for_loop} {for_loop}
#pragma unroll #pragma unroll
for (int j=0; j<{ILP}; j++) {{ for (int j=0; j<{ILP}; j++) {{
v[i][j] = expf(float(v[i][j]) - vmax); if (need_log) {{
v1 += float(v[i][j]); v[i][j] = float(v[i][j]) - vmax;
v1 += expf(float(v[i][j]));
}} else {{
v[i][j] = expf(float(v[i][j]) - vmax);
v1 += float(v[i][j]);
}}
}} }}
tmp = BlockReduce(temp_storage).Sum(v1); tmp = BlockReduce(temp_storage).Sum(v1);
@ -74,11 +80,12 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
{for_loop} {for_loop}
#pragma unroll #pragma unroll
for (int j=0; j<{ILP}; j++) for (int j=0; j<{ILP}; j++) {{
v[i][j] = { if (need_log)
"@expand_op(log,@in0_type,float(v[i][j])/vsum)" if log v[i][j] = v[i][j] - @expand_op(log,@in0_type,vsum);
else "float(v[i][j])/vsum" else
}; v[i][j] = float(v[i][j])/vsum;
}}
{for_loop} {for_loop}
vload<sizeof(in0_type)*{ILP}>(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]); vload<sizeof(in0_type)*{ILP}>(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]);
}} }}

View File

@ -265,6 +265,19 @@ class TestOther(unittest.TestCase):
z2 = np.arctan2(y.data, x.data) z2 = np.arctan2(y.data, x.data)
np.testing.assert_allclose(z.data, z2) np.testing.assert_allclose(z.data, z2)
def test_softmax_precision(self):
# jt.flags.use_cuda = 1
a = -jt.array([1.0,2.0,1e5])
b = a.log_softmax(0)
assert b.isfinite().all().item()
print("test_softmax_precision cpu ok")
if not jt.has_cuda: return
jt.flags.use_cuda = 1
a = -jt.array([1.0,2.0,1e5])
b = a.log_softmax(0)
assert b.isfinite().all().item()
print("test_softmax_precision gpu ok")
def test_code_softmax(self): def test_code_softmax(self):
if not jt.has_cuda: return if not jt.has_cuda: return