mirror of https://github.com/Jittor/Jittor
fix array dtype
This commit is contained in:
parent
fb90824178
commit
f3ec2e956c
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.1.7.6'
|
||||
__version__ = '1.1.7.7'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
from . import compiler
|
||||
|
@ -204,7 +204,11 @@ def array(data, dtype=None):
|
|||
if dtype is None:
|
||||
return data.clone()
|
||||
return cast(data, dtype)
|
||||
if dtype != None:
|
||||
if dtype is not None:
|
||||
if isinstance(dtype, NanoString):
|
||||
dtype = str(dtype)
|
||||
elif callable(dtype):
|
||||
dtype = dtype.__name__
|
||||
return ops.array(np.array(data, dtype))
|
||||
return ops.array(data)
|
||||
|
||||
|
|
|
@ -127,6 +127,10 @@ class TestArray(unittest.TestCase):
|
|||
assert jt.array(np.int32(1)).data == 1
|
||||
assert jt.array(np.int64(1)).data == 1
|
||||
|
||||
def test_array_dtype(self):
|
||||
a = jt.array([1,2,3], dtype=jt.NanoString("float32"))
|
||||
a = jt.array([1,2,3], dtype=jt.float32)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue