mirror of https://github.com/Jittor/Jittor
fix: randint returns decimals when dtype is float
This commit is contained in:
parent
0278356f31
commit
37e468e0f8
|
@ -522,6 +522,7 @@ def randint(low, high=None, shape=(1,), dtype="int32") -> Var:
|
|||
'''
|
||||
if high is None: low, high = 0, low
|
||||
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
|
||||
v = jt.floor(v)
|
||||
return v.astype(dtype)
|
||||
|
||||
def randint_like(x, low, high=None) -> Var:
|
||||
|
|
Loading…
Reference in New Issue