better randperm support

This commit is contained in:
Dun Liang 2021-02-07 22:37:08 +08:00
parent 6ada098f60
commit 79959b6ab4
2 changed files with 5 additions and 10 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in # This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.2.2.29' __version__ = '1.2.2.30'
from . import lock from . import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -483,11 +483,6 @@ def arange(start=0, end=None, step=1,dtype=None):
x= x.cast(dtype) x= x.cast(dtype)
return x return x
def randperm(n, dtype="int64"):
x = np.arange(n)
np.random.shuffle(x)
return jt.array(x).cast(dtype)
def log2(x): def log2(x):
return jt.log(x)/math.log(2.0) return jt.log(x)/math.log(2.0)
@ -1003,10 +998,10 @@ def linspace(start, end, steps):
res = res*(end-start)/float(steps-1)+start res = res*(end-start)/float(steps-1)+start
return res return res
def randperm(n): def randperm(n, dtype="int32"):
# TODO: use jt.random key = jt.random((n,))
idx = np.arange(n) index, _ = jt.argsort(key)
return jt.array(np.random.permutation(idx)) return index.cast(dtype)
def set_global_seed(seed): def set_global_seed(seed):
import random import random