diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index 6ce53563..97c6922f 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -9,7 +9,7 @@ # file 'LICENSE.txt', which is part of this source code package. # *************************************************************** -__version__ = '1.2.3.100' +__version__ = '1.2.3.101' from jittor_utils import lock with lock.lock_scope(): ori_int = int diff --git a/python/jittor/test/test_transform.py b/python/jittor/test/test_transform.py index b59ea4b4..5498367c 100644 --- a/python/jittor/test/test_transform.py +++ b/python/jittor/test/test_transform.py @@ -954,6 +954,22 @@ class Tester(unittest.TestCase): transform.ToTensor(), ])(img) + def test_not_pil_image(self): + img = jt.random((30,40,3)) + result = transform.Compose([ + transform.RandomAffine(20), + transform.ToTensor(), + ])(img) + + img = jt.random((30,40,3)) + result = transform.Compose([ + transform.ToPILImage(), + transform.Gray(), + transform.Resize(20), + transform.ToTensor(), + ])(img) + + if __name__ == '__main__': diff --git a/python/jittor/transform/__init__.py b/python/jittor/transform/__init__.py index ee817bb8..d4b7009f 100644 --- a/python/jittor/transform/__init__.py +++ b/python/jittor/transform/__init__.py @@ -152,7 +152,9 @@ class Crop: self.left = left self.height = height self.width = width - def __call__(self, img): + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) return crop(img, self.top, self.left, self.height, self.width) @@ -181,6 +183,8 @@ class RandomCropAndResize: self.interpolation = interpolation def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) width, height = img.size scale = self.scale ratio = self.ratio @@ -363,6 +367,8 @@ class RandomHorizontalFlip: self.p = p def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) if random.random() < self.p: return img.transpose(Image.FLIP_LEFT_RIGHT) return img @@ -384,6 +390,8 @@ class CenterCrop: self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) width, height = img.size return crop(img, (height - self.size[0]) / 2, (width - self.size[1]) / 2, self.size[0], self.size[1]) @@ -682,6 +690,8 @@ class Resize: self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") self.mode = mode def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) return resize(img, self.size, self.mode) class Gray: @@ -697,6 +707,8 @@ class Gray: self.num_output_channels = num_output_channels def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) img = np.float32(img.convert('L')) / np.float32(255.0) if self.num_output_channels == 1: return img[np.newaxis, :] @@ -720,7 +732,9 @@ class RandomGray: def __init__(self, p=0.1): self.p = p - def __call__(self, img: Image.Image): + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) num_output_channels = _get_image_num_channels(img) if random.random() < self.p: return gray(img, num_output_channels=num_output_channels) @@ -742,6 +756,8 @@ class RandomCrop: def __init__(self, size): self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) width, height = img.size assert self.size[0] <= height and self.size[1] <= width, f"crop size exceeds the input image in RandomCrop, {(self.size, height, width)}" top = np.random.randint(0,height-self.size[0]+1) @@ -835,7 +851,9 @@ class RandomVerticalFlip: def __init__(self, p=0.5): self.p = p - def __call__(self, img: Image.Image): + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) if random.random() < self.p: return vflip(img) return img @@ -918,13 +936,15 @@ class ColorJitter: return transform - def __call__(self, img): + def __call__(self, img:Image.Image): """ Args:: [in] img (PIL Image): Input image. Returns:: [out] PIL Image: Color jittered image. """ + if not isinstance(img, Image.Image): + img = to_pil_image(img) transform = self._get_transform(self.brightness, self.contrast, self.saturation, self.hue) return transform(img) @@ -1002,7 +1022,7 @@ class RandomPerspective(object): self.interpolation = interpolation self.distortion_scale = distortion_scale - def __call__(self, img): + def __call__(self, img:Image.Image): """ Args: img (PIL Image): Image to be Perspectively transformed. @@ -1011,7 +1031,7 @@ class RandomPerspective(object): PIL Image: Random perspectivley transformed image. """ if not isinstance(img, Image.Image): - raise TypeError('img should be PIL Image. Got {}'.format(type(img))) + img = to_pil_image(img) if random.random() < self.p: width, height = img.size @@ -1119,7 +1139,7 @@ class RandomResizedCrop(object): j = (width - w) // 2 return i, j, h, w - def __call__(self, img): + def __call__(self, img:Image.Image): """ Args: img (PIL Image): Image to be cropped and resized. @@ -1127,6 +1147,8 @@ class RandomResizedCrop(object): Returns: PIL Image: Randomly cropped and resized image. """ + if not isinstance(img, Image.Image): + img = to_pil_image(img) i, j, h, w = self.get_params(img, self.scale, self.ratio) return F_pil.resized_crop(img, i, j, h, w, self.size, self.interpolation) @@ -1174,7 +1196,9 @@ class FiveCrop(object): assert len(size) == 2, "Please provide only two dimensions (h, w) for size." self.size = size - def __call__(self, img): + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) return F_pil.five_crop(img, self.size) def __repr__(self): @@ -1217,7 +1241,9 @@ class TenCrop(object): self.size = size self.vertical_flip = vertical_flip - def __call__(self, img): + def __call__(self, img:Image.Image): + if not isinstance(img, Image.Image): + img = to_pil_image(img) return F_pil.ten_crop(img, self.size, self.vertical_flip) def __repr__(self): @@ -1275,7 +1301,7 @@ class RandomRotation(object): return angle - def __call__(self, img): + def __call__(self, img:Image.Image): """ Args: img (PIL Image): Image to be rotated. @@ -1283,7 +1309,8 @@ class RandomRotation(object): Returns: PIL Image: Rotated image. """ - + if not isinstance(img, Image.Image): + img = to_pil_image(img) angle = self.get_params(self.degrees) return F_pil.rotate(img, angle, self.resample, self.expand, self.center, self.fill) @@ -1405,13 +1432,15 @@ class RandomAffine(object): return angle, translations, scale, shear - def __call__(self, img): + def __call__(self, img:Image.Image): """ img (PIL Image): Image to be transformed. Returns: PIL Image: Affine transformed image. """ + if not isinstance(img, Image.Image): + img = to_pil_image(img) ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) return F_pil.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)