mirror of https://github.com/Jittor/Jittor
polish transform
This commit is contained in:
parent
f807a28e6b
commit
42dfaaed2e
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.100'
|
||||
__version__ = '1.2.3.101'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -954,6 +954,22 @@ class Tester(unittest.TestCase):
|
|||
transform.ToTensor(),
|
||||
])(img)
|
||||
|
||||
def test_not_pil_image(self):
|
||||
img = jt.random((30,40,3))
|
||||
result = transform.Compose([
|
||||
transform.RandomAffine(20),
|
||||
transform.ToTensor(),
|
||||
])(img)
|
||||
|
||||
img = jt.random((30,40,3))
|
||||
result = transform.Compose([
|
||||
transform.ToPILImage(),
|
||||
transform.Gray(),
|
||||
transform.Resize(20),
|
||||
transform.ToTensor(),
|
||||
])(img)
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -152,7 +152,9 @@ class Crop:
|
|||
self.left = left
|
||||
self.height = height
|
||||
self.width = width
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
return crop(img, self.top, self.left, self.height, self.width)
|
||||
|
||||
|
||||
|
@ -181,6 +183,8 @@ class RandomCropAndResize:
|
|||
self.interpolation = interpolation
|
||||
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
width, height = img.size
|
||||
scale = self.scale
|
||||
ratio = self.ratio
|
||||
|
@ -363,6 +367,8 @@ class RandomHorizontalFlip:
|
|||
self.p = p
|
||||
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
if random.random() < self.p:
|
||||
return img.transpose(Image.FLIP_LEFT_RIGHT)
|
||||
return img
|
||||
|
@ -384,6 +390,8 @@ class CenterCrop:
|
|||
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
|
||||
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
width, height = img.size
|
||||
return crop(img, (height - self.size[0]) / 2, (width - self.size[1]) / 2, self.size[0], self.size[1])
|
||||
|
||||
|
@ -682,6 +690,8 @@ class Resize:
|
|||
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
|
||||
self.mode = mode
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
return resize(img, self.size, self.mode)
|
||||
|
||||
class Gray:
|
||||
|
@ -697,6 +707,8 @@ class Gray:
|
|||
self.num_output_channels = num_output_channels
|
||||
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
img = np.float32(img.convert('L')) / np.float32(255.0)
|
||||
if self.num_output_channels == 1:
|
||||
return img[np.newaxis, :]
|
||||
|
@ -721,6 +733,8 @@ class RandomGray:
|
|||
self.p = p
|
||||
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
num_output_channels = _get_image_num_channels(img)
|
||||
if random.random() < self.p:
|
||||
return gray(img, num_output_channels=num_output_channels)
|
||||
|
@ -742,6 +756,8 @@ class RandomCrop:
|
|||
def __init__(self, size):
|
||||
self.size = _setup_size(size, error_msg="If size is a sequence, it should have 2 values")
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
width, height = img.size
|
||||
assert self.size[0] <= height and self.size[1] <= width, f"crop size exceeds the input image in RandomCrop, {(self.size, height, width)}"
|
||||
top = np.random.randint(0,height-self.size[0]+1)
|
||||
|
@ -836,6 +852,8 @@ class RandomVerticalFlip:
|
|||
self.p = p
|
||||
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
if random.random() < self.p:
|
||||
return vflip(img)
|
||||
return img
|
||||
|
@ -918,13 +936,15 @@ class ColorJitter:
|
|||
|
||||
return transform
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
"""
|
||||
Args::
|
||||
[in] img (PIL Image): Input image.
|
||||
Returns::
|
||||
[out] PIL Image: Color jittered image.
|
||||
"""
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
transform = self._get_transform(self.brightness, self.contrast, self.saturation, self.hue)
|
||||
|
||||
return transform(img)
|
||||
|
@ -1002,7 +1022,7 @@ class RandomPerspective(object):
|
|||
self.interpolation = interpolation
|
||||
self.distortion_scale = distortion_scale
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be Perspectively transformed.
|
||||
|
@ -1011,7 +1031,7 @@ class RandomPerspective(object):
|
|||
PIL Image: Random perspectivley transformed image.
|
||||
"""
|
||||
if not isinstance(img, Image.Image):
|
||||
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
|
||||
img = to_pil_image(img)
|
||||
|
||||
if random.random() < self.p:
|
||||
width, height = img.size
|
||||
|
@ -1119,7 +1139,7 @@ class RandomResizedCrop(object):
|
|||
j = (width - w) // 2
|
||||
return i, j, h, w
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be cropped and resized.
|
||||
|
@ -1127,6 +1147,8 @@ class RandomResizedCrop(object):
|
|||
Returns:
|
||||
PIL Image: Randomly cropped and resized image.
|
||||
"""
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
i, j, h, w = self.get_params(img, self.scale, self.ratio)
|
||||
return F_pil.resized_crop(img, i, j, h, w, self.size, self.interpolation)
|
||||
|
||||
|
@ -1174,7 +1196,9 @@ class FiveCrop(object):
|
|||
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
|
||||
self.size = size
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
return F_pil.five_crop(img, self.size)
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -1217,7 +1241,9 @@ class TenCrop(object):
|
|||
self.size = size
|
||||
self.vertical_flip = vertical_flip
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
return F_pil.ten_crop(img, self.size, self.vertical_flip)
|
||||
|
||||
def __repr__(self):
|
||||
|
@ -1275,7 +1301,7 @@ class RandomRotation(object):
|
|||
|
||||
return angle
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
"""
|
||||
Args:
|
||||
img (PIL Image): Image to be rotated.
|
||||
|
@ -1283,7 +1309,8 @@ class RandomRotation(object):
|
|||
Returns:
|
||||
PIL Image: Rotated image.
|
||||
"""
|
||||
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
angle = self.get_params(self.degrees)
|
||||
|
||||
return F_pil.rotate(img, angle, self.resample, self.expand, self.center, self.fill)
|
||||
|
@ -1405,13 +1432,15 @@ class RandomAffine(object):
|
|||
|
||||
return angle, translations, scale, shear
|
||||
|
||||
def __call__(self, img):
|
||||
def __call__(self, img:Image.Image):
|
||||
"""
|
||||
img (PIL Image): Image to be transformed.
|
||||
|
||||
Returns:
|
||||
PIL Image: Affine transformed image.
|
||||
"""
|
||||
if not isinstance(img, Image.Image):
|
||||
img = to_pil_image(img)
|
||||
ret = self.get_params(self.degrees, self.translate, self.scale, self.shear, img.size)
|
||||
return F_pil.affine(img, *ret, resample=self.resample, fillcolor=self.fillcolor)
|
||||
|
||||
|
|
Loading…
Reference in New Issue