fix pr#355&add unittest

This commit is contained in:
lidongyang 2022-10-31 20:46:44 +08:00 committed by Jittor
parent 13f9eaafc0
commit b5f03f996b
7 changed files with 821 additions and 131 deletions

View File

@ -13,7 +13,7 @@ Backends in `einops` are organized to meet the following requirements
import sys
import warnings
__author__ = 'Alex Rogozhnikov'
__author__ = 'Alex Rogozhnikov, RuiYang Liu'
_backends = {}
_debug_importing = False
@ -107,7 +107,6 @@ class AbstractBackend:
raise NotImplementedError()
def is_float_type(self, x):
# some backends (torch) can't compute average for non-floating types.
# Decided to drop average for all backends if type is not floating
raise NotImplementedError()
@ -117,6 +116,9 @@ class AbstractBackend:
def __repr__(self):
return "<einops backend for {}>".format(self.framework_name)
def einsum(self, pattern, *x):
raise NotImplementedError("backend does not support einsum")
class UnknownSize:
""" pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """
@ -163,10 +165,14 @@ class NumpyBackend(AbstractBackend):
return self.np.tile(x, repeats)
def is_float_type(self, x):
return x.dtype in ('float16', 'float32', 'float64', 'float128')
return x.dtype in ('float16', 'float32', 'float64', 'float128', 'bfloat16')
def add_axis(self, x, new_position):
return self.np.expand_dims(x, new_position)
def einsum(self, pattern, *x):
return self.np.einsum(pattern, *x)
class HashableTuple:
"""Overcomes non-hashability of symbolic elements"""
@ -192,13 +198,10 @@ class JittorBackend(AbstractBackend):
self.jittor = jittor
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.jittor.jittor_core.Var)
return isinstance(tensor, self.jittor.Var)
def from_numpy(self, x):
variable = self.jittor.array(x)
if self.is_float_type(variable):
# attach grad only to floating types
variable.requires_grad = True
return variable
def to_numpy(self, x):
@ -208,21 +211,24 @@ class JittorBackend(AbstractBackend):
return self.jittor.arange(start, stop, dtype='int64')
def shape(self, x):
return HashableTuple(tuple(x.shape))
return tuple(x.shape)
def reshape(self, x, shape):
if len(shape) == 0:
return x
return self.jittor.reshape(x, shape)
# def reduce(self, x, operation, axes):
# return getattr(x, operation)(dim=axes)
def reduce(self, x, operation, reduced_axes):
if operation == 'prod':
#avoid overflow
return x.prod(reduced_axes)
for axis in sorted(reduced_axes, reverse=True):
if operation == 'min':
x = x.min(dim=axis)
elif operation == 'max':
x = x.max(dim=axis)
elif operation in ['sum', 'mean', 'prod']:
elif operation in ['sum', 'mean']:
x = getattr(x, operation)(dim=axis)
else:
raise NotImplementedError('Unknown reduction ', operation)
@ -252,4 +258,7 @@ class JittorBackend(AbstractBackend):
def layers(self):
from jittor.einops.layers import jittor
return jittor
return jittor
def einsum(self, pattern, *x):
return self.jittor.linalg.einsum(pattern, *x)

View File

@ -1,84 +0,0 @@
"""
Specialization of einops for jittor.
Unfortunately, jittor's jit scripting mechanism isn't strong enough,
and to have scripting supported at least for layers,
a number of changes is required, and this layer helps.
Importantly, whole lib is designed so that you can't use it
"""
from typing import Dict, List
import jittor as jt
from jittor.einops.einops import TransformRecipe, _reconstruct_from_shape_uncached
class JittorJitBackend:
"""
Completely static backend that mimics part of normal backend functionality
but restricted to jittor stuff only
"""
@staticmethod
def reduce(x: jt.jittor_core.Var, operation: str, reduced_axes: List[int]):
if operation == 'min':
return x.min(dims=reduced_axes)
elif operation == 'max':
return x.max(dims=reduced_axes)
elif operation == 'sum':
return x.sum(dims=reduced_axes)
elif operation == 'mean':
return x.mean(dims=reduced_axes)
elif operation == 'prod':
for i in list(sorted(reduced_axes))[::-1]:
x = x.prod(dim=i)
return x
else:
raise NotImplementedError('Unknown reduction ', operation)
@staticmethod
def transpose(x, axes: List[int]):
return x.permute(axes)
@staticmethod
def stack_on_zeroth_dimension(tensors: List[jt.jittor_core.Var]):
return jt.stack(tensors)
@staticmethod
def tile(x, repeats: List[int]):
return x.repeat(repeats)
@staticmethod
def add_axes(x, n_axes: int, pos2len: Dict[int, int]):
repeats = [1] * n_axes
for axis_position, axis_length in pos2len.items():
x = jt.unsqueeze(x, axis_position)
repeats[axis_position] = axis_length
return JittorJitBackend.tile(x, repeats)
@staticmethod
def is_float_type(x):
return x.dtype in ["float16", "float32", "float64"]
@staticmethod
def shape(x):
return x.shape
@staticmethod
def reshape(x, shape: List[int]):
return x.reshape(shape)
# mirrors einops.einops._apply_recipe
def apply_for_scriptable_jittor(recipe: TransformRecipe, tensor: jt.jittor_core.Var, reduction_type: str) -> jt.jittor_core.Var:
backend = JittorJitBackend
init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \
_reconstruct_from_shape_uncached(recipe, backend.shape(tensor))
tensor = backend.reshape(tensor, init_shapes)
if len(reduced_axes) > 0:
tensor = backend.reduce(tensor, operation=reduction_type, reduced_axes=reduced_axes)
tensor = backend.transpose(tensor, axes_reordering)
if len(added_axes) > 0:
tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes)
return backend.reshape(tensor, final_shapes)

View File

@ -1,5 +1,6 @@
import functools
import itertools
import string
import typing
from collections import OrderedDict
from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar
@ -393,7 +394,7 @@ def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: i
```
Parameters:
tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var).
list of tensors is also accepted, those should be of the same type and shape
pattern: string, reduction pattern
reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
@ -418,13 +419,7 @@ def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: i
@typing.overload
def rearrange(tensor: Tensor, pattern: str, **axes_length: int) -> Tensor: ...
@typing.overload
def rearrange(tensor: List[Tensor], pattern: str, **axes_lengths: int) -> Tensor: ...
def rearrange(tensor, pattern: str, **axes_lengths):
def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor:
"""
einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors.
This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
@ -470,7 +465,7 @@ def rearrange(tensor, pattern: str, **axes_lengths):
Find more examples in einops tutorial.
Parameters:
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var).
list of tensors is also accepted, those should be of the same type and shape
pattern: string, rearrangement pattern
axes_lengths: any additional specifications for dimensions
@ -506,8 +501,8 @@ def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
(60, 40)
# repeat image 2 time along height and 3 times along width
>>> repeat(image, 'h w -> h (repeat w)', repeat=3).shape
(30, 120)
>>> repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape
(60, 120)
# convert each pixel to a small square 2x2. Upsample image by 2x
>>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
@ -524,7 +519,7 @@ def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
Find more examples in einops tutorial.
Parameters:
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var).
list of tensors is also accepted, those should be of the same type and shape
pattern: string, rearrangement pattern
axes_lengths: any additional specifications for dimensions
@ -536,7 +531,7 @@ def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
def parse_shape(x, pattern: str):
def parse_shape(x, pattern: str) -> dict:
"""
Parse a tensor shape to dictionary mapping axes names to their lengths.
@ -579,11 +574,11 @@ def parse_shape(x, pattern: str):
ellipsis_idx = exp.composition.index(_ellipsis)
composition = (exp.composition[:ellipsis_idx] +
['_'] * (len(shape) - len(exp.composition) + 1) +
exp.composition[ellipsis_idx+1:])
exp.composition[ellipsis_idx + 1:])
else:
composition = exp.composition
result = {}
for (axis_name, ), axis_length in zip(composition, shape):
for (axis_name,), axis_length in zip(composition, shape):
if axis_name != '_':
result[axis_name] = axis_length
return result
@ -623,3 +618,165 @@ def asnumpy(tensor) -> 'numpy.ndarray':
`numpy.ndarray`, converted to numpy
"""
return get_backend(tensor).to_numpy(tensor)
def _validate_einsum_axis_name(axis_name):
if len(axis_name) == 0:
raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
if len(axis_name) > 1:
raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")
axis_name = axis_name[0]
if isinstance(axis_name, AnonymousAxis):
raise NotImplementedError("Anonymous axes are not yet supported in einsum.")
if len(axis_name) == 0:
raise RuntimeError("Encountered empty axis name in einsum.")
if not isinstance(axis_name, str):
raise RuntimeError("Axis name in einsum must be a string.")
@functools.lru_cache(256)
def _compactify_pattern_for_einsum(pattern: str) -> str:
if "->" not in pattern:
# numpy allows this, so make sure users
# don't accidentally do something like this.
raise ValueError("Einsum pattern must contain '->'.")
lefts, right = pattern.split('->')
lefts = lefts.split(',')
lefts = [
ParsedExpression(left, allow_underscore=True, allow_duplicates=True)
for left in lefts
]
right = ParsedExpression(right, allow_underscore=True)
# Start from 'a' and go up to 'Z'
output_axis_names = string.ascii_letters
i = 0
axis_name_mapping = {}
left_patterns = []
for left in lefts:
left_pattern = ""
for raw_axis_name in left.composition:
if raw_axis_name == _ellipsis:
left_pattern += '...'
continue
_validate_einsum_axis_name(raw_axis_name)
axis_name = raw_axis_name[0]
if axis_name not in axis_name_mapping:
if i >= len(output_axis_names):
raise RuntimeError("Too many axes in einsum.")
axis_name_mapping[axis_name] = output_axis_names[i]
i += 1
left_pattern += axis_name_mapping[axis_name]
left_patterns.append(left_pattern)
compact_pattern = ",".join(left_patterns) + "->"
for raw_axis_name in right.composition:
if raw_axis_name == _ellipsis:
compact_pattern += '...'
continue
_validate_einsum_axis_name(raw_axis_name)
axis_name = raw_axis_name[0]
if axis_name not in axis_name_mapping:
raise EinopsError(f"Unknown axis {axis_name} on right side of einsum {pattern}.")
compact_pattern += axis_name_mapping[axis_name]
return compact_pattern
@typing.overload
def einsum(tensor: Tensor, pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str) -> Tensor: ...
@typing.overload
def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str) -> Tensor: ...
def einsum(*tensors_and_pattern: List[Union[Tensor, str]]) -> Tensor:
"""
einops.einsum calls einsum operations with einops-style named
axes indexing, computing tensor products with an arbitrary
number of tensors. Unlike typical einsum syntax, here you must
pass tensors first, and then the pattern.
Also, note that rearrange operations such as `"(batch chan) out"`,
or singleton axes `()`, are not currently supported.
Examples:
For a given pattern such as:
```python
>>> x, y, z = np.random.randn(3, 20, 20, 20)
>>> output = einsum(x, y, z, "a b c, c b d, a g k -> a b k")
```
the following formula is computed:
```tex
output[a, b, k] =
\sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k]
```
where the summation over `c`, `d`, and `g` is performed
because those axes names do not appear on the right-hand side.
Let's see some additional examples:
```python
# Filter a set of images:
>>> batched_images = np.random.randn(128, 16, 16)
>>> filters = np.random.randn(16, 16, 30)
>>> result = einsum(batched_images, filters,
... "batch h w, h w channel -> batch channel")
>>> result.shape
(128, 30)
# Matrix multiplication, with an unknown input shape:
>>> batch_shape = (50, 30)
>>> data = np.random.randn(*batch_shape, 20)
>>> weights = np.random.randn(10, 20)
>>> result = einsum(weights, data,
... "out_dim in_dim, ... in_dim -> ... out_dim")
>>> result.shape
(50, 30, 10)
# Matrix trace on a single tensor:
>>> matrix = np.random.randn(10, 10)
>>> result = einsum(matrix, "i i ->")
>>> result.shape
()
```
Parameters:
tensors: tensors of any supported library (numpy, jittor).
pattern: string, einsum pattern, with commas
separating specifications for each tensor.
Returns:
Tensor of the same type as input, after processing with einsum.
"""
if len(tensors_and_pattern) <= 1:
raise ValueError(
"`einops.einsum` takes at minimum two arguments: the tensors (at least one),"
" followed by the pattern."
)
pattern = tensors_and_pattern[-1]
if not isinstance(pattern, str):
raise ValueError(
"The last argument passed to `einops.einsum` must be a string,"
" representing the einsum pattern."
)
tensors = tensors_and_pattern[:-1]
pattern = _compactify_pattern_for_einsum(pattern)
return get_backend(tensors[0]).einsum(pattern, *tensors)

View File

@ -49,8 +49,8 @@ class _EinmixMixin:
Parameters
:param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output
:param weight_shape: axes of weight. Tensor od this shape is created, stored, and optimized in a layer
:param bias_shape: axes of bias added to output.
:param weight_shape: axes of weight. A tensor of this shape is created, stored, and optimized in a layer
:param bias_shape: axes of bias added to output. Weights of this shape are created and stored. If `None` (the default), no bias is added.
:param axes_lengths: dimensions of weight tensor
"""
super().__init__()
@ -58,7 +58,9 @@ class _EinmixMixin:
self.weight_shape = weight_shape
self.bias_shape = bias_shape
self.axes_lengths = axes_lengths
self.initialize_einmix(pattern=pattern, weight_shape=weight_shape, bias_shape=bias_shape, axes_lengths=axes_lengths)
def initialize_einmix(self, pattern, weight_shape, bias_shape, axes_lengths):
left_pattern, right_pattern = pattern.split('->')
left = ParsedExpression(left_pattern)
right = ParsedExpression(right_pattern)
@ -84,7 +86,7 @@ class _EinmixMixin:
names += group
composition = ' '.join(names)
pre_reshape_pattern = f'{left_pattern}->{composition}'
pre_reshape_lengths = {name: length for name, length in self.axes_lengths.items() if name in names}
pre_reshape_lengths = {name: length for name, length in axes_lengths.items() if name in names}
if any(len(group) != 1 for group in right.composition):
names = []
@ -134,8 +136,7 @@ class _EinmixMixin:
_bias_shape.append(1)
else:
_bias_shape = None
_bias_input_size = None
weight_bound = (3 / _fan_in) ** 0.5
bias_bound = (1 / _fan_in) ** 0.5
self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound)

View File

@ -6,27 +6,18 @@ import numpy as np
from jittor.einops.layers import RearrangeMixin, ReduceMixin
from jittor.einops.layers._einmix import _EinmixMixin
from jittor.einops._jittor_specific import apply_for_scriptable_jittor
__author__ = 'Ruiyang Liu'
class Rearrange(RearrangeMixin, jt.nn.Module):
def execute(self, input):
return apply_for_scriptable_jittor(self._recipe, input, reduction_type='rearrange')
def _apply_recipe(self, x):
# overriding parent method to prevent it's scripting
pass
return self._apply_recipe(input)
class Reduce(ReduceMixin, jt.nn.Module):
def execute(self, input):
return apply_for_scriptable_jittor(self._recipe, input, reduction_type=self.reduction)
def _apply_recipe(self, x):
# overriding parent method to prevent it's scripting
pass
return self._apply_recipe(input)
class EinMix(_EinmixMixin, jt.nn.Module):

View File

@ -26,7 +26,7 @@ class ParsedExpression:
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
and keeps some information important for downstream
"""
def __init__(self, expression, *, allow_underscore: bool = False):
def __init__(self, expression, *, allow_underscore: bool = False, allow_duplicates: bool = False):
self.has_ellipsis: bool = False
self.has_ellipsis_parenthesized: Optional[bool] = None
self.identifiers: Set[str] = set()
@ -48,7 +48,7 @@ class ParsedExpression:
def add_axis_name(x):
if x is not None:
if x in self.identifiers:
if not (allow_underscore and x == "_"):
if not (allow_underscore and x == "_") and not allow_duplicates:
raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
if x == _ellipsis:
self.identifiers.add(_ellipsis)

View File

@ -0,0 +1,616 @@
# ***************************************************************
# Copyright (c) 2022 Jittor. All Rights Reserved.
# Maintainers: DongYang Li <lidongyang2001@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
from collections import namedtuple
import tempfile
import pickle
import itertools
from jittor.einops.einops import (rearrange, reduce, _enumerate_directions, _reductions)
from jittor.einops import EinopsError
import jittor as jt
import numpy
import unittest
# tests/__init__.py
import os
from jittor.einops import _backends
import warnings
flag_to_bool = {
'': False,
'0': False,
'1': True,
}
def collect_test_backends(symbolic=False, layers=False):
"""
:param symbolic: symbolic or imperative frameworks?
:param layers: layers or operations?
:return: list of backends satisfying set conditions
"""
if not symbolic:
if not layers:
backend_types = [
_backends.NumpyBackend,
_backends.JittorBackend,
]
else:
backend_types = [
_backends.JittorBackend,
]
else:
backend_types = []
result = []
for backend_type in backend_types:
try:
result.append(backend_type())
except ImportError:
# problem with backend installation fails a specific test function,
# but will be skipped in all other test cases
warnings.warn('backend could not be initialized for tests: {}'.format(backend_type))
return result
# test/test_ops.py
imp_op_backends = collect_test_backends(symbolic=False, layers=False)
# test/test_layer.py
class TestSlice(unittest.TestCase):
def test_anonymous_axes(self):
x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6])
for pattern, axis_dimensions in test_cases_repeat_anonymous:
check_reversion(x, pattern, **axis_dimensions)
def test_repeat_imperatives(self):
x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5])
for backend in imp_op_backends:
print('Repeat tests for ', backend.framework_name)
for pattern, axis_dimensions in repeat_test_cases:
expected = reduce(x, pattern, reduction='repeat', **axis_dimensions)
converted = backend.from_numpy(x)
repeated = reduce(converted, pattern, reduction='repeat', **axis_dimensions)
result = backend.to_numpy(repeated)
assert numpy.array_equal(result, expected)
def test_repeat_numpy(self):
# check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well
x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5])
x1 = reduce(x, 'a b c -> copy a b c ', reduction='repeat', copy=1)
assert numpy.array_equal(x[None], x1)
for pattern, axis_dimensions in repeat_test_cases:
check_reversion(x, pattern, **axis_dimensions)
def test_tiling_imperatives(self):
for backend in imp_op_backends:
print('Tiling tests for ', backend.framework_name)
input = numpy.arange(2 * 3 * 5, dtype='int64').reshape([2, 1, 3, 1, 5])
test_cases = [
(1, 1, 1, 1, 1),
(1, 2, 1, 3, 1),
(3, 1, 1, 4, 1),
]
for repeats in test_cases:
expected = numpy.tile(input, repeats)
converted = backend.from_numpy(input)
repeated = backend.tile(converted, repeats)
result = backend.to_numpy(repeated)
assert numpy.array_equal(result, expected)
def test_gradients_imperatives(self):
# lazy - just checking reductions
for reduction in _reductions:
x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype('float32')
results = {}
for backend in imp_op_backends:
y0 = backend.from_numpy(x)
if not 'jittor' in backend.framework_name and not hasattr(y0, 'grad'):
continue
y1 = reduce(y0, 'a b c -> c a', reduction=reduction)
y2 = reduce(y1, 'c a -> a c', reduction=reduction)
y3 = reduce(y2, 'a (c1 c2) -> a', reduction=reduction, c1=2)
y4 = reduce(y3, '... -> ', reduction=reduction)
if 'jittor' in backend.framework_name:
grad = backend.jittor.grad(y4, y0)
else:
y4.backward()
grad = y0.grad
results[backend.framework_name] = backend.to_numpy(grad)
print('comparing gradients for', results.keys())
for name1, grad1 in results.items():
for name2, grad2 in results.items():
assert numpy.allclose(grad1, grad2), [name1, name2, 'provided different gradients']
def test_concatenations_and_stacking(self):
for backend in imp_op_backends:
print('testing shapes for ', backend.framework_name)
for n_arrays in [1, 2, 5]:
shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
for shape in shapes:
if (backend.framework_name == 'jittor')\
and len(shape) == 0:
# jittor stores scalar in 1d array
continue
arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)]
arrays2 = [backend.from_numpy(array) for array in arrays1]
result0 = numpy.asarray(arrays1)
result1 = rearrange(arrays1, '...->...')
result2 = rearrange(arrays2, '...->...')
assert numpy.array_equal(result0, result1)
assert numpy.array_equal(result1, backend.to_numpy(result2))
result1 = rearrange(arrays1, 'b ... -> ... b')
result2 = rearrange(arrays2, 'b ... -> ... b')
assert numpy.array_equal(result1, backend.to_numpy(result2))
def test_enumerating_directions(self):
for backend in imp_op_backends:
print('testing directions for', backend.framework_name)
for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]:
if (backend.framework_name == 'jittor')\
and len(shape) == 0:
# jittor stores scalar in 1d array
continue
x = numpy.arange(numpy.prod(shape)).reshape(shape)
axes1 = _enumerate_directions(x)
axes2 = _enumerate_directions(backend.from_numpy(x))
assert len(axes1) == len(axes2) == len(shape)
for ax1, ax2 in zip(axes1, axes2):
ax2 = backend.to_numpy(ax2)
assert ax1.shape == ax2.shape
assert numpy.allclose(ax1, ax2)
def test_reduction_with_callable_imperatives(self):
x_numpy = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]).astype('float32')
x_numpy /= x_numpy.max()
def logsumexp_jittor(x, tuple_of_axes):
import jittor as jt
return jt.nn.logsumexp(x, tuple_of_axes)
def logsumexp_numpy(x, tuple_of_axes):
# very naive logsumexp to compare to
minused = x.max(tuple_of_axes)
y = x - x.max(tuple_of_axes, keepdims=True)
y = numpy.exp(y)
y = numpy.sum(y, axis=tuple_of_axes)
return numpy.log(y) + minused
from jittor.einops._backends import JittorBackend, NumpyBackend
backend2callback = {
JittorBackend.framework_name: logsumexp_jittor,
NumpyBackend.framework_name: logsumexp_numpy,
}
for backend in imp_op_backends:
if backend.framework_name not in backend2callback:
continue
backend_callback = backend2callback[backend.framework_name]
x_backend = backend.from_numpy(x_numpy)
for pattern1, pattern2 in equivalent_reduction_patterns:
print('Test reduction with callable for ', backend.framework_name, pattern1, pattern2)
output_numpy = reduce(x_numpy, pattern1, reduction=logsumexp_numpy)
output_backend = reduce(x_backend, pattern1, reduction=backend_callback)
assert numpy.allclose(
output_numpy,
backend.to_numpy(output_backend),
)
def test_reduction_stress_imperatives(self):
for backend in imp_op_backends:
print('Stress-testing reduction for ', backend.framework_name)
for reduction in _reductions + ('rearrange',):
dtype = 'int64'
coincide = numpy.array_equal
if reduction in ['mean', 'prod']:
dtype = 'float64'
coincide = numpy.allclose
for n_axes in range(11):
shape = numpy.random.randint(2, 4, size=n_axes)
permutation = numpy.random.permutation(n_axes)
skipped = 0 if reduction == 'rearrange' else numpy.random.randint(n_axes + 1)
left = ' '.join('x' + str(i) for i in range(n_axes))
right = ' '.join('x' + str(i) for i in permutation[skipped:])
pattern = left + '->' + right
x = numpy.arange(1, 1 + numpy.prod(shape), dtype=dtype).reshape(shape)
if reduction == 'prod':
x /= x.mean() # to avoid overflows
result1 = reduce(x, pattern, reduction=reduction)
result2 = x.transpose(permutation)
if skipped > 0:
result2 = getattr(result2, reduction)(axis=tuple(range(skipped)))
assert coincide(result1, result2)
check_op_against_numpy(backend, x, pattern, reduction=reduction, axes_lengths={}, is_symbolic=False)
def test_reduction_imperatives(self):
for backend in imp_op_backends:
print('Reduction tests for ', backend.framework_name)
for reduction in _reductions:
# slight redundancy for simpler order - numpy version is evaluated multiple times
input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype='int64').reshape([2, 3, 4, 5, 6])
if reduction in ['mean', 'prod']:
input = input / input.astype('float64').mean()
test_cases = [
['a b c d e -> ', {},
getattr(input, reduction)()],
['a ... -> ', {},
getattr(input, reduction)()],
['(a1 a2) ... (e1 e2) -> ', dict(a1=1, e2=2),
getattr(input, reduction)()],
['a b c d e -> (e c) a', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
['a ... c d e -> (e c) a', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
['a b c d e ... -> (e c) a', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
['a b c d e -> (e c a)', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])],
['(a a2) ... -> (a2 a) ...', dict(a2=1),
input],
]
for pattern, axes_lengths, expected_result in test_cases:
result = reduce(backend.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths)
result = backend.to_numpy(result)
assert numpy.allclose(result, expected_result)
def test_rearrange_permutations_numpy(self):
# tests random permutation of axes against two independent numpy ways
for n_axes in range(1, 10):
input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
permutation = numpy.random.permutation(n_axes)
left_expression = ' '.join('i' + str(axis) for axis in range(n_axes))
right_expression = ' '.join('i' + str(axis) for axis in permutation)
expression = left_expression + ' -> ' + right_expression
result = rearrange(input, expression)
for pick in numpy.random.randint(0, 2, [10, n_axes]):
assert input[tuple(pick)] == result[tuple(pick[permutation])]
for n_axes in range(1, 10):
input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
permutation = numpy.random.permutation(n_axes)
left_expression = ' '.join('i' + str(axis) for axis in range(n_axes)[::-1])
right_expression = ' '.join('i' + str(axis) for axis in permutation[::-1])
expression = left_expression + ' -> ' + right_expression
result = rearrange(input, expression)
assert result.shape == input.shape
expected_result = numpy.zeros_like(input)
for original_axis, result_axis in enumerate(permutation):
expected_result |= ((input >> original_axis) & 1) << result_axis
assert numpy.array_equal(result, expected_result)
def test_rearrange_consistency_numpy(self):
shape = [1, 2, 3, 5, 7, 11]
x = numpy.arange(numpy.prod(shape)).reshape(shape)
for pattern in [
'a b c d e f -> a b c d e f',
'b a c d e f -> a b d e f c',
'a b c d e f -> f e d c b a',
'a b c d e f -> (f e) d (c b a)',
'a b c d e f -> (f e d c b a)',
]:
result = rearrange(x, pattern)
assert len(numpy.setdiff1d(x, result)) == 0
assert result.dtype == x.dtype
result = rearrange(x, 'a b c d e f -> a (b) (c d e) f')
assert numpy.array_equal(x.flatten(), result.flatten())
result = rearrange(x, 'a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11')
assert numpy.array_equal(x, result)
result1 = rearrange(x, 'a b c d e f -> f e d c b a')
result2 = rearrange(x, 'f e d c b a -> a b c d e f')
assert numpy.array_equal(result1, result2)
result = rearrange(rearrange(x, 'a b c d e f -> (f d) c (e b) a'), '(f d) c (e b) a -> a b c d e f', b=2, d=5)
assert numpy.array_equal(x, result)
sizes = dict(zip('abcdef', shape))
temp = rearrange(x, 'a b c d e f -> (f d) c (e b) a', **sizes)
result = rearrange(temp, '(f d) c (e b) a -> a b c d e f', **sizes)
assert numpy.array_equal(x, result)
x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4])
result = rearrange(x2, 'a b c -> b c a')
assert x2[1, 2, 3] == result[2, 3, 1]
assert x2[0, 1, 2] == result[1, 2, 0]
def test_ellipsis_ops_imperative(self):
""" Checking various patterns against numpy """
x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
for is_symbolic in [True, False]:
for backend in collect_test_backends(symbolic=is_symbolic, layers=False):
for pattern in identity_patterns + list(itertools.chain(*equivalent_rearrange_patterns)):
check_op_against_numpy(backend, x, pattern, axes_lengths={},
reduction='rearrange', is_symbolic=is_symbolic)
for reduction in ['min', 'max', 'sum']:
for pattern in itertools.chain(*equivalent_reduction_patterns):
check_op_against_numpy(backend, x, pattern, axes_lengths={},
reduction=reduction, is_symbolic=is_symbolic)
def test_ellipsis_ops_numpy(self):
x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
for pattern in identity_patterns:
assert numpy.array_equal(x, rearrange(x, pattern)), pattern
for pattern1, pattern2 in equivalent_rearrange_patterns:
assert numpy.array_equal(rearrange(x, pattern1), rearrange(x, pattern2))
for reduction in ['min', 'max', 'sum']:
for pattern1, pattern2 in equivalent_reduction_patterns:
assert numpy.array_equal(reduce(x, pattern1, reduction=reduction),
reduce(x, pattern2, reduction=reduction))
# now just check coincidence with numpy
all_rearrange_patterns = [*identity_patterns]
for pattern_pairs in equivalent_rearrange_patterns:
all_rearrange_patterns.extend(pattern_pairs)
def test_collapsed_ellipsis_errors_out(self):
x = numpy.zeros([1, 1, 1, 1, 1])
rearrange(x, 'a b c d ... -> a b c ... d')
error = 0
try:
rearrange(x, 'a b c d (...) -> a b c ... d')
except Exception as e:
error = 1
assert error == 1
rearrange(x, '... -> (...)')
error = 0
try:
rearrange(x, '(...) -> (...)')
except Exception as e:
error = 1
assert error == 1
def test_rearrange_imperative(self):
for backend in collect_test_backends(symbolic=False, layers=True):
print('Test layer for ', backend.framework_name)
for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns:
x = numpy.arange(numpy.prod(input_shape), dtype='float32').reshape(input_shape)
result_numpy = rearrange(x, pattern, **axes_lengths)
layer = backend.layers().Rearrange(pattern, **axes_lengths)
for shape in wrong_shapes:
try:
layer(backend.from_numpy(numpy.zeros(shape, dtype='float32')))
except:
pass
else:
raise AssertionError('Failure expected')
# simple pickling / unpickling
layer2 = pickle.loads(pickle.dumps(layer))
result1 = backend.to_numpy(layer(backend.from_numpy(x)))
result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
assert numpy.allclose(result_numpy, result1)
assert numpy.allclose(result1, result2)
just_sum = backend.layers().Reduce('...->', reduction='sum')
variable = backend.from_numpy(x)
result = just_sum(layer(variable))
if 'jittor' in backend.framework_name:
grad = backend.jittor.grad(result, variable)
else:
result.backward()
grad = variable.grad
assert numpy.allclose(backend.to_numpy(grad), 1)
def test_reduce_imperative(self):
for backend in collect_test_backends(symbolic=False, layers=True):
print('Test layer for ', backend.framework_name)
for reduction in _reductions:
for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns:
print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes)
x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype='float32').reshape(input_shape)
x /= x.mean()
result_numpy = reduce(x, pattern, reduction, **axes_lengths)
layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
for shape in wrong_shapes:
try:
layer(backend.from_numpy(numpy.zeros(shape, dtype='float32')))
except:
pass
else:
raise AssertionError('Failure expected')
# simple pickling / unpickling
layer2 = pickle.loads(pickle.dumps(layer))
result1 = backend.to_numpy(layer(backend.from_numpy(x)))
result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
assert numpy.allclose(result_numpy, result1)
assert numpy.allclose(result1, result2)
just_sum = backend.layers().Reduce('...->', reduction='sum')
variable = backend.from_numpy(x)
result = just_sum(layer(variable))
if 'jittor' in backend.framework_name:
grad = backend.jittor.grad(result, variable)
grad = backend.to_numpy(grad)
else:
result.backward()
grad = backend.to_numpy(variable.grad)
if reduction == 'sum':
assert numpy.allclose(grad, 1)
if reduction == 'mean':
assert numpy.allclose(grad, grad.min())
if reduction in ['max', 'min']:
assert numpy.all(numpy.in1d(grad, [0, 1]))
assert numpy.sum(grad) > 0.5
def test_jittor_layer(self):
has_jittor = any(backend.framework_name == 'jittor' for backend in collect_test_backends(symbolic=False, layers=True))
if has_jittor:
# checked that jittor present
import jittor
rtol = 1e-05
atol = 1e-08
def allclose(input, other): return jittor.all(jittor.abs(input-other) <= atol+rtol*jittor.abs(other))
model1 = create_jittor_model(use_reduce=True)
model2 = create_jittor_model(use_reduce=False)
input = jittor.randn([10, 3, 32, 32])
# random models have different predictions
assert not allclose(model1(input), model2(input))
model2.load_state_dict(pickle.loads(pickle.dumps(model1.state_dict())))
assert allclose(model1(input), model2(input))
testcase = namedtuple('testcase', ['pattern', 'axes_lengths', 'input_shape', 'wrong_shapes'])
rearrangement_patterns = [
testcase('b c h w -> b (c h w)', dict(c=20), (10, 20, 30, 40),
[(), (10,), (10, 10, 10), (10, 21, 30, 40), [1, 20, 1, 1, 1]]),
testcase('b c (h1 h2) (w1 w2) -> b (c h2 w2) h1 w1', dict(h2=2, w2=2), (10, 20, 30, 40),
[(), (1, 1, 1, 1), (1, 10, 3), ()]),
testcase('b ... c -> c b ...', dict(b=10), (10, 20, 30),
[(), (10,), (5, 10)]),
]
reduction_patterns = rearrangement_patterns + [
testcase('b c h w -> b ()', dict(b=10), (10, 20, 30, 40),
[(10,), (10, 20, 30)]),
testcase('b c (h1 h2) (w1 w2) -> b c h1 w1', dict(h1=15, h2=2, w2=2), (10, 20, 30, 40),
[(10, 20, 31, 40)]),
testcase('b ... c -> b', dict(b=10), (10, 20, 30, 40),
[(10,), (11, 10)]),
]
equivalent_reduction_patterns = [
('a b c d e -> ', ' ... -> '),
('a b c d e -> (e a)', 'a ... e -> (e a)'),
('a b c d e -> d (a e)', ' a b c d e ... -> d (a e) '),
('a b c d e -> (a b)', ' ... c d e -> (...) '),
]
equivalent_rearrange_patterns = [
('a b c d e -> (a b) c d e', 'a b ... -> (a b) ... '),
('a b c d e -> a b (c d) e', '... c d e -> ... (c d) e'),
('a b c d e -> a b c d e', '... -> ... '),
('a b c d e -> (a b c d e)', '... -> (...)'),
('a b c d e -> b (c d e) a', 'a b ... -> b (...) a'),
('a b c d e -> b (a c d) e', 'a b ... e -> b (a ...) e'),
]
identity_patterns = [
'...->...',
'a b c d e-> a b c d e',
'a b c d e ...-> ... a b c d e',
'a b c d e ...-> a ... b c d e',
'... a b c d e -> ... a b c d e',
'a ... e-> a ... e',
'a ... -> a ... ',
'a ... c d e -> a (...) c d e',
]
test_cases_repeat_anonymous = [
# all assume that input has shape [1, 2, 4, 6]
('a b c d -> c a d b', dict()),
('a b c d -> (c 2 d a b)', dict(a=1, c=4, d=6)),
('1 b c d -> (d copy 1) 3 b c ', dict(copy=3)),
('1 ... -> 3 ... ', dict()),
('() ... d -> 1 (copy1 d copy2) ... ', dict(copy1=2, copy2=3)),
('1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)', dict()),
]
repeat_test_cases = [
# all assume that input has shape [2, 3, 5]
('a b c -> c a b', dict()),
('a b c -> (c copy a b)', dict(copy=2, a=2, b=3, c=5)),
('a b c -> (a copy) b c ', dict(copy=1)),
('a b c -> (c a) (copy1 b copy2)', dict(a=2, copy1=1, copy2=2)),
('a ... -> a ... copy', dict(copy=4)),
('... c -> ... (copy1 c copy2)', dict(copy1=1, copy2=2)),
('... -> ... ', dict()),
(' ... -> copy1 ... copy2 ', dict(copy1=2, copy2=3)),
('a b c -> copy1 a copy2 b c () ', dict(copy1=2, copy2=1)),
]
def check_reversion(x, repeat_pattern, **sizes):
"""Checks repeat pattern by running reduction """
left, right = repeat_pattern.split('->')
reduce_pattern = right + '->' + left
repeated = reduce(x, repeat_pattern, reduction='repeat', **sizes)
reduced_min = reduce(repeated, reduce_pattern, reduction='min', **sizes)
reduced_max = reduce(repeated, reduce_pattern, reduction='max', **sizes)
assert numpy.array_equal(x, reduced_min)
assert numpy.array_equal(x, reduced_max)
def check_op_against_numpy(backend, numpy_input, pattern, axes_lengths, reduction='rearrange', is_symbolic=False):
"""
Helper to test result of operation (rearrange or transpose) against numpy
if reduction == 'rearrange', rearrange op is tested, otherwise reduce
"""
if len(numpy_input.shape) == 0:
return
def operation(x):
if reduction == 'rearrange':
return rearrange(x, pattern, **axes_lengths)
else:
return reduce(x, pattern, reduction, **axes_lengths)
numpy_result = operation(numpy_input)
check_equal = numpy.array_equal
p_none_dimension = 0.5
if 'jittor' in backend.framework_name:
check_equal = numpy.allclose
p_none_dimension = 0
if is_symbolic:
symbol_shape = [d if numpy.random.random() >= p_none_dimension else None for d in numpy_input.shape]
symbol = backend.create_symbol(shape=symbol_shape)
result_symbol = operation(symbol)
backend_result = backend.eval_symbol(result_symbol, [(symbol, numpy_input)])
else:
backend_result = operation(backend.from_numpy(numpy_input))
backend_result = backend.to_numpy(backend_result)
check_equal(numpy_result, backend_result)
def create_jittor_model(use_reduce=False):
from jittor.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
from jittor.einops.layers.jittor import Rearrange, Reduce, EinMix
return Sequential(
Conv2d(3, 6, kernel_size=(5, 5)),
Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2),
Conv2d(6, 16, kernel_size=(5, 5)),
Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2),
Rearrange('b c h w -> b (c h w)'),
Linear(16 * 5 * 5, 120),
ReLU(),
Linear(120, 84),
ReLU(),
EinMix('b c1 -> (b c2)', weight_shape='c1 c2', bias_shape='c2', c1=84, c2=84),
EinMix('(b c2) -> b c3', weight_shape='c2 c3', bias_shape='c3', c2=84, c3=84),
Linear(84, 10),
)
if __name__ == '__main__':
unittest.main()