From ba266fa99c67e76222d824bfd1446b63eb00ea0d Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Wed, 18 May 2022 15:37:09 +0800 Subject: [PATCH] polish log_softmax precision --- python/jittor/__init__.py | 2 +- python/jittor/nn.py | 14 ++++++-------- python/jittor/other/code_softmax.py | 21 ++++++++++++++------- python/jittor/test/test_misc_op.py | 13 +++++++++++++ 4 files changed, 34 insertions(+), 16 deletions(-) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 294b9bb4..433740c7 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.3.4.5' +__version__ = '1.3.4.6' from jittor_utils import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/nn.py b/python/jittor/nn.py index 35321049..42965538 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -496,14 +496,12 @@ def softmax(x, dim=None, log=False): import jittor.other.code_softmax as code_softmax if code_softmax.can_softmax_v1(x, dim): return code_softmax.softmax_v1(x, log) - if dim is None: - x = (x - x.max()).exp() - ret = x / x.sum() - else: - x = (x-x.max(dim, keepdims=True)).exp() - ret = x / x.sum(dim, keepdims=True) - if log: return ret.log() - return ret + if dim is None: dim = () + if log: + a = x-x.max(dim, keepdims=True) + return a - a.exp().sum(dim, keepdims=True).log() + x = (x-x.max(dim, keepdims=True)).exp() + return x / x.sum(dim, keepdims=True) jt.Var.softmax = softmax def log_softmax(x,dim=None): diff --git a/python/jittor/other/code_softmax.py b/python/jittor/other/code_softmax.py index 837bd648..0c1b2b56 100644 --- a/python/jittor/other/code_softmax.py +++ b/python/jittor/other/code_softmax.py @@ -39,6 +39,7 @@ def softmax_v1(a, log=False): ''', cuda_src=f''' __global__ void kernel(in0_type* x, out0_type* y, int len) {{ typedef cub::BlockReduce BlockReduce; + constexpr int need_log = {int(log)}; __shared__ typename BlockReduce::TempStorage temp_storage; int id = blockIdx.x * len; @@ -62,8 +63,13 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{ {for_loop} #pragma unroll for (int j=0; j<{ILP}; j++) {{ - v[i][j] = expf(float(v[i][j]) - vmax); - v1 += float(v[i][j]); + if (need_log) {{ + v[i][j] = float(v[i][j]) - vmax; + v1 += expf(float(v[i][j])); + }} else {{ + v[i][j] = expf(float(v[i][j]) - vmax); + v1 += float(v[i][j]); + }} }} tmp = BlockReduce(temp_storage).Sum(v1); @@ -74,11 +80,12 @@ __global__ void kernel(in0_type* x, out0_type* y, int len) {{ {for_loop} #pragma unroll - for (int j=0; j<{ILP}; j++) - v[i][j] = { - "@expand_op(log,@in0_type,float(v[i][j])/vsum)" if log - else "float(v[i][j])/vsum" - }; + for (int j=0; j<{ILP}; j++) {{ + if (need_log) + v[i][j] = v[i][j] - @expand_op(log,@in0_type,vsum); + else + v[i][j] = float(v[i][j])/vsum; + }} {for_loop} vload(&y[id+(i*{tnum}+threadIdx.x)*{ILP}], v[i]); }} diff --git a/python/jittor/test/test_misc_op.py b/python/jittor/test/test_misc_op.py index cfe4cde3..772afb83 100644 --- a/python/jittor/test/test_misc_op.py +++ b/python/jittor/test/test_misc_op.py @@ -265,6 +265,19 @@ class TestOther(unittest.TestCase): z2 = np.arctan2(y.data, x.data) np.testing.assert_allclose(z.data, z2) + def test_softmax_precision(self): + # jt.flags.use_cuda = 1 + a = -jt.array([1.0,2.0,1e5]) + b = a.log_softmax(0) + assert b.isfinite().all().item() + print("test_softmax_precision cpu ok") + if not jt.has_cuda: return + jt.flags.use_cuda = 1 + a = -jt.array([1.0,2.0,1e5]) + b = a.log_softmax(0) + assert b.isfinite().all().item() + print("test_softmax_precision gpu ok") + def test_code_softmax(self): if not jt.has_cuda: return