diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index ee2a6d10..99269ce9 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -8,7 +8,7 @@ # This file is subject to the terms and conditions defined in # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.2.29' +__version__ = '1.2.2.30' from . import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 74ceab64..f88f4d6c 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -483,11 +483,6 @@ def arange(start=0, end=None, step=1,dtype=None): x= x.cast(dtype) return x -def randperm(n, dtype="int64"): - x = np.arange(n) - np.random.shuffle(x) - return jt.array(x).cast(dtype) - def log2(x): return jt.log(x)/math.log(2.0) @@ -1003,10 +998,10 @@ def linspace(start, end, steps): res = res*(end-start)/float(steps-1)+start return res -def randperm(n): - # TODO: use jt.random - idx = np.arange(n) - return jt.array(np.random.permutation(idx)) +def randperm(n, dtype="int32"): + key = jt.random((n,)) + index, _ = jt.argsort(key) + return index.cast(dtype) def set_global_seed(seed): import random