mirror of https://github.com/Jittor/Jittor
polish log_softmax precision
This commit is contained in:
parent
4c5ac0fda9
commit
ba266fa99c
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.3.4.5'
|
__version__ = '1.3.4.6'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -496,14 +496,12 @@ def softmax(x, dim=None, log=False):
|
||||||
import jittor.other.code_softmax as code_softmax
|
import jittor.other.code_softmax as code_softmax
|
||||||
if code_softmax.can_softmax_v1(x, dim):
|
if code_softmax.can_softmax_v1(x, dim):
|
||||||
return code_softmax.softmax_v1(x, log)
|
return code_softmax.softmax_v1(x, log)
|
||||||
if dim is None:
|
if dim is None: dim = ()
|
||||||
x = (x - x.max()).exp()
|
if log:
|
||||||
ret = x / x.sum()
|
a = x-x.max(dim, keepdims=True)
|
||||||
else:
|
return a - a.exp().sum(dim, keepdims=True).log()
|
||||||
x = (x-x.max(dim, keepdims=True)).exp()
|
x = (x-x.max(dim, keepdims=True)).exp()
|
||||||
ret = x / x.sum(dim, keepdims=True)
|
return x / x.sum(dim, keepdims=True)
|
||||||
if log: return ret.log()
|
|
||||||
return ret
|
|
||||||
jt.Var.softmax = softmax
|
jt.Var.softmax = softmax
|
||||||
|
|
||||||
def log_softmax(x,dim=None):
|
def log_softmax(x,dim=None):
|
||||||
|
|
|
@ -39,6 +39,7 @@ def softmax_v1(a, log=False):
|
||||||
''', cuda_src=f'''
|
''', cuda_src=f'''
|
||||||
__global__ void kernel(in0_type* x, out0_type* y, int len) {{
|
__global__ void kernel(in0_type* x, out0_type* y, int len) {{
|
||||||
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
|
typedef cub::BlockReduce<float, {tnum}> BlockReduce;
|
||||||
|
constexpr int need_log = {int(log)};
|
||||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||||
|
|
||||||
int id = blockIdx.x * len;
|
int id = blockIdx.x * len;
|
||||||
|
@ -62,8 +63,13 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
|
||||||
{for_loop}
|
{for_loop}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j=0; j<{ILP}; j++) {{
|
for (int j=0; j<{ILP}; j++) {{
|
||||||
v[i][j] = expf(float(v[i][j]) - vmax);
|
if (need_log) {{
|
||||||
v1 += float(v[i][j]);
|
v[i][j] = float(v[i][j]) - vmax;
|
||||||
|
v1 += expf(float(v[i][j]));
|
||||||
|
}} else {{
|
||||||
|
v[i][j] = expf(float(v[i][j]) - vmax);
|
||||||
|
v1 += float(v[i][j]);
|
||||||
|
}}
|
||||||
}}
|
}}
|
||||||
|
|
||||||
tmp = BlockReduce(temp_storage).Sum(v1);
|
tmp = BlockReduce(temp_storage).Sum(v1);
|
||||||
|
@ -74,11 +80,12 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{
|
||||||
|
|
||||||
{for_loop}
|
{for_loop}
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j=0; j<{ILP}; j++)
|
for (int j=0; j<{ILP}; j++) {{
|
||||||
v[i][j] = {
|
if (need_log)
|
||||||
"@expand_op(log,@in0_type,float(v[i][j])/vsum)" if log
|
v[i][j] = v[i][j] - @expand_op(log,@in0_type,vsum);
|
||||||
else "float(v[i][j])/vsum"
|
else
|
||||||
};
|
v[i][j] = float(v[i][j])/vsum;
|
||||||
|
}}
|
||||||
{for_loop}
|
{for_loop}
|
||||||
vload<sizeof(in0_type)*{ILP}>(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]);
|
vload<sizeof(in0_type)*{ILP}>(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]);
|
||||||
}}
|
}}
|
||||||
|
|
|
@ -265,6 +265,19 @@ class TestOther(unittest.TestCase):
|
||||||
z2 = np.arctan2(y.data, x.data)
|
z2 = np.arctan2(y.data, x.data)
|
||||||
np.testing.assert_allclose(z.data, z2)
|
np.testing.assert_allclose(z.data, z2)
|
||||||
|
|
||||||
|
def test_softmax_precision(self):
|
||||||
|
# jt.flags.use_cuda = 1
|
||||||
|
a = -jt.array([1.0,2.0,1e5])
|
||||||
|
b = a.log_softmax(0)
|
||||||
|
assert b.isfinite().all().item()
|
||||||
|
print("test_softmax_precision cpu ok")
|
||||||
|
if not jt.has_cuda: return
|
||||||
|
jt.flags.use_cuda = 1
|
||||||
|
a = -jt.array([1.0,2.0,1e5])
|
||||||
|
b = a.log_softmax(0)
|
||||||
|
assert b.isfinite().all().item()
|
||||||
|
print("test_softmax_precision gpu ok")
|
||||||
|
|
||||||
def test_code_softmax(self):
|
def test_code_softmax(self):
|
||||||
if not jt.has_cuda: return
|
if not jt.has_cuda: return
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue