fix zeros_like dtype select

This commit is contained in:
Dun Liang 2022-12-09 20:24:04 +08:00
parent 2a06801681
commit 2e74df517b
2 changed files with 9 additions and 3 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.6.2'
__version__ = '1.3.6.3'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -453,7 +453,7 @@ def ones(*shape, dtype="float32"):
:return: The output Var.
:rtype: jittor.Var
'''
if isinstance(shape, tuple) and isinstance(shape[-1], str):
if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)):
dtype = shape[-1]
shape = shape[:-1]
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
@ -480,7 +480,7 @@ def zeros(*shape, dtype="float32"):
:return: The output Var.
:rtype: jittor.Var
'''
if isinstance(shape, tuple) and isinstance(shape[-1], str):
if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)):
dtype = shape[-1]
shape = shape[:-1]
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):

View File

@ -377,5 +377,11 @@ class TestOther(unittest.TestCase):
a = jt.ones(10,10)
assert a.shape == (10,10)
a = jt.ones_like(jt.ones([10], "int16"))
assert a.dtype == "int16"
a = jt.ones_like(jt.ones([10], "bool"))
assert a.dtype == "bool"
if __name__ == "__main__":
unittest.main()