mirror of https://github.com/Jittor/Jittor
polish to_tensor
This commit is contained in:
parent
80000c6941
commit
a7ced77f69
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.26'
|
||||
__version__ = '1.2.3.27'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -257,14 +257,14 @@ class Tester(unittest.TestCase):
|
|||
expect = input_data.transpose(2,0,1)
|
||||
self.assertTrue(np.allclose(expect, output), f"{expect.shape}\n{output.shape}")
|
||||
|
||||
ndarray = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8)
|
||||
ndarray = np.random.randint(low=0, high=255, size=(channels, height, width)).astype(np.uint8)
|
||||
output = trans(ndarray)
|
||||
expected_output = ndarray.transpose((2, 0, 1)) / 255.0
|
||||
self.assertTrue(np.allclose(output, expected_output))
|
||||
expected_output = ndarray / 255.0
|
||||
np.testing.assert_allclose(output, expected_output)
|
||||
|
||||
ndarray = np.random.rand(height, width, channels).astype(np.float32)
|
||||
ndarray = np.random.rand(channels, height, width).astype(np.float32)
|
||||
output = trans(ndarray)
|
||||
expected_output = ndarray.transpose((2, 0, 1))
|
||||
expected_output = ndarray
|
||||
self.assertTrue(np.allclose(output, expected_output))
|
||||
|
||||
# separate test for mode '1' PIL images
|
||||
|
|
|
@ -389,7 +389,7 @@ class CenterCrop:
|
|||
|
||||
def to_tensor(pic):
|
||||
"""
|
||||
Function for turning Image.Image to np.array.
|
||||
Function for turning Image.Image to np.array with CHW format.
|
||||
|
||||
Args::
|
||||
|
||||
|
@ -414,14 +414,13 @@ def to_tensor(pic):
|
|||
if _is_numpy(pic):
|
||||
# handle numpy array
|
||||
if pic.ndim == 2:
|
||||
pic = pic[:, :, None]
|
||||
pic = pic[None, :, :]
|
||||
|
||||
img = pic.transpose((2, 0, 1))
|
||||
# backward compatibility
|
||||
if img.dtype == 'uint8':
|
||||
return np.float32(img) * np.float32(1/255.0)
|
||||
if pic.dtype == 'uint8':
|
||||
return np.float32(pic) * np.float32(1/255.0)
|
||||
else:
|
||||
return img
|
||||
return pic
|
||||
|
||||
# handle PIL Image
|
||||
if pic.mode == 'I':
|
||||
|
@ -499,7 +498,7 @@ def _to_jittor_array(pic):
|
|||
def to_pil_image(pic, mode=None):
|
||||
"""Convert a tensor or an ndarray to PIL Image.
|
||||
Args:
|
||||
pic (Tensor or numpy.ndarray): Image to be converted to PIL Image.
|
||||
pic (Tensor or numpy.ndarray): Image(HWC format) to be converted to PIL Image.
|
||||
mode (`PIL.Image mode`_): color space and pixel depth of input data (optional).
|
||||
.. _PIL.Image mode: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#concept-modes
|
||||
Returns:
|
||||
|
|
Loading…
Reference in New Issue