fix: randint returns decimals when dtype is float

This commit is contained in:
lzhengning 2021-02-09 18:47:47 +08:00
parent 0278356f31
commit 37e468e0f8
1 changed files with 1 additions and 0 deletions

View File

@ -522,6 +522,7 @@ def randint(low, high=None, shape=(1,), dtype="int32") -> Var:
'''
if high is None: low, high = 0, low
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
v = jt.floor(v)
return v.astype(dtype)
def randint_like(x, low, high=None) -> Var: