mirror of https://github.com/Jittor/Jittor
polish to_tensor
This commit is contained in:
parent
80000c6941
commit
a7ced77f69
|
@ -9,7 +9,7 @@
|
||||||
# file 'LICENSE.txt', which is part of this source code package.
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
# ***************************************************************
|
# ***************************************************************
|
||||||
|
|
||||||
__version__ = '1.2.3.26'
|
__version__ = '1.2.3.27'
|
||||||
from jittor_utils import lock
|
from jittor_utils import lock
|
||||||
with lock.lock_scope():
|
with lock.lock_scope():
|
||||||
ori_int = int
|
ori_int = int
|
||||||
|
|
|
@ -257,14 +257,14 @@ class Tester(unittest.TestCase):
|
||||||
expect = input_data.transpose(2,0,1)
|
expect = input_data.transpose(2,0,1)
|
||||||
self.assertTrue(np.allclose(expect, output), f"{expect.shape}\n{output.shape}")
|
self.assertTrue(np.allclose(expect, output), f"{expect.shape}\n{output.shape}")
|
||||||
|
|
||||||
ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
|
ndarray = np.random.randint(low=0, high=255, size=(channels, height, width)).astype(np.uint8)
|
||||||
output = trans(ndarray)
|
output = trans(ndarray)
|
||||||
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
|
expected_output = ndarray / 255.0
|
||||||
self.assertTrue(np.allclose(output, expected_output))
|
np.testing.assert_allclose(output, expected_output)
|
||||||
|
|
||||||
ndarray = np.random.rand(height, width, channels).astype(np.float32)
|
ndarray = np.random.rand(channels, height, width).astype(np.float32)
|
||||||
output = trans(ndarray)
|
output = trans(ndarray)
|
||||||
expected_output = ndarray.transpose((2, 0, 1))
|
expected_output = ndarray
|
||||||
self.assertTrue(np.allclose(output, expected_output))
|
self.assertTrue(np.allclose(output, expected_output))
|
||||||
|
|
||||||
# separate test for mode '1' PIL images
|
# separate test for mode '1' PIL images
|
||||||
|
|
|
@ -389,7 +389,7 @@ class CenterCrop:
|
||||||
|
|
||||||
def to_tensor(pic):
|
def to_tensor(pic):
|
||||||
"""
|
"""
|
||||||
Function for turning Image.Image to np.array.
|
Function for turning Image.Image to np.array with CHW format.
|
||||||
|
|
||||||
Args::
|
Args::
|
||||||
|
|
||||||
|
@ -414,14 +414,13 @@ def to_tensor(pic):
|
||||||
if _is_numpy(pic):
|
if _is_numpy(pic):
|
||||||
# handle numpy array
|
# handle numpy array
|
||||||
if pic.ndim == 2:
|
if pic.ndim == 2:
|
||||||
pic = pic[:, :, None]
|
pic = pic[None, :, :]
|
||||||
|
|
||||||
img = pic.transpose((2, 0, 1))
|
|
||||||
# backward compatibility
|
# backward compatibility
|
||||||
if img.dtype == 'uint8':
|
if pic.dtype == 'uint8':
|
||||||
return np.float32(img) * np.float32(1/255.0)
|
return np.float32(pic) * np.float32(1/255.0)
|
||||||
else:
|
else:
|
||||||
return img
|
return pic
|
||||||
|
|
||||||
# handle PIL Image
|
# handle PIL Image
|
||||||
if pic.mode == 'I':
|
if pic.mode == 'I':
|
||||||
|
@ -499,7 +498,7 @@ def _to_jittor_array(pic):
|
||||||
def to_pil_image(pic, mode=None):
|
def to_pil_image(pic, mode=None):
|
||||||
"""Convert a tensor or an ndarray to PIL Image.
|
"""Convert a tensor or an ndarray to PIL Image.
|
||||||
Args:
|
Args:
|
||||||
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
|
pic (Tensor or numpy.ndarray): Image(HWC format) to be converted to PIL Image.
|
||||||
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
|
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
|
||||||
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
|
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
|
||||||
Returns:
|
Returns:
|
||||||
|
|
Loading…
Reference in New Issue