fix simple_presum

This commit is contained in:
Dun Liang 2021-05-07 12:16:46 +08:00
parent 1e82016bfa
commit 04be711ddb
2 changed files with 5 additions and 5 deletions

View File

@ -8,7 +8,8 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.67'
__version__ = '1.2.2.68'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -15,10 +15,9 @@ __inline_static__
void kernel(int n0, int i0, in0_type* x, in0_type* out, int nl) {
out[i0*(nl+1)] = 0;
for (int i=0; i<nl; i++)
out[i0*(nl+1)+i+1] = out[i0*(nl+1)+i] + x[i0*(nl+1)+i];
out[i0*(nl+1)+i+1] = out[i0*(nl+1)+i] + x[i0*nl+i];
}
kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->num);
kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in0->shape.size()-1]);
'''
return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x],
cpu_src=src, cuda_src=src)
@ -32,7 +31,7 @@ class OneHotCategorical:
probs = jt.sigmoid(logits)
with jt.no_grad():
self.probs = probs / probs.sum(-1, True)
self.cum_probs = simple_presum(probs)
self.cum_probs = simple_presum(self.probs)
self.cum_probs_l = self.cum_probs[..., :-1]
self.cum_probs_r = self.cum_probs[..., 1:]