mirror of https://github.com/Jittor/Jittor
fix cumprod
This commit is contained in:
parent
8a2e7a1881
commit
546860e19e
|
@ -718,7 +718,7 @@ def cumsum(x, dim=None):
|
|||
|
||||
jt.Var.cumsum = cumsum
|
||||
|
||||
def cumprod(x,dim=0):
|
||||
def cumprod(x,dim=None):
|
||||
x = jt.log(x)
|
||||
x = cumsum(x,dim=dim)
|
||||
return jt.exp(x)
|
||||
|
|
Loading…
Reference in New Issue