mirror of https://github.com/Jittor/Jittor
add bernoulli
This commit is contained in:
parent
33f85afe6c
commit
36a3c24a46
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.49'
|
||||
__version__ = '1.2.2.50'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -64,6 +64,8 @@ def any(x,dim):
|
|||
return x.any_(dim).bool()
|
||||
jt.Var.any = any
|
||||
|
||||
def bernoulli(input):
|
||||
return (input>jt.rand_like(input)).cast(input.dtype)
|
||||
|
||||
def repeat(x, *shape):
|
||||
r'''
|
||||
|
|
Loading…
Reference in New Issue