mirror of https://github.com/Jittor/Jittor
fix simple_presum
This commit is contained in:
parent
1e82016bfa
commit
04be711ddb
|
@ -8,7 +8,8 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.67'
|
||||
|
||||
__version__ = '1.2.2.68'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -15,10 +15,9 @@ __inline_static__
|
|||
void kernel(int n0, int i0, in0_type* x, in0_type* out, int nl) {
|
||||
out[i0*(nl+1)] = 0;
|
||||
for (int i=0; i<nl; i++)
|
||||
out[i0*(nl+1)+i+1] = out[i0*(nl+1)+i] + x[i0*(nl+1)+i];
|
||||
out[i0*(nl+1)+i+1] = out[i0*(nl+1)+i] + x[i0*nl+i];
|
||||
}
|
||||
|
||||
kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->num);
|
||||
kernel(in0->num/in0->shape[in0->shape.size()-1], 0, in0_p, out0_p, in0->shape[in0->shape.size()-1]);
|
||||
'''
|
||||
return jt.code(x.shape[:-1]+(x.shape[-1]+1,), x.dtype, [x],
|
||||
cpu_src=src, cuda_src=src)
|
||||
|
@ -32,7 +31,7 @@ class OneHotCategorical:
|
|||
probs = jt.sigmoid(logits)
|
||||
with jt.no_grad():
|
||||
self.probs = probs / probs.sum(-1, True)
|
||||
self.cum_probs = simple_presum(probs)
|
||||
self.cum_probs = simple_presum(self.probs)
|
||||
self.cum_probs_l = self.cum_probs[..., :-1]
|
||||
self.cum_probs_r = self.cum_probs[..., 1:]
|
||||
|
||||
|
|
Loading…
Reference in New Issue