mirror of https://github.com/Jittor/Jittor
fix pr#355&add unittest
This commit is contained in:
parent
13f9eaafc0
commit
b5f03f996b
|
@ -13,7 +13,7 @@ Backends in `einops` are organized to meet the following requirements
|
||||||
import sys
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
__author__ = 'Alex Rogozhnikov'
|
__author__ = 'Alex Rogozhnikov, RuiYang Liu'
|
||||||
|
|
||||||
_backends = {}
|
_backends = {}
|
||||||
_debug_importing = False
|
_debug_importing = False
|
||||||
|
@ -107,7 +107,6 @@ class AbstractBackend:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def is_float_type(self, x):
|
def is_float_type(self, x):
|
||||||
# some backends (torch) can't compute average for non-floating types.
|
|
||||||
# Decided to drop average for all backends if type is not floating
|
# Decided to drop average for all backends if type is not floating
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -117,6 +116,9 @@ class AbstractBackend:
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "<einops backend for {}>".format(self.framework_name)
|
return "<einops backend for {}>".format(self.framework_name)
|
||||||
|
|
||||||
|
def einsum(self, pattern, *x):
|
||||||
|
raise NotImplementedError("backend does not support einsum")
|
||||||
|
|
||||||
|
|
||||||
class UnknownSize:
|
class UnknownSize:
|
||||||
""" pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """
|
""" pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """
|
||||||
|
@ -163,11 +165,15 @@ class NumpyBackend(AbstractBackend):
|
||||||
return self.np.tile(x, repeats)
|
return self.np.tile(x, repeats)
|
||||||
|
|
||||||
def is_float_type(self, x):
|
def is_float_type(self, x):
|
||||||
return x.dtype in ('float16', 'float32', 'float64', 'float128')
|
return x.dtype in ('float16', 'float32', 'float64', 'float128', 'bfloat16')
|
||||||
|
|
||||||
def add_axis(self, x, new_position):
|
def add_axis(self, x, new_position):
|
||||||
return self.np.expand_dims(x, new_position)
|
return self.np.expand_dims(x, new_position)
|
||||||
|
|
||||||
|
def einsum(self, pattern, *x):
|
||||||
|
return self.np.einsum(pattern, *x)
|
||||||
|
|
||||||
|
|
||||||
class HashableTuple:
|
class HashableTuple:
|
||||||
"""Overcomes non-hashability of symbolic elements"""
|
"""Overcomes non-hashability of symbolic elements"""
|
||||||
|
|
||||||
|
@ -192,13 +198,10 @@ class JittorBackend(AbstractBackend):
|
||||||
self.jittor = jittor
|
self.jittor = jittor
|
||||||
|
|
||||||
def is_appropriate_type(self, tensor):
|
def is_appropriate_type(self, tensor):
|
||||||
return isinstance(tensor, self.jittor.jittor_core.Var)
|
return isinstance(tensor, self.jittor.Var)
|
||||||
|
|
||||||
def from_numpy(self, x):
|
def from_numpy(self, x):
|
||||||
variable = self.jittor.array(x)
|
variable = self.jittor.array(x)
|
||||||
if self.is_float_type(variable):
|
|
||||||
# attach grad only to floating types
|
|
||||||
variable.requires_grad = True
|
|
||||||
return variable
|
return variable
|
||||||
|
|
||||||
def to_numpy(self, x):
|
def to_numpy(self, x):
|
||||||
|
@ -208,21 +211,24 @@ class JittorBackend(AbstractBackend):
|
||||||
return self.jittor.arange(start, stop, dtype='int64')
|
return self.jittor.arange(start, stop, dtype='int64')
|
||||||
|
|
||||||
def shape(self, x):
|
def shape(self, x):
|
||||||
return HashableTuple(tuple(x.shape))
|
return tuple(x.shape)
|
||||||
|
|
||||||
def reshape(self, x, shape):
|
def reshape(self, x, shape):
|
||||||
|
if len(shape) == 0:
|
||||||
|
return x
|
||||||
return self.jittor.reshape(x, shape)
|
return self.jittor.reshape(x, shape)
|
||||||
|
|
||||||
# def reduce(self, x, operation, axes):
|
|
||||||
# return getattr(x, operation)(dim=axes)
|
|
||||||
|
|
||||||
def reduce(self, x, operation, reduced_axes):
|
def reduce(self, x, operation, reduced_axes):
|
||||||
|
|
||||||
|
if operation == 'prod':
|
||||||
|
#avoid overflow
|
||||||
|
return x.prod(reduced_axes)
|
||||||
for axis in sorted(reduced_axes, reverse=True):
|
for axis in sorted(reduced_axes, reverse=True):
|
||||||
if operation == 'min':
|
if operation == 'min':
|
||||||
x = x.min(dim=axis)
|
x = x.min(dim=axis)
|
||||||
elif operation == 'max':
|
elif operation == 'max':
|
||||||
x = x.max(dim=axis)
|
x = x.max(dim=axis)
|
||||||
elif operation in ['sum', 'mean', 'prod']:
|
elif operation in ['sum', 'mean']:
|
||||||
x = getattr(x, operation)(dim=axis)
|
x = getattr(x, operation)(dim=axis)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError('Unknown reduction ', operation)
|
raise NotImplementedError('Unknown reduction ', operation)
|
||||||
|
@ -253,3 +259,6 @@ class JittorBackend(AbstractBackend):
|
||||||
def layers(self):
|
def layers(self):
|
||||||
from jittor.einops.layers import jittor
|
from jittor.einops.layers import jittor
|
||||||
return jittor
|
return jittor
|
||||||
|
|
||||||
|
def einsum(self, pattern, *x):
|
||||||
|
return self.jittor.linalg.einsum(pattern, *x)
|
|
@ -1,84 +0,0 @@
|
||||||
"""
|
|
||||||
Specialization of einops for jittor.
|
|
||||||
|
|
||||||
Unfortunately, jittor's jit scripting mechanism isn't strong enough,
|
|
||||||
and to have scripting supported at least for layers,
|
|
||||||
a number of changes is required, and this layer helps.
|
|
||||||
|
|
||||||
Importantly, whole lib is designed so that you can't use it
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
import jittor as jt
|
|
||||||
from jittor.einops.einops import TransformRecipe, _reconstruct_from_shape_uncached
|
|
||||||
|
|
||||||
|
|
||||||
class JittorJitBackend:
|
|
||||||
"""
|
|
||||||
Completely static backend that mimics part of normal backend functionality
|
|
||||||
but restricted to jittor stuff only
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def reduce(x: jt.jittor_core.Var, operation: str, reduced_axes: List[int]):
|
|
||||||
if operation == 'min':
|
|
||||||
return x.min(dims=reduced_axes)
|
|
||||||
elif operation == 'max':
|
|
||||||
return x.max(dims=reduced_axes)
|
|
||||||
elif operation == 'sum':
|
|
||||||
return x.sum(dims=reduced_axes)
|
|
||||||
elif operation == 'mean':
|
|
||||||
return x.mean(dims=reduced_axes)
|
|
||||||
elif operation == 'prod':
|
|
||||||
for i in list(sorted(reduced_axes))[::-1]:
|
|
||||||
x = x.prod(dim=i)
|
|
||||||
return x
|
|
||||||
else:
|
|
||||||
raise NotImplementedError('Unknown reduction ', operation)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def transpose(x, axes: List[int]):
|
|
||||||
return x.permute(axes)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def stack_on_zeroth_dimension(tensors: List[jt.jittor_core.Var]):
|
|
||||||
return jt.stack(tensors)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def tile(x, repeats: List[int]):
|
|
||||||
return x.repeat(repeats)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def add_axes(x, n_axes: int, pos2len: Dict[int, int]):
|
|
||||||
repeats = [1] * n_axes
|
|
||||||
for axis_position, axis_length in pos2len.items():
|
|
||||||
x = jt.unsqueeze(x, axis_position)
|
|
||||||
repeats[axis_position] = axis_length
|
|
||||||
return JittorJitBackend.tile(x, repeats)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def is_float_type(x):
|
|
||||||
return x.dtype in ["float16", "float32", "float64"]
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def shape(x):
|
|
||||||
return x.shape
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def reshape(x, shape: List[int]):
|
|
||||||
return x.reshape(shape)
|
|
||||||
|
|
||||||
|
|
||||||
# mirrors einops.einops._apply_recipe
|
|
||||||
def apply_for_scriptable_jittor(recipe: TransformRecipe, tensor: jt.jittor_core.Var, reduction_type: str) -> jt.jittor_core.Var:
|
|
||||||
backend = JittorJitBackend
|
|
||||||
init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \
|
|
||||||
_reconstruct_from_shape_uncached(recipe, backend.shape(tensor))
|
|
||||||
tensor = backend.reshape(tensor, init_shapes)
|
|
||||||
if len(reduced_axes) > 0:
|
|
||||||
tensor = backend.reduce(tensor, operation=reduction_type, reduced_axes=reduced_axes)
|
|
||||||
tensor = backend.transpose(tensor, axes_reordering)
|
|
||||||
if len(added_axes) > 0:
|
|
||||||
tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes)
|
|
||||||
return backend.reshape(tensor, final_shapes)
|
|
|
@ -1,5 +1,6 @@
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
|
import string
|
||||||
import typing
|
import typing
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar
|
from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar
|
||||||
|
@ -393,7 +394,7 @@ def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: i
|
||||||
```
|
```
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
|
tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var).
|
||||||
list of tensors is also accepted, those should be of the same type and shape
|
list of tensors is also accepted, those should be of the same type and shape
|
||||||
pattern: string, reduction pattern
|
pattern: string, reduction pattern
|
||||||
reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
|
reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
|
||||||
|
@ -418,13 +419,7 @@ def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: i
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@typing.overload
|
def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor:
|
||||||
def rearrange(tensor: Tensor, pattern: str, **axes_length: int) -> Tensor: ...
|
|
||||||
@typing.overload
|
|
||||||
def rearrange(tensor: List[Tensor], pattern: str, **axes_lengths: int) -> Tensor: ...
|
|
||||||
|
|
||||||
|
|
||||||
def rearrange(tensor, pattern: str, **axes_lengths):
|
|
||||||
"""
|
"""
|
||||||
einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors.
|
einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors.
|
||||||
This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
|
This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
|
||||||
|
@ -470,7 +465,7 @@ def rearrange(tensor, pattern: str, **axes_lengths):
|
||||||
Find more examples in einops tutorial.
|
Find more examples in einops tutorial.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
|
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var).
|
||||||
list of tensors is also accepted, those should be of the same type and shape
|
list of tensors is also accepted, those should be of the same type and shape
|
||||||
pattern: string, rearrangement pattern
|
pattern: string, rearrangement pattern
|
||||||
axes_lengths: any additional specifications for dimensions
|
axes_lengths: any additional specifications for dimensions
|
||||||
|
@ -506,8 +501,8 @@ def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
|
||||||
(60, 40)
|
(60, 40)
|
||||||
|
|
||||||
# repeat image 2 time along height and 3 times along width
|
# repeat image 2 time along height and 3 times along width
|
||||||
>>> repeat(image, 'h w -> h (repeat w)', repeat=3).shape
|
>>> repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape
|
||||||
(30, 120)
|
(60, 120)
|
||||||
|
|
||||||
# convert each pixel to a small square 2x2. Upsample image by 2x
|
# convert each pixel to a small square 2x2. Upsample image by 2x
|
||||||
>>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
|
>>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
|
||||||
|
@ -524,7 +519,7 @@ def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
|
||||||
Find more examples in einops tutorial.
|
Find more examples in einops tutorial.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
|
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var).
|
||||||
list of tensors is also accepted, those should be of the same type and shape
|
list of tensors is also accepted, those should be of the same type and shape
|
||||||
pattern: string, rearrangement pattern
|
pattern: string, rearrangement pattern
|
||||||
axes_lengths: any additional specifications for dimensions
|
axes_lengths: any additional specifications for dimensions
|
||||||
|
@ -536,7 +531,7 @@ def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
|
||||||
return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
|
return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
|
||||||
|
|
||||||
|
|
||||||
def parse_shape(x, pattern: str):
|
def parse_shape(x, pattern: str) -> dict:
|
||||||
"""
|
"""
|
||||||
Parse a tensor shape to dictionary mapping axes names to their lengths.
|
Parse a tensor shape to dictionary mapping axes names to their lengths.
|
||||||
|
|
||||||
|
@ -579,11 +574,11 @@ def parse_shape(x, pattern: str):
|
||||||
ellipsis_idx = exp.composition.index(_ellipsis)
|
ellipsis_idx = exp.composition.index(_ellipsis)
|
||||||
composition = (exp.composition[:ellipsis_idx] +
|
composition = (exp.composition[:ellipsis_idx] +
|
||||||
['_'] * (len(shape) - len(exp.composition) + 1) +
|
['_'] * (len(shape) - len(exp.composition) + 1) +
|
||||||
exp.composition[ellipsis_idx+1:])
|
exp.composition[ellipsis_idx + 1:])
|
||||||
else:
|
else:
|
||||||
composition = exp.composition
|
composition = exp.composition
|
||||||
result = {}
|
result = {}
|
||||||
for (axis_name, ), axis_length in zip(composition, shape):
|
for (axis_name,), axis_length in zip(composition, shape):
|
||||||
if axis_name != '_':
|
if axis_name != '_':
|
||||||
result[axis_name] = axis_length
|
result[axis_name] = axis_length
|
||||||
return result
|
return result
|
||||||
|
@ -623,3 +618,165 @@ def asnumpy(tensor) -> 'numpy.ndarray':
|
||||||
`numpy.ndarray`, converted to numpy
|
`numpy.ndarray`, converted to numpy
|
||||||
"""
|
"""
|
||||||
return get_backend(tensor).to_numpy(tensor)
|
return get_backend(tensor).to_numpy(tensor)
|
||||||
|
|
||||||
|
def _validate_einsum_axis_name(axis_name):
|
||||||
|
if len(axis_name) == 0:
|
||||||
|
raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
|
||||||
|
if len(axis_name) > 1:
|
||||||
|
raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")
|
||||||
|
|
||||||
|
axis_name = axis_name[0]
|
||||||
|
|
||||||
|
if isinstance(axis_name, AnonymousAxis):
|
||||||
|
raise NotImplementedError("Anonymous axes are not yet supported in einsum.")
|
||||||
|
if len(axis_name) == 0:
|
||||||
|
raise RuntimeError("Encountered empty axis name in einsum.")
|
||||||
|
if not isinstance(axis_name, str):
|
||||||
|
raise RuntimeError("Axis name in einsum must be a string.")
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(256)
|
||||||
|
def _compactify_pattern_for_einsum(pattern: str) -> str:
|
||||||
|
if "->" not in pattern:
|
||||||
|
# numpy allows this, so make sure users
|
||||||
|
# don't accidentally do something like this.
|
||||||
|
raise ValueError("Einsum pattern must contain '->'.")
|
||||||
|
lefts, right = pattern.split('->')
|
||||||
|
lefts = lefts.split(',')
|
||||||
|
|
||||||
|
lefts = [
|
||||||
|
ParsedExpression(left, allow_underscore=True, allow_duplicates=True)
|
||||||
|
for left in lefts
|
||||||
|
]
|
||||||
|
|
||||||
|
right = ParsedExpression(right, allow_underscore=True)
|
||||||
|
|
||||||
|
# Start from 'a' and go up to 'Z'
|
||||||
|
output_axis_names = string.ascii_letters
|
||||||
|
i = 0
|
||||||
|
axis_name_mapping = {}
|
||||||
|
|
||||||
|
left_patterns = []
|
||||||
|
for left in lefts:
|
||||||
|
left_pattern = ""
|
||||||
|
for raw_axis_name in left.composition:
|
||||||
|
|
||||||
|
if raw_axis_name == _ellipsis:
|
||||||
|
left_pattern += '...'
|
||||||
|
continue
|
||||||
|
|
||||||
|
_validate_einsum_axis_name(raw_axis_name)
|
||||||
|
axis_name = raw_axis_name[0]
|
||||||
|
if axis_name not in axis_name_mapping:
|
||||||
|
if i >= len(output_axis_names):
|
||||||
|
raise RuntimeError("Too many axes in einsum.")
|
||||||
|
axis_name_mapping[axis_name] = output_axis_names[i]
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
left_pattern += axis_name_mapping[axis_name]
|
||||||
|
left_patterns.append(left_pattern)
|
||||||
|
|
||||||
|
compact_pattern = ",".join(left_patterns) + "->"
|
||||||
|
|
||||||
|
for raw_axis_name in right.composition:
|
||||||
|
if raw_axis_name == _ellipsis:
|
||||||
|
compact_pattern += '...'
|
||||||
|
continue
|
||||||
|
|
||||||
|
_validate_einsum_axis_name(raw_axis_name)
|
||||||
|
axis_name = raw_axis_name[0]
|
||||||
|
|
||||||
|
if axis_name not in axis_name_mapping:
|
||||||
|
raise EinopsError(f"Unknown axis {axis_name} on right side of einsum {pattern}.")
|
||||||
|
|
||||||
|
compact_pattern += axis_name_mapping[axis_name]
|
||||||
|
|
||||||
|
return compact_pattern
|
||||||
|
|
||||||
|
|
||||||
|
@typing.overload
|
||||||
|
def einsum(tensor: Tensor, pattern: str) -> Tensor: ...
|
||||||
|
@typing.overload
|
||||||
|
def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str) -> Tensor: ...
|
||||||
|
@typing.overload
|
||||||
|
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str) -> Tensor: ...
|
||||||
|
@typing.overload
|
||||||
|
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str) -> Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
def einsum(*tensors_and_pattern: List[Union[Tensor, str]]) -> Tensor:
|
||||||
|
"""
|
||||||
|
einops.einsum calls einsum operations with einops-style named
|
||||||
|
axes indexing, computing tensor products with an arbitrary
|
||||||
|
number of tensors. Unlike typical einsum syntax, here you must
|
||||||
|
pass tensors first, and then the pattern.
|
||||||
|
|
||||||
|
Also, note that rearrange operations such as `"(batch chan) out"`,
|
||||||
|
or singleton axes `()`, are not currently supported.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
For a given pattern such as:
|
||||||
|
```python
|
||||||
|
>>> x, y, z = np.random.randn(3, 20, 20, 20)
|
||||||
|
>>> output = einsum(x, y, z, "a b c, c b d, a g k -> a b k")
|
||||||
|
|
||||||
|
```
|
||||||
|
the following formula is computed:
|
||||||
|
```tex
|
||||||
|
output[a, b, k] =
|
||||||
|
\sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k]
|
||||||
|
```
|
||||||
|
where the summation over `c`, `d`, and `g` is performed
|
||||||
|
because those axes names do not appear on the right-hand side.
|
||||||
|
|
||||||
|
Let's see some additional examples:
|
||||||
|
```python
|
||||||
|
# Filter a set of images:
|
||||||
|
>>> batched_images = np.random.randn(128, 16, 16)
|
||||||
|
>>> filters = np.random.randn(16, 16, 30)
|
||||||
|
>>> result = einsum(batched_images, filters,
|
||||||
|
... "batch h w, h w channel -> batch channel")
|
||||||
|
>>> result.shape
|
||||||
|
(128, 30)
|
||||||
|
|
||||||
|
# Matrix multiplication, with an unknown input shape:
|
||||||
|
>>> batch_shape = (50, 30)
|
||||||
|
>>> data = np.random.randn(*batch_shape, 20)
|
||||||
|
>>> weights = np.random.randn(10, 20)
|
||||||
|
>>> result = einsum(weights, data,
|
||||||
|
... "out_dim in_dim, ... in_dim -> ... out_dim")
|
||||||
|
>>> result.shape
|
||||||
|
(50, 30, 10)
|
||||||
|
|
||||||
|
# Matrix trace on a single tensor:
|
||||||
|
>>> matrix = np.random.randn(10, 10)
|
||||||
|
>>> result = einsum(matrix, "i i ->")
|
||||||
|
>>> result.shape
|
||||||
|
()
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
tensors: tensors of any supported library (numpy, jittor).
|
||||||
|
pattern: string, einsum pattern, with commas
|
||||||
|
separating specifications for each tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of the same type as input, after processing with einsum.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if len(tensors_and_pattern) <= 1:
|
||||||
|
raise ValueError(
|
||||||
|
"`einops.einsum` takes at minimum two arguments: the tensors (at least one),"
|
||||||
|
" followed by the pattern."
|
||||||
|
)
|
||||||
|
pattern = tensors_and_pattern[-1]
|
||||||
|
if not isinstance(pattern, str):
|
||||||
|
raise ValueError(
|
||||||
|
"The last argument passed to `einops.einsum` must be a string,"
|
||||||
|
" representing the einsum pattern."
|
||||||
|
)
|
||||||
|
tensors = tensors_and_pattern[:-1]
|
||||||
|
pattern = _compactify_pattern_for_einsum(pattern)
|
||||||
|
return get_backend(tensors[0]).einsum(pattern, *tensors)
|
|
@ -49,8 +49,8 @@ class _EinmixMixin:
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
:param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output
|
:param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output
|
||||||
:param weight_shape: axes of weight. Tensor od this shape is created, stored, and optimized in a layer
|
:param weight_shape: axes of weight. A tensor of this shape is created, stored, and optimized in a layer
|
||||||
:param bias_shape: axes of bias added to output.
|
:param bias_shape: axes of bias added to output. Weights of this shape are created and stored. If `None` (the default), no bias is added.
|
||||||
:param axes_lengths: dimensions of weight tensor
|
:param axes_lengths: dimensions of weight tensor
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -58,7 +58,9 @@ class _EinmixMixin:
|
||||||
self.weight_shape = weight_shape
|
self.weight_shape = weight_shape
|
||||||
self.bias_shape = bias_shape
|
self.bias_shape = bias_shape
|
||||||
self.axes_lengths = axes_lengths
|
self.axes_lengths = axes_lengths
|
||||||
|
self.initialize_einmix(pattern=pattern, weight_shape=weight_shape, bias_shape=bias_shape, axes_lengths=axes_lengths)
|
||||||
|
|
||||||
|
def initialize_einmix(self, pattern, weight_shape, bias_shape, axes_lengths):
|
||||||
left_pattern, right_pattern = pattern.split('->')
|
left_pattern, right_pattern = pattern.split('->')
|
||||||
left = ParsedExpression(left_pattern)
|
left = ParsedExpression(left_pattern)
|
||||||
right = ParsedExpression(right_pattern)
|
right = ParsedExpression(right_pattern)
|
||||||
|
@ -84,7 +86,7 @@ class _EinmixMixin:
|
||||||
names += group
|
names += group
|
||||||
composition = ' '.join(names)
|
composition = ' '.join(names)
|
||||||
pre_reshape_pattern = f'{left_pattern}->{composition}'
|
pre_reshape_pattern = f'{left_pattern}->{composition}'
|
||||||
pre_reshape_lengths = {name: length for name, length in self.axes_lengths.items() if name in names}
|
pre_reshape_lengths = {name: length for name, length in axes_lengths.items() if name in names}
|
||||||
|
|
||||||
if any(len(group) != 1 for group in right.composition):
|
if any(len(group) != 1 for group in right.composition):
|
||||||
names = []
|
names = []
|
||||||
|
@ -134,7 +136,6 @@ class _EinmixMixin:
|
||||||
_bias_shape.append(1)
|
_bias_shape.append(1)
|
||||||
else:
|
else:
|
||||||
_bias_shape = None
|
_bias_shape = None
|
||||||
_bias_input_size = None
|
|
||||||
|
|
||||||
weight_bound = (3 / _fan_in) ** 0.5
|
weight_bound = (3 / _fan_in) ** 0.5
|
||||||
bias_bound = (1 / _fan_in) ** 0.5
|
bias_bound = (1 / _fan_in) ** 0.5
|
||||||
|
|
|
@ -6,27 +6,18 @@ import numpy as np
|
||||||
|
|
||||||
from jittor.einops.layers import RearrangeMixin, ReduceMixin
|
from jittor.einops.layers import RearrangeMixin, ReduceMixin
|
||||||
from jittor.einops.layers._einmix import _EinmixMixin
|
from jittor.einops.layers._einmix import _EinmixMixin
|
||||||
from jittor.einops._jittor_specific import apply_for_scriptable_jittor
|
|
||||||
|
|
||||||
__author__ = 'Ruiyang Liu'
|
__author__ = 'Ruiyang Liu'
|
||||||
|
|
||||||
|
|
||||||
class Rearrange(RearrangeMixin, jt.nn.Module):
|
class Rearrange(RearrangeMixin, jt.nn.Module):
|
||||||
def execute(self, input):
|
def execute(self, input):
|
||||||
return apply_for_scriptable_jittor(self._recipe, input, reduction_type='rearrange')
|
return self._apply_recipe(input)
|
||||||
|
|
||||||
def _apply_recipe(self, x):
|
|
||||||
# overriding parent method to prevent it's scripting
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Reduce(ReduceMixin, jt.nn.Module):
|
class Reduce(ReduceMixin, jt.nn.Module):
|
||||||
def execute(self, input):
|
def execute(self, input):
|
||||||
return apply_for_scriptable_jittor(self._recipe, input, reduction_type=self.reduction)
|
return self._apply_recipe(input)
|
||||||
|
|
||||||
def _apply_recipe(self, x):
|
|
||||||
# overriding parent method to prevent it's scripting
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class EinMix(_EinmixMixin, jt.nn.Module):
|
class EinMix(_EinmixMixin, jt.nn.Module):
|
||||||
|
|
|
@ -26,7 +26,7 @@ class ParsedExpression:
|
||||||
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
|
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
|
||||||
and keeps some information important for downstream
|
and keeps some information important for downstream
|
||||||
"""
|
"""
|
||||||
def __init__(self, expression, *, allow_underscore: bool = False):
|
def __init__(self, expression, *, allow_underscore: bool = False, allow_duplicates: bool = False):
|
||||||
self.has_ellipsis: bool = False
|
self.has_ellipsis: bool = False
|
||||||
self.has_ellipsis_parenthesized: Optional[bool] = None
|
self.has_ellipsis_parenthesized: Optional[bool] = None
|
||||||
self.identifiers: Set[str] = set()
|
self.identifiers: Set[str] = set()
|
||||||
|
@ -48,7 +48,7 @@ class ParsedExpression:
|
||||||
def add_axis_name(x):
|
def add_axis_name(x):
|
||||||
if x is not None:
|
if x is not None:
|
||||||
if x in self.identifiers:
|
if x in self.identifiers:
|
||||||
if not (allow_underscore and x == "_"):
|
if not (allow_underscore and x == "_") and not allow_duplicates:
|
||||||
raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
|
raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
|
||||||
if x == _ellipsis:
|
if x == _ellipsis:
|
||||||
self.identifiers.add(_ellipsis)
|
self.identifiers.add(_ellipsis)
|
||||||
|
|
|
@ -0,0 +1,616 @@
|
||||||
|
# ***************************************************************
|
||||||
|
# Copyright (c) 2022 Jittor. All Rights Reserved.
|
||||||
|
# Maintainers: DongYang Li <lidongyang2001@gmail.com>.
|
||||||
|
# This file is subject to the terms and conditions defined in
|
||||||
|
# file 'LICENSE.txt', which is part of this source code package.
|
||||||
|
# ***************************************************************
|
||||||
|
from collections import namedtuple
|
||||||
|
import tempfile
|
||||||
|
import pickle
|
||||||
|
import itertools
|
||||||
|
from jittor.einops.einops import (rearrange, reduce, _enumerate_directions, _reductions)
|
||||||
|
from jittor.einops import EinopsError
|
||||||
|
import jittor as jt
|
||||||
|
import numpy
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
# tests/__init__.py
|
||||||
|
import os
|
||||||
|
from jittor.einops import _backends
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
flag_to_bool = {
|
||||||
|
'': False,
|
||||||
|
'0': False,
|
||||||
|
'1': True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def collect_test_backends(symbolic=False, layers=False):
|
||||||
|
"""
|
||||||
|
:param symbolic: symbolic or imperative frameworks?
|
||||||
|
:param layers: layers or operations?
|
||||||
|
:return: list of backends satisfying set conditions
|
||||||
|
"""
|
||||||
|
if not symbolic:
|
||||||
|
if not layers:
|
||||||
|
backend_types = [
|
||||||
|
_backends.NumpyBackend,
|
||||||
|
_backends.JittorBackend,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
backend_types = [
|
||||||
|
_backends.JittorBackend,
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
backend_types = []
|
||||||
|
result = []
|
||||||
|
for backend_type in backend_types:
|
||||||
|
try:
|
||||||
|
result.append(backend_type())
|
||||||
|
except ImportError:
|
||||||
|
# problem with backend installation fails a specific test function,
|
||||||
|
# but will be skipped in all other test cases
|
||||||
|
warnings.warn('backend could not be initialized for tests: {}'.format(backend_type))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# test/test_ops.py
|
||||||
|
|
||||||
|
imp_op_backends = collect_test_backends(symbolic=False, layers=False)
|
||||||
|
|
||||||
|
# test/test_layer.py
|
||||||
|
|
||||||
|
|
||||||
|
class TestSlice(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_anonymous_axes(self):
|
||||||
|
x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6])
|
||||||
|
for pattern, axis_dimensions in test_cases_repeat_anonymous:
|
||||||
|
check_reversion(x, pattern, **axis_dimensions)
|
||||||
|
|
||||||
|
def test_repeat_imperatives(self):
|
||||||
|
x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5])
|
||||||
|
for backend in imp_op_backends:
|
||||||
|
print('Repeat tests for ', backend.framework_name)
|
||||||
|
|
||||||
|
for pattern, axis_dimensions in repeat_test_cases:
|
||||||
|
expected = reduce(x, pattern, reduction='repeat', **axis_dimensions)
|
||||||
|
converted = backend.from_numpy(x)
|
||||||
|
repeated = reduce(converted, pattern, reduction='repeat', **axis_dimensions)
|
||||||
|
result = backend.to_numpy(repeated)
|
||||||
|
assert numpy.array_equal(result, expected)
|
||||||
|
|
||||||
|
def test_repeat_numpy(self):
|
||||||
|
# check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well
|
||||||
|
x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5])
|
||||||
|
x1 = reduce(x, 'a b c -> copy a b c ', reduction='repeat', copy=1)
|
||||||
|
assert numpy.array_equal(x[None], x1)
|
||||||
|
for pattern, axis_dimensions in repeat_test_cases:
|
||||||
|
check_reversion(x, pattern, **axis_dimensions)
|
||||||
|
|
||||||
|
def test_tiling_imperatives(self):
|
||||||
|
for backend in imp_op_backends:
|
||||||
|
print('Tiling tests for ', backend.framework_name)
|
||||||
|
input = numpy.arange(2 * 3 * 5, dtype='int64').reshape([2, 1, 3, 1, 5])
|
||||||
|
test_cases = [
|
||||||
|
(1, 1, 1, 1, 1),
|
||||||
|
(1, 2, 1, 3, 1),
|
||||||
|
(3, 1, 1, 4, 1),
|
||||||
|
]
|
||||||
|
for repeats in test_cases:
|
||||||
|
expected = numpy.tile(input, repeats)
|
||||||
|
converted = backend.from_numpy(input)
|
||||||
|
repeated = backend.tile(converted, repeats)
|
||||||
|
result = backend.to_numpy(repeated)
|
||||||
|
assert numpy.array_equal(result, expected)
|
||||||
|
|
||||||
|
def test_gradients_imperatives(self):
|
||||||
|
# lazy - just checking reductions
|
||||||
|
for reduction in _reductions:
|
||||||
|
x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype('float32')
|
||||||
|
results = {}
|
||||||
|
for backend in imp_op_backends:
|
||||||
|
y0 = backend.from_numpy(x)
|
||||||
|
if not 'jittor' in backend.framework_name and not hasattr(y0, 'grad'):
|
||||||
|
continue
|
||||||
|
y1 = reduce(y0, 'a b c -> c a', reduction=reduction)
|
||||||
|
y2 = reduce(y1, 'c a -> a c', reduction=reduction)
|
||||||
|
y3 = reduce(y2, 'a (c1 c2) -> a', reduction=reduction, c1=2)
|
||||||
|
y4 = reduce(y3, '... -> ', reduction=reduction)
|
||||||
|
if 'jittor' in backend.framework_name:
|
||||||
|
grad = backend.jittor.grad(y4, y0)
|
||||||
|
else:
|
||||||
|
y4.backward()
|
||||||
|
grad = y0.grad
|
||||||
|
results[backend.framework_name] = backend.to_numpy(grad)
|
||||||
|
|
||||||
|
print('comparing gradients for', results.keys())
|
||||||
|
for name1, grad1 in results.items():
|
||||||
|
for name2, grad2 in results.items():
|
||||||
|
assert numpy.allclose(grad1, grad2), [name1, name2, 'provided different gradients']
|
||||||
|
|
||||||
|
def test_concatenations_and_stacking(self):
|
||||||
|
for backend in imp_op_backends:
|
||||||
|
print('testing shapes for ', backend.framework_name)
|
||||||
|
for n_arrays in [1, 2, 5]:
|
||||||
|
shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
|
||||||
|
for shape in shapes:
|
||||||
|
if (backend.framework_name == 'jittor')\
|
||||||
|
and len(shape) == 0:
|
||||||
|
# jittor stores scalar in 1d array
|
||||||
|
continue
|
||||||
|
arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)]
|
||||||
|
arrays2 = [backend.from_numpy(array) for array in arrays1]
|
||||||
|
result0 = numpy.asarray(arrays1)
|
||||||
|
result1 = rearrange(arrays1, '...->...')
|
||||||
|
result2 = rearrange(arrays2, '...->...')
|
||||||
|
assert numpy.array_equal(result0, result1)
|
||||||
|
assert numpy.array_equal(result1, backend.to_numpy(result2))
|
||||||
|
|
||||||
|
result1 = rearrange(arrays1, 'b ... -> ... b')
|
||||||
|
result2 = rearrange(arrays2, 'b ... -> ... b')
|
||||||
|
assert numpy.array_equal(result1, backend.to_numpy(result2))
|
||||||
|
|
||||||
|
def test_enumerating_directions(self):
|
||||||
|
for backend in imp_op_backends:
|
||||||
|
print('testing directions for', backend.framework_name)
|
||||||
|
for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]:
|
||||||
|
if (backend.framework_name == 'jittor')\
|
||||||
|
and len(shape) == 0:
|
||||||
|
# jittor stores scalar in 1d array
|
||||||
|
continue
|
||||||
|
x = numpy.arange(numpy.prod(shape)).reshape(shape)
|
||||||
|
axes1 = _enumerate_directions(x)
|
||||||
|
axes2 = _enumerate_directions(backend.from_numpy(x))
|
||||||
|
assert len(axes1) == len(axes2) == len(shape)
|
||||||
|
for ax1, ax2 in zip(axes1, axes2):
|
||||||
|
ax2 = backend.to_numpy(ax2)
|
||||||
|
assert ax1.shape == ax2.shape
|
||||||
|
assert numpy.allclose(ax1, ax2)
|
||||||
|
|
||||||
|
def test_reduction_with_callable_imperatives(self):
|
||||||
|
x_numpy = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]).astype('float32')
|
||||||
|
x_numpy /= x_numpy.max()
|
||||||
|
|
||||||
|
def logsumexp_jittor(x, tuple_of_axes):
|
||||||
|
import jittor as jt
|
||||||
|
return jt.nn.logsumexp(x, tuple_of_axes)
|
||||||
|
|
||||||
|
def logsumexp_numpy(x, tuple_of_axes):
|
||||||
|
# very naive logsumexp to compare to
|
||||||
|
minused = x.max(tuple_of_axes)
|
||||||
|
y = x - x.max(tuple_of_axes, keepdims=True)
|
||||||
|
y = numpy.exp(y)
|
||||||
|
y = numpy.sum(y, axis=tuple_of_axes)
|
||||||
|
return numpy.log(y) + minused
|
||||||
|
|
||||||
|
from jittor.einops._backends import JittorBackend, NumpyBackend
|
||||||
|
backend2callback = {
|
||||||
|
JittorBackend.framework_name: logsumexp_jittor,
|
||||||
|
NumpyBackend.framework_name: logsumexp_numpy,
|
||||||
|
}
|
||||||
|
|
||||||
|
for backend in imp_op_backends:
|
||||||
|
if backend.framework_name not in backend2callback:
|
||||||
|
continue
|
||||||
|
|
||||||
|
backend_callback = backend2callback[backend.framework_name]
|
||||||
|
|
||||||
|
x_backend = backend.from_numpy(x_numpy)
|
||||||
|
for pattern1, pattern2 in equivalent_reduction_patterns:
|
||||||
|
print('Test reduction with callable for ', backend.framework_name, pattern1, pattern2)
|
||||||
|
output_numpy = reduce(x_numpy, pattern1, reduction=logsumexp_numpy)
|
||||||
|
output_backend = reduce(x_backend, pattern1, reduction=backend_callback)
|
||||||
|
assert numpy.allclose(
|
||||||
|
output_numpy,
|
||||||
|
backend.to_numpy(output_backend),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_reduction_stress_imperatives(self):
|
||||||
|
for backend in imp_op_backends:
|
||||||
|
print('Stress-testing reduction for ', backend.framework_name)
|
||||||
|
for reduction in _reductions + ('rearrange',):
|
||||||
|
dtype = 'int64'
|
||||||
|
coincide = numpy.array_equal
|
||||||
|
if reduction in ['mean', 'prod']:
|
||||||
|
dtype = 'float64'
|
||||||
|
coincide = numpy.allclose
|
||||||
|
for n_axes in range(11):
|
||||||
|
shape = numpy.random.randint(2, 4, size=n_axes)
|
||||||
|
permutation = numpy.random.permutation(n_axes)
|
||||||
|
skipped = 0 if reduction == 'rearrange' else numpy.random.randint(n_axes + 1)
|
||||||
|
left = ' '.join('x' + str(i) for i in range(n_axes))
|
||||||
|
right = ' '.join('x' + str(i) for i in permutation[skipped:])
|
||||||
|
pattern = left + '->' + right
|
||||||
|
x = numpy.arange(1, 1 + numpy.prod(shape), dtype=dtype).reshape(shape)
|
||||||
|
if reduction == 'prod':
|
||||||
|
x /= x.mean() # to avoid overflows
|
||||||
|
result1 = reduce(x, pattern, reduction=reduction)
|
||||||
|
result2 = x.transpose(permutation)
|
||||||
|
if skipped > 0:
|
||||||
|
result2 = getattr(result2, reduction)(axis=tuple(range(skipped)))
|
||||||
|
assert coincide(result1, result2)
|
||||||
|
check_op_against_numpy(backend, x, pattern, reduction=reduction, axes_lengths={}, is_symbolic=False)
|
||||||
|
|
||||||
|
def test_reduction_imperatives(self):
|
||||||
|
for backend in imp_op_backends:
|
||||||
|
print('Reduction tests for ', backend.framework_name)
|
||||||
|
for reduction in _reductions:
|
||||||
|
# slight redundancy for simpler order - numpy version is evaluated multiple times
|
||||||
|
input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype='int64').reshape([2, 3, 4, 5, 6])
|
||||||
|
if reduction in ['mean', 'prod']:
|
||||||
|
input = input / input.astype('float64').mean()
|
||||||
|
test_cases = [
|
||||||
|
['a b c d e -> ', {},
|
||||||
|
getattr(input, reduction)()],
|
||||||
|
['a ... -> ', {},
|
||||||
|
getattr(input, reduction)()],
|
||||||
|
['(a1 a2) ... (e1 e2) -> ', dict(a1=1, e2=2),
|
||||||
|
getattr(input, reduction)()],
|
||||||
|
['a b c d e -> (e c) a', {},
|
||||||
|
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
|
||||||
|
['a ... c d e -> (e c) a', {},
|
||||||
|
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
|
||||||
|
['a b c d e ... -> (e c) a', {},
|
||||||
|
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
|
||||||
|
['a b c d e -> (e c a)', {},
|
||||||
|
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])],
|
||||||
|
['(a a2) ... -> (a2 a) ...', dict(a2=1),
|
||||||
|
input],
|
||||||
|
]
|
||||||
|
for pattern, axes_lengths, expected_result in test_cases:
|
||||||
|
result = reduce(backend.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths)
|
||||||
|
result = backend.to_numpy(result)
|
||||||
|
assert numpy.allclose(result, expected_result)
|
||||||
|
|
||||||
|
def test_rearrange_permutations_numpy(self):
|
||||||
|
# tests random permutation of axes against two independent numpy ways
|
||||||
|
for n_axes in range(1, 10):
|
||||||
|
input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
|
||||||
|
permutation = numpy.random.permutation(n_axes)
|
||||||
|
left_expression = ' '.join('i' + str(axis) for axis in range(n_axes))
|
||||||
|
right_expression = ' '.join('i' + str(axis) for axis in permutation)
|
||||||
|
expression = left_expression + ' -> ' + right_expression
|
||||||
|
result = rearrange(input, expression)
|
||||||
|
|
||||||
|
for pick in numpy.random.randint(0, 2, [10, n_axes]):
|
||||||
|
assert input[tuple(pick)] == result[tuple(pick[permutation])]
|
||||||
|
|
||||||
|
for n_axes in range(1, 10):
|
||||||
|
input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
|
||||||
|
permutation = numpy.random.permutation(n_axes)
|
||||||
|
left_expression = ' '.join('i' + str(axis) for axis in range(n_axes)[::-1])
|
||||||
|
right_expression = ' '.join('i' + str(axis) for axis in permutation[::-1])
|
||||||
|
expression = left_expression + ' -> ' + right_expression
|
||||||
|
result = rearrange(input, expression)
|
||||||
|
assert result.shape == input.shape
|
||||||
|
expected_result = numpy.zeros_like(input)
|
||||||
|
for original_axis, result_axis in enumerate(permutation):
|
||||||
|
expected_result |= ((input >> original_axis) & 1) << result_axis
|
||||||
|
|
||||||
|
assert numpy.array_equal(result, expected_result)
|
||||||
|
|
||||||
|
def test_rearrange_consistency_numpy(self):
|
||||||
|
shape = [1, 2, 3, 5, 7, 11]
|
||||||
|
x = numpy.arange(numpy.prod(shape)).reshape(shape)
|
||||||
|
for pattern in [
|
||||||
|
'a b c d e f -> a b c d e f',
|
||||||
|
'b a c d e f -> a b d e f c',
|
||||||
|
'a b c d e f -> f e d c b a',
|
||||||
|
'a b c d e f -> (f e) d (c b a)',
|
||||||
|
'a b c d e f -> (f e d c b a)',
|
||||||
|
]:
|
||||||
|
result = rearrange(x, pattern)
|
||||||
|
assert len(numpy.setdiff1d(x, result)) == 0
|
||||||
|
assert result.dtype == x.dtype
|
||||||
|
|
||||||
|
result = rearrange(x, 'a b c d e f -> a (b) (c d e) f')
|
||||||
|
assert numpy.array_equal(x.flatten(), result.flatten())
|
||||||
|
|
||||||
|
result = rearrange(x, 'a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11')
|
||||||
|
assert numpy.array_equal(x, result)
|
||||||
|
|
||||||
|
result1 = rearrange(x, 'a b c d e f -> f e d c b a')
|
||||||
|
result2 = rearrange(x, 'f e d c b a -> a b c d e f')
|
||||||
|
assert numpy.array_equal(result1, result2)
|
||||||
|
|
||||||
|
result = rearrange(rearrange(x, 'a b c d e f -> (f d) c (e b) a'), '(f d) c (e b) a -> a b c d e f', b=2, d=5)
|
||||||
|
assert numpy.array_equal(x, result)
|
||||||
|
|
||||||
|
sizes = dict(zip('abcdef', shape))
|
||||||
|
temp = rearrange(x, 'a b c d e f -> (f d) c (e b) a', **sizes)
|
||||||
|
result = rearrange(temp, '(f d) c (e b) a -> a b c d e f', **sizes)
|
||||||
|
assert numpy.array_equal(x, result)
|
||||||
|
|
||||||
|
x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4])
|
||||||
|
result = rearrange(x2, 'a b c -> b c a')
|
||||||
|
assert x2[1, 2, 3] == result[2, 3, 1]
|
||||||
|
assert x2[0, 1, 2] == result[1, 2, 0]
|
||||||
|
|
||||||
|
def test_ellipsis_ops_imperative(self):
|
||||||
|
""" Checking various patterns against numpy """
|
||||||
|
x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
|
||||||
|
for is_symbolic in [True, False]:
|
||||||
|
for backend in collect_test_backends(symbolic=is_symbolic, layers=False):
|
||||||
|
for pattern in identity_patterns + list(itertools.chain(*equivalent_rearrange_patterns)):
|
||||||
|
check_op_against_numpy(backend, x, pattern, axes_lengths={},
|
||||||
|
reduction='rearrange', is_symbolic=is_symbolic)
|
||||||
|
|
||||||
|
for reduction in ['min', 'max', 'sum']:
|
||||||
|
for pattern in itertools.chain(*equivalent_reduction_patterns):
|
||||||
|
check_op_against_numpy(backend, x, pattern, axes_lengths={},
|
||||||
|
reduction=reduction, is_symbolic=is_symbolic)
|
||||||
|
|
||||||
|
def test_ellipsis_ops_numpy(self):
|
||||||
|
x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
|
||||||
|
for pattern in identity_patterns:
|
||||||
|
assert numpy.array_equal(x, rearrange(x, pattern)), pattern
|
||||||
|
|
||||||
|
for pattern1, pattern2 in equivalent_rearrange_patterns:
|
||||||
|
assert numpy.array_equal(rearrange(x, pattern1), rearrange(x, pattern2))
|
||||||
|
|
||||||
|
for reduction in ['min', 'max', 'sum']:
|
||||||
|
for pattern1, pattern2 in equivalent_reduction_patterns:
|
||||||
|
assert numpy.array_equal(reduce(x, pattern1, reduction=reduction),
|
||||||
|
reduce(x, pattern2, reduction=reduction))
|
||||||
|
|
||||||
|
# now just check coincidence with numpy
|
||||||
|
all_rearrange_patterns = [*identity_patterns]
|
||||||
|
for pattern_pairs in equivalent_rearrange_patterns:
|
||||||
|
all_rearrange_patterns.extend(pattern_pairs)
|
||||||
|
|
||||||
|
def test_collapsed_ellipsis_errors_out(self):
|
||||||
|
x = numpy.zeros([1, 1, 1, 1, 1])
|
||||||
|
rearrange(x, 'a b c d ... -> a b c ... d')
|
||||||
|
error = 0
|
||||||
|
try:
|
||||||
|
rearrange(x, 'a b c d (...) -> a b c ... d')
|
||||||
|
except Exception as e:
|
||||||
|
error = 1
|
||||||
|
assert error == 1
|
||||||
|
|
||||||
|
rearrange(x, '... -> (...)')
|
||||||
|
error = 0
|
||||||
|
try:
|
||||||
|
rearrange(x, '(...) -> (...)')
|
||||||
|
except Exception as e:
|
||||||
|
error = 1
|
||||||
|
assert error == 1
|
||||||
|
|
||||||
|
def test_rearrange_imperative(self):
|
||||||
|
for backend in collect_test_backends(symbolic=False, layers=True):
|
||||||
|
print('Test layer for ', backend.framework_name)
|
||||||
|
|
||||||
|
for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns:
|
||||||
|
x = numpy.arange(numpy.prod(input_shape), dtype='float32').reshape(input_shape)
|
||||||
|
result_numpy = rearrange(x, pattern, **axes_lengths)
|
||||||
|
layer = backend.layers().Rearrange(pattern, **axes_lengths)
|
||||||
|
for shape in wrong_shapes:
|
||||||
|
try:
|
||||||
|
layer(backend.from_numpy(numpy.zeros(shape, dtype='float32')))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise AssertionError('Failure expected')
|
||||||
|
|
||||||
|
# simple pickling / unpickling
|
||||||
|
layer2 = pickle.loads(pickle.dumps(layer))
|
||||||
|
result1 = backend.to_numpy(layer(backend.from_numpy(x)))
|
||||||
|
result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
|
||||||
|
assert numpy.allclose(result_numpy, result1)
|
||||||
|
assert numpy.allclose(result1, result2)
|
||||||
|
|
||||||
|
just_sum = backend.layers().Reduce('...->', reduction='sum')
|
||||||
|
|
||||||
|
|
||||||
|
variable = backend.from_numpy(x)
|
||||||
|
result = just_sum(layer(variable))
|
||||||
|
|
||||||
|
if 'jittor' in backend.framework_name:
|
||||||
|
grad = backend.jittor.grad(result, variable)
|
||||||
|
else:
|
||||||
|
result.backward()
|
||||||
|
grad = variable.grad
|
||||||
|
|
||||||
|
assert numpy.allclose(backend.to_numpy(grad), 1)
|
||||||
|
|
||||||
|
def test_reduce_imperative(self):
|
||||||
|
for backend in collect_test_backends(symbolic=False, layers=True):
|
||||||
|
print('Test layer for ', backend.framework_name)
|
||||||
|
for reduction in _reductions:
|
||||||
|
for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns:
|
||||||
|
print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes)
|
||||||
|
x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype='float32').reshape(input_shape)
|
||||||
|
x /= x.mean()
|
||||||
|
result_numpy = reduce(x, pattern, reduction, **axes_lengths)
|
||||||
|
layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
|
||||||
|
for shape in wrong_shapes:
|
||||||
|
try:
|
||||||
|
layer(backend.from_numpy(numpy.zeros(shape, dtype='float32')))
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise AssertionError('Failure expected')
|
||||||
|
|
||||||
|
# simple pickling / unpickling
|
||||||
|
layer2 = pickle.loads(pickle.dumps(layer))
|
||||||
|
result1 = backend.to_numpy(layer(backend.from_numpy(x)))
|
||||||
|
result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
|
||||||
|
assert numpy.allclose(result_numpy, result1)
|
||||||
|
assert numpy.allclose(result1, result2)
|
||||||
|
|
||||||
|
just_sum = backend.layers().Reduce('...->', reduction='sum')
|
||||||
|
|
||||||
|
|
||||||
|
variable = backend.from_numpy(x)
|
||||||
|
result = just_sum(layer(variable))
|
||||||
|
|
||||||
|
if 'jittor' in backend.framework_name:
|
||||||
|
grad = backend.jittor.grad(result, variable)
|
||||||
|
grad = backend.to_numpy(grad)
|
||||||
|
else:
|
||||||
|
result.backward()
|
||||||
|
grad = backend.to_numpy(variable.grad)
|
||||||
|
if reduction == 'sum':
|
||||||
|
assert numpy.allclose(grad, 1)
|
||||||
|
if reduction == 'mean':
|
||||||
|
assert numpy.allclose(grad, grad.min())
|
||||||
|
if reduction in ['max', 'min']:
|
||||||
|
assert numpy.all(numpy.in1d(grad, [0, 1]))
|
||||||
|
assert numpy.sum(grad) > 0.5
|
||||||
|
|
||||||
|
def test_jittor_layer(self):
|
||||||
|
has_jittor = any(backend.framework_name == 'jittor' for backend in collect_test_backends(symbolic=False, layers=True))
|
||||||
|
if has_jittor:
|
||||||
|
# checked that jittor present
|
||||||
|
import jittor
|
||||||
|
|
||||||
|
rtol = 1e-05
|
||||||
|
atol = 1e-08
|
||||||
|
def allclose(input, other): return jittor.all(jittor.abs(input-other) <= atol+rtol*jittor.abs(other))
|
||||||
|
model1 = create_jittor_model(use_reduce=True)
|
||||||
|
model2 = create_jittor_model(use_reduce=False)
|
||||||
|
input = jittor.randn([10, 3, 32, 32])
|
||||||
|
# random models have different predictions
|
||||||
|
assert not allclose(model1(input), model2(input))
|
||||||
|
model2.load_state_dict(pickle.loads(pickle.dumps(model1.state_dict())))
|
||||||
|
assert allclose(model1(input), model2(input))
|
||||||
|
|
||||||
|
|
||||||
|
testcase = namedtuple('testcase', ['pattern', 'axes_lengths', 'input_shape', 'wrong_shapes'])
|
||||||
|
|
||||||
|
rearrangement_patterns = [
|
||||||
|
testcase('b c h w -> b (c h w)', dict(c=20), (10, 20, 30, 40),
|
||||||
|
[(), (10,), (10, 10, 10), (10, 21, 30, 40), [1, 20, 1, 1, 1]]),
|
||||||
|
testcase('b c (h1 h2) (w1 w2) -> b (c h2 w2) h1 w1', dict(h2=2, w2=2), (10, 20, 30, 40),
|
||||||
|
[(), (1, 1, 1, 1), (1, 10, 3), ()]),
|
||||||
|
testcase('b ... c -> c b ...', dict(b=10), (10, 20, 30),
|
||||||
|
[(), (10,), (5, 10)]),
|
||||||
|
]
|
||||||
|
|
||||||
|
reduction_patterns = rearrangement_patterns + [
|
||||||
|
testcase('b c h w -> b ()', dict(b=10), (10, 20, 30, 40),
|
||||||
|
[(10,), (10, 20, 30)]),
|
||||||
|
testcase('b c (h1 h2) (w1 w2) -> b c h1 w1', dict(h1=15, h2=2, w2=2), (10, 20, 30, 40),
|
||||||
|
[(10, 20, 31, 40)]),
|
||||||
|
testcase('b ... c -> b', dict(b=10), (10, 20, 30, 40),
|
||||||
|
[(10,), (11, 10)]),
|
||||||
|
]
|
||||||
|
|
||||||
|
equivalent_reduction_patterns = [
|
||||||
|
('a b c d e -> ', ' ... -> '),
|
||||||
|
('a b c d e -> (e a)', 'a ... e -> (e a)'),
|
||||||
|
('a b c d e -> d (a e)', ' a b c d e ... -> d (a e) '),
|
||||||
|
('a b c d e -> (a b)', ' ... c d e -> (...) '),
|
||||||
|
]
|
||||||
|
|
||||||
|
equivalent_rearrange_patterns = [
|
||||||
|
('a b c d e -> (a b) c d e', 'a b ... -> (a b) ... '),
|
||||||
|
('a b c d e -> a b (c d) e', '... c d e -> ... (c d) e'),
|
||||||
|
('a b c d e -> a b c d e', '... -> ... '),
|
||||||
|
('a b c d e -> (a b c d e)', '... -> (...)'),
|
||||||
|
('a b c d e -> b (c d e) a', 'a b ... -> b (...) a'),
|
||||||
|
('a b c d e -> b (a c d) e', 'a b ... e -> b (a ...) e'),
|
||||||
|
]
|
||||||
|
|
||||||
|
identity_patterns = [
|
||||||
|
'...->...',
|
||||||
|
'a b c d e-> a b c d e',
|
||||||
|
'a b c d e ...-> ... a b c d e',
|
||||||
|
'a b c d e ...-> a ... b c d e',
|
||||||
|
'... a b c d e -> ... a b c d e',
|
||||||
|
'a ... e-> a ... e',
|
||||||
|
'a ... -> a ... ',
|
||||||
|
'a ... c d e -> a (...) c d e',
|
||||||
|
]
|
||||||
|
|
||||||
|
test_cases_repeat_anonymous = [
|
||||||
|
# all assume that input has shape [1, 2, 4, 6]
|
||||||
|
('a b c d -> c a d b', dict()),
|
||||||
|
('a b c d -> (c 2 d a b)', dict(a=1, c=4, d=6)),
|
||||||
|
('1 b c d -> (d copy 1) 3 b c ', dict(copy=3)),
|
||||||
|
('1 ... -> 3 ... ', dict()),
|
||||||
|
('() ... d -> 1 (copy1 d copy2) ... ', dict(copy1=2, copy2=3)),
|
||||||
|
('1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)', dict()),
|
||||||
|
|
||||||
|
]
|
||||||
|
|
||||||
|
repeat_test_cases = [
|
||||||
|
# all assume that input has shape [2, 3, 5]
|
||||||
|
('a b c -> c a b', dict()),
|
||||||
|
('a b c -> (c copy a b)', dict(copy=2, a=2, b=3, c=5)),
|
||||||
|
('a b c -> (a copy) b c ', dict(copy=1)),
|
||||||
|
('a b c -> (c a) (copy1 b copy2)', dict(a=2, copy1=1, copy2=2)),
|
||||||
|
('a ... -> a ... copy', dict(copy=4)),
|
||||||
|
('... c -> ... (copy1 c copy2)', dict(copy1=1, copy2=2)),
|
||||||
|
('... -> ... ', dict()),
|
||||||
|
(' ... -> copy1 ... copy2 ', dict(copy1=2, copy2=3)),
|
||||||
|
('a b c -> copy1 a copy2 b c () ', dict(copy1=2, copy2=1)),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def check_reversion(x, repeat_pattern, **sizes):
|
||||||
|
"""Checks repeat pattern by running reduction """
|
||||||
|
left, right = repeat_pattern.split('->')
|
||||||
|
reduce_pattern = right + '->' + left
|
||||||
|
repeated = reduce(x, repeat_pattern, reduction='repeat', **sizes)
|
||||||
|
reduced_min = reduce(repeated, reduce_pattern, reduction='min', **sizes)
|
||||||
|
reduced_max = reduce(repeated, reduce_pattern, reduction='max', **sizes)
|
||||||
|
assert numpy.array_equal(x, reduced_min)
|
||||||
|
assert numpy.array_equal(x, reduced_max)
|
||||||
|
|
||||||
|
|
||||||
|
def check_op_against_numpy(backend, numpy_input, pattern, axes_lengths, reduction='rearrange', is_symbolic=False):
|
||||||
|
"""
|
||||||
|
Helper to test result of operation (rearrange or transpose) against numpy
|
||||||
|
if reduction == 'rearrange', rearrange op is tested, otherwise reduce
|
||||||
|
"""
|
||||||
|
if len(numpy_input.shape) == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
def operation(x):
|
||||||
|
if reduction == 'rearrange':
|
||||||
|
return rearrange(x, pattern, **axes_lengths)
|
||||||
|
else:
|
||||||
|
return reduce(x, pattern, reduction, **axes_lengths)
|
||||||
|
|
||||||
|
numpy_result = operation(numpy_input)
|
||||||
|
check_equal = numpy.array_equal
|
||||||
|
p_none_dimension = 0.5
|
||||||
|
if 'jittor' in backend.framework_name:
|
||||||
|
check_equal = numpy.allclose
|
||||||
|
p_none_dimension = 0
|
||||||
|
if is_symbolic:
|
||||||
|
symbol_shape = [d if numpy.random.random() >= p_none_dimension else None for d in numpy_input.shape]
|
||||||
|
symbol = backend.create_symbol(shape=symbol_shape)
|
||||||
|
result_symbol = operation(symbol)
|
||||||
|
backend_result = backend.eval_symbol(result_symbol, [(symbol, numpy_input)])
|
||||||
|
else:
|
||||||
|
backend_result = operation(backend.from_numpy(numpy_input))
|
||||||
|
backend_result = backend.to_numpy(backend_result)
|
||||||
|
|
||||||
|
check_equal(numpy_result, backend_result)
|
||||||
|
|
||||||
|
|
||||||
|
def create_jittor_model(use_reduce=False):
|
||||||
|
from jittor.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
|
||||||
|
from jittor.einops.layers.jittor import Rearrange, Reduce, EinMix
|
||||||
|
return Sequential(
|
||||||
|
Conv2d(3, 6, kernel_size=(5, 5)),
|
||||||
|
Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2),
|
||||||
|
Conv2d(6, 16, kernel_size=(5, 5)),
|
||||||
|
Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2),
|
||||||
|
Rearrange('b c h w -> b (c h w)'),
|
||||||
|
Linear(16 * 5 * 5, 120),
|
||||||
|
ReLU(),
|
||||||
|
Linear(120, 84),
|
||||||
|
ReLU(),
|
||||||
|
EinMix('b c1 -> (b c2)', weight_shape='c1 c2', bias_shape='c2', c1=84, c2=84),
|
||||||
|
EinMix('(b c2) -> b c3', weight_shape='c2 c3', bias_shape='c3', c2=84, c3=84),
|
||||||
|
Linear(84, 10),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue