polish transform

This commit is contained in:
Dun Liang 2021-09-09 11:32:16 +08:00
parent f807a28e6b
commit 42dfaaed2e
3 changed files with 58 additions and 13 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package. # file 'LICENSE.txt', which is part of this source code package.
# *************************************************************** # ***************************************************************
__version__ = '1.2.3.100' __version__ = '1.2.3.101'
from jittor_utils import lock from jittor_utils import lock
with lock.lock_scope(): with lock.lock_scope():
ori_int = int ori_int = int

View File

@ -954,6 +954,22 @@ class Tester(unittest.TestCase):
transform.ToTensor(), transform.ToTensor(),
])(img) ])(img)
def test_not_pil_image(self):
img = jt.random((30,40,3))
result = transform.Compose([
transform.RandomAffine(20),
transform.ToTensor(),
])(img)
img = jt.random((30,40,3))
result = transform.Compose([
transform.ToPILImage(),
transform.Gray(),
transform.Resize(20),
transform.ToTensor(),
])(img)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -152,7 +152,9 @@ class Crop:
self.left = left self.left = left
self.height = height self.height = height
self.width = width self.width = width
def __call__(self, img): def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
return crop(img, self.top, self.left, self.height, self.width) return crop(img, self.top, self.left, self.height, self.width)
@ -181,6 +183,8 @@ class RandomCropAndResize:
self.interpolation = interpolation self.interpolation = interpolation
def __call__(self, img:Image.Image): def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
width, height = img.size width, height = img.size
scale = self.scale scale = self.scale
ratio = self.ratio ratio = self.ratio
@ -363,6 +367,8 @@ class RandomHorizontalFlip:
self.p = p self.p = p
def __call__(self, img:Image.Image): def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
if random.random() < self.p: if random.random() < self.p:
return img.transpose(Image.FLIP_LEFT_RIGHT) return img.transpose(Image.FLIP_LEFT_RIGHT)
return img return img
@ -384,6 +390,8 @@ class CenterCrop:
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
def __call__(self, img:Image.Image): def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
width, height = img.size width, height = img.size
return crop(img, (height - self.size[0]) / 2, (width - self.size[1]) / 2, self.size[0], self.size[1]) return crop(img, (height - self.size[0]) / 2, (width - self.size[1]) / 2, self.size[0], self.size[1])
@ -682,6 +690,8 @@ class Resize:
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
self.mode = mode self.mode = mode
def __call__(self, img:Image.Image): def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
return resize(img, self.size, self.mode) return resize(img, self.size, self.mode)
class Gray: class Gray:
@ -697,6 +707,8 @@ class Gray:
self.num_output_channels = num_output_channels self.num_output_channels = num_output_channels
def __call__(self, img:Image.Image): def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
img = np.float32(img.convert('L')) / np.float32(255.0) img = np.float32(img.convert('L')) / np.float32(255.0)
if self.num_output_channels == 1: if self.num_output_channels == 1:
return img[np.newaxis, :] return img[np.newaxis, :]
@ -720,7 +732,9 @@ class RandomGray:
def __init__(self, p=0.1): def __init__(self, p=0.1):
self.p = p self.p = p
def __call__(self, img: Image.Image): def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
num_output_channels = _get_image_num_channels(img) num_output_channels = _get_image_num_channels(img)
if random.random() < self.p: if random.random() < self.p:
return gray(img, num_output_channels=num_output_channels) return gray(img, num_output_channels=num_output_channels)
@ -742,6 +756,8 @@ class RandomCrop:
def __init__(self, size): def __init__(self, size):
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values") self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
def __call__(self, img:Image.Image): def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
width, height = img.size width, height = img.size
assert self.size[0] <= height and self.size[1] <= width, f"crop size exceeds the input image in RandomCrop, {(self.size, height, width)}" assert self.size[0] <= height and self.size[1] <= width, f"crop size exceeds the input image in RandomCrop, {(self.size, height, width)}"
top = np.random.randint(0,height-self.size[0]+1) top = np.random.randint(0,height-self.size[0]+1)
@ -835,7 +851,9 @@ class RandomVerticalFlip:
def __init__(self, p=0.5): def __init__(self, p=0.5):
self.p = p self.p = p
def __call__(self, img: Image.Image): def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
if random.random() < self.p: if random.random() < self.p:
return vflip(img) return vflip(img)
return img return img
@ -918,13 +936,15 @@ class ColorJitter:
return transform return transform
def __call__(self, img): def __call__(self, img:Image.Image):
""" """
Args:: Args::
[in] img (PIL Image): Input image. [in] img (PIL Image): Input image.
Returns:: Returns::
[out] PIL Image: Color jittered image. [out] PIL Image: Color jittered image.
""" """
if not isinstance(img, Image.Image):
img = to_pil_image(img)
transform = self._get_transform(self.brightness, self.contrast, self.saturation, self.hue) transform = self._get_transform(self.brightness, self.contrast, self.saturation, self.hue)
return transform(img) return transform(img)
@ -1002,7 +1022,7 @@ class RandomPerspective(object):
self.interpolation = interpolation self.interpolation = interpolation
self.distortion_scale = distortion_scale self.distortion_scale = distortion_scale
def __call__(self, img): def __call__(self, img:Image.Image):
""" """
Args: Args:
img (PIL Image): Image to be Perspectively transformed. img (PIL Image): Image to be Perspectively transformed.
@ -1011,7 +1031,7 @@ class RandomPerspective(object):
PIL Image: Random perspectivley transformed image. PIL Image: Random perspectivley transformed image.
""" """
if not isinstance(img, Image.Image): if not isinstance(img, Image.Image):
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) img = to_pil_image(img)
if random.random() < self.p: if random.random() < self.p:
width, height = img.size width, height = img.size
@ -1119,7 +1139,7 @@ class RandomResizedCrop(object):
j = (width - w) // 2 j = (width - w) // 2
return i, j, h, w return i, j, h, w
def __call__(self, img): def __call__(self, img:Image.Image):
""" """
Args: Args:
img (PIL Image): Image to be cropped and resized. img (PIL Image): Image to be cropped and resized.
@ -1127,6 +1147,8 @@ class RandomResizedCrop(object):
Returns: Returns:
PIL Image: Randomly cropped and resized image. PIL Image: Randomly cropped and resized image.
""" """
if not isinstance(img, Image.Image):
img = to_pil_image(img)
i, j, h, w = self.get_params(img, self.scale, self.ratio) i, j, h, w = self.get_params(img, self.scale, self.ratio)
return F_pil.resized_crop(img, i, j, h, w, self.size, self.interpolation) return F_pil.resized_crop(img, i, j, h, w, self.size, self.interpolation)
@ -1174,7 +1196,9 @@ class FiveCrop(object):
assert len(size) == 2, "Please provide only two dimensions (h, w) for size." assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
self.size = size self.size = size
def __call__(self, img): def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
return F_pil.five_crop(img, self.size) return F_pil.five_crop(img, self.size)
def __repr__(self): def __repr__(self):
@ -1217,7 +1241,9 @@ class TenCrop(object):
self.size = size self.size = size
self.vertical_flip = vertical_flip self.vertical_flip = vertical_flip
def __call__(self, img): def __call__(self, img:Image.Image):
if not isinstance(img, Image.Image):
img = to_pil_image(img)
return F_pil.ten_crop(img, self.size, self.vertical_flip) return F_pil.ten_crop(img, self.size, self.vertical_flip)
def __repr__(self): def __repr__(self):
@ -1275,7 +1301,7 @@ class RandomRotation(object):
return angle return angle
def __call__(self, img): def __call__(self, img:Image.Image):
""" """
Args: Args:
img (PIL Image): Image to be rotated. img (PIL Image): Image to be rotated.
@ -1283,7 +1309,8 @@ class RandomRotation(object):
Returns: Returns:
PIL Image: Rotated image. PIL Image: Rotated image.
""" """
if not isinstance(img, Image.Image):
img = to_pil_image(img)
angle = self.get_params(self.degrees) angle = self.get_params(self.degrees)
return F_pil.rotate(img, angle, self.resample, self.expand, self.center, self.fill) return F_pil.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
@ -1405,13 +1432,15 @@ class RandomAffine(object):
return angle, translations, scale, shear return angle, translations, scale, shear
def __call__(self, img): def __call__(self, img:Image.Image):
""" """
img (PIL Image): Image to be transformed. img (PIL Image): Image to be transformed.
Returns: Returns:
PIL Image: Affine transformed image. PIL Image: Affine transformed image.
""" """
if not isinstance(img, Image.Image):
img = to_pil_image(img)
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size) ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
return F_pil.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor) return F_pil.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)