mirror of https://github.com/Jittor/Jittor
fix zeros_like dtype select
This commit is contained in:
parent
2a06801681
commit
2e74df517b
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.6.2'
|
||||
__version__ = '1.3.6.3'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -453,7 +453,7 @@ def ones(*shape, dtype="float32"):
|
|||
:return: The output Var.
|
||||
:rtype: jittor.Var
|
||||
'''
|
||||
if isinstance(shape, tuple) and isinstance(shape[-1], str):
|
||||
if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)):
|
||||
dtype = shape[-1]
|
||||
shape = shape[:-1]
|
||||
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
|
||||
|
@ -480,7 +480,7 @@ def zeros(*shape, dtype="float32"):
|
|||
:return: The output Var.
|
||||
:rtype: jittor.Var
|
||||
'''
|
||||
if isinstance(shape, tuple) and isinstance(shape[-1], str):
|
||||
if isinstance(shape, tuple) and isinstance(shape[-1], (str, NanoString)):
|
||||
dtype = shape[-1]
|
||||
shape = shape[:-1]
|
||||
if isinstance(shape, tuple) and isinstance(shape[0], (Sequence, NanoVector)):
|
||||
|
|
|
@ -377,5 +377,11 @@ class TestOther(unittest.TestCase):
|
|||
a = jt.ones(10,10)
|
||||
assert a.shape == (10,10)
|
||||
|
||||
a = jt.ones_like(jt.ones([10], "int16"))
|
||||
assert a.dtype == "int16"
|
||||
|
||||
a = jt.ones_like(jt.ones([10], "bool"))
|
||||
assert a.dtype == "bool"
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue