mirror of https://github.com/Jittor/Jittor
polish log_softmax precision
This commit is contained in:
parent
4c5ac0fda9
commit
ba266fa99c
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.4.5'
|
||||
__version__ = '1.3.4.6'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -496,14 +496,12 @@ def softmax(x, dim=None, log=False):
|
|||
import jittor.other.code_softmax as code_softmax
|
||||
if code_softmax.can_softmax_v1(x, dim):
|
||||
return code_softmax.softmax_v1(x, log)
|
||||
if dim is None:
|
||||
x = (x - x.max()).exp()
|
||||
ret = x / x.sum()
|
||||
else:
|
||||
if dim is None: dim = ()
|
||||
if log:
|
||||
a = x-x.max(dim, keepdims=True)
|
||||
return a - a.exp().sum(dim, keepdims=True).log()
|
||||
x = (x-x.max(dim, keepdims=True)).exp()
|
||||
ret = x / x.sum(dim, keepdims=True)
|
||||
if log: return ret.log()
|
||||
return ret
|
||||
return x / x.sum(dim, keepdims=True)
|
||||
jt.Var.softmax = softmax
|
||||
|
||||
def log_softmax(x,dim=None):
|
||||
|
|
|
@ -39,6 +39,7 @@ def softmax_v1(a, log=False):
|
|||
''', cuda_src=f'''
|
||||
__global__ void kernel(in0_type* x, out0_type* y, int len) {{
|
||||
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
|
||||
constexpr int need_log = {int(log)};
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
int id = blockIdx.x * len;
|
||||
|
@ -62,9 +63,14 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
|
|||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
if (need_log) {{
|
||||
v[i][j] = float(v[i][j]) - vmax;
|
||||
v1 += expf(float(v[i][j]));
|
||||
}} else {{
|
||||
v[i][j] = expf(float(v[i][j]) - vmax);
|
||||
v1 += float(v[i][j]);
|
||||
}}
|
||||
}}
|
||||
|
||||
tmp = BlockReduce(temp_storage).Sum(v1);
|
||||
__shared__ float vsum;
|
||||
|
@ -74,11 +80,12 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
|
|||
|
||||
{for_loop}
|
||||
#pragma unroll
|
||||
for (int j=0; j<{ILP}; j++)
|
||||
v[i][j] = {
|
||||
"@expand_op(log,@in0_type,float(v[i][j])/vsum)" if log
|
||||
else "float(v[i][j])/vsum"
|
||||
};
|
||||
for (int j=0; j<{ILP}; j++) {{
|
||||
if (need_log)
|
||||
v[i][j] = v[i][j] - @expand_op(log,@in0_type,vsum);
|
||||
else
|
||||
v[i][j] = float(v[i][j])/vsum;
|
||||
}}
|
||||
{for_loop}
|
||||
vload<sizeof(in0_type)*{ILP}>(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]);
|
||||
}}
|
||||
|
|
|
@ -265,6 +265,19 @@ class TestOther(unittest.TestCase):
|
|||
z2 = np.arctan2(y.data, x.data)
|
||||
np.testing.assert_allclose(z.data, z2)
|
||||
|
||||
def test_softmax_precision(self):
|
||||
# jt.flags.use_cuda = 1
|
||||
a = -jt.array([1.0,2.0,1e5])
|
||||
b = a.log_softmax(0)
|
||||
assert b.isfinite().all().item()
|
||||
print("test_softmax_precision cpu ok")
|
||||
if not jt.has_cuda: return
|
||||
jt.flags.use_cuda = 1
|
||||
a = -jt.array([1.0,2.0,1e5])
|
||||
b = a.log_softmax(0)
|
||||
assert b.isfinite().all().item()
|
||||
print("test_softmax_precision gpu ok")
|
||||
|
||||
def test_code_softmax(self):
|
||||
if not jt.has_cuda: return
|
||||
|
||||
|
|
Loading…
Reference in New Issue