mirror of https://github.com/Jittor/Jittor
better randperm support
This commit is contained in:
parent
6ada098f60
commit
79959b6ab4
|
@ -8,7 +8,7 @@
|
||||||
# This file is subject to the terms and conditions defined in
|
# This file is subject to the terms and conditions defined in
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
__version__ = '1.2.2.29'
|
__version__ = '1.2.2.30'
|
||||||
from . import lock
|
from . import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -483,11 +483,6 @@ def arange(start=0, end=None, step=1,dtype=None):
|
||||||
x= x.cast(dtype)
|
x= x.cast(dtype)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def randperm(n, dtype="int64"):
|
|
||||||
x = np.arange(n)
|
|
||||||
np.random.shuffle(x)
|
|
||||||
return jt.array(x).cast(dtype)
|
|
||||||
|
|
||||||
def log2(x):
|
def log2(x):
|
||||||
return jt.log(x)/math.log(2.0)
|
return jt.log(x)/math.log(2.0)
|
||||||
|
|
||||||
|
@ -1003,10 +998,10 @@ def linspace(start, end, steps):
|
||||||
res = res*(end-start)/float(steps-1)+start
|
res = res*(end-start)/float(steps-1)+start
|
||||||
return res
|
return res
|
||||||
|
|
||||||
def randperm(n):
|
def randperm(n, dtype="int32"):
|
||||||
# TODO: use jt.random
|
key = jt.random((n,))
|
||||||
idx = np.arange(n)
|
index, _ = jt.argsort(key)
|
||||||
return jt.array(np.random.permutation(idx))
|
return index.cast(dtype)
|
||||||
|
|
||||||
def set_global_seed(seed):
|
def set_global_seed(seed):
|
||||||
import random
|
import random
|
||||||
|
|
Loading…
Reference in New Issue