From b5f03f996bd7bf5b8b5ca2fe4c1739bdd2d58776 Mon Sep 17 00:00:00 2001 From: lidongyang Date: Mon, 31 Oct 2022 20:46:44 +0800 Subject: [PATCH] fix pr#355&add unittest --- python/jittor/einops/_backends.py | 35 +- python/jittor/einops/_jittor_specific.py | 84 ---- python/jittor/einops/einops.py | 187 ++++++- python/jittor/einops/layers/_einmix.py | 11 +- python/jittor/einops/layers/jittor.py | 15 +- python/jittor/einops/parsing.py | 4 +- python/jittor/test/test_einops.py | 616 +++++++++++++++++++++++ 7 files changed, 821 insertions(+), 131 deletions(-) delete mode 100644 python/jittor/einops/_jittor_specific.py create mode 100644 python/jittor/test/test_einops.py diff --git a/python/jittor/einops/_backends.py b/python/jittor/einops/_backends.py index a04f34be..eef8f21c 100644 --- a/python/jittor/einops/_backends.py +++ b/python/jittor/einops/_backends.py @@ -13,7 +13,7 @@ Backends in `einops` are organized to meet the following requirements import sys import warnings -__author__ = 'Alex Rogozhnikov' +__author__ = 'Alex Rogozhnikov, RuiYang Liu' _backends = {} _debug_importing = False @@ -107,7 +107,6 @@ class AbstractBackend: raise NotImplementedError() def is_float_type(self, x): - # some backends (torch) can't compute average for non-floating types. # Decided to drop average for all backends if type is not floating raise NotImplementedError() @@ -117,6 +116,9 @@ class AbstractBackend: def __repr__(self): return "".format(self.framework_name) + def einsum(self, pattern, *x): + raise NotImplementedError("backend does not support einsum") + class UnknownSize: """ pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """ @@ -163,10 +165,14 @@ class NumpyBackend(AbstractBackend): return self.np.tile(x, repeats) def is_float_type(self, x): - return x.dtype in ('float16', 'float32', 'float64', 'float128') + return x.dtype in ('float16', 'float32', 'float64', 'float128', 'bfloat16') def add_axis(self, x, new_position): return self.np.expand_dims(x, new_position) + + def einsum(self, pattern, *x): + return self.np.einsum(pattern, *x) + class HashableTuple: """Overcomes non-hashability of symbolic elements""" @@ -192,13 +198,10 @@ class JittorBackend(AbstractBackend): self.jittor = jittor def is_appropriate_type(self, tensor): - return isinstance(tensor, self.jittor.jittor_core.Var) + return isinstance(tensor, self.jittor.Var) def from_numpy(self, x): variable = self.jittor.array(x) - if self.is_float_type(variable): - # attach grad only to floating types - variable.requires_grad = True return variable def to_numpy(self, x): @@ -208,21 +211,24 @@ class JittorBackend(AbstractBackend): return self.jittor.arange(start, stop, dtype='int64') def shape(self, x): - return HashableTuple(tuple(x.shape)) + return tuple(x.shape) def reshape(self, x, shape): + if len(shape) == 0: + return x return self.jittor.reshape(x, shape) - # def reduce(self, x, operation, axes): - # return getattr(x, operation)(dim=axes) - def reduce(self, x, operation, reduced_axes): + + if operation == 'prod': + #avoid overflow + return x.prod(reduced_axes) for axis in sorted(reduced_axes, reverse=True): if operation == 'min': x = x.min(dim=axis) elif operation == 'max': x = x.max(dim=axis) - elif operation in ['sum', 'mean', 'prod']: + elif operation in ['sum', 'mean']: x = getattr(x, operation)(dim=axis) else: raise NotImplementedError('Unknown reduction ', operation) @@ -252,4 +258,7 @@ class JittorBackend(AbstractBackend): def layers(self): from jittor.einops.layers import jittor - return jittor \ No newline at end of file + return jittor + + def einsum(self, pattern, *x): + return self.jittor.linalg.einsum(pattern, *x) \ No newline at end of file diff --git a/python/jittor/einops/_jittor_specific.py b/python/jittor/einops/_jittor_specific.py deleted file mode 100644 index d21ff76d..00000000 --- a/python/jittor/einops/_jittor_specific.py +++ /dev/null @@ -1,84 +0,0 @@ -""" -Specialization of einops for jittor. - -Unfortunately, jittor's jit scripting mechanism isn't strong enough, -and to have scripting supported at least for layers, -a number of changes is required, and this layer helps. - -Importantly, whole lib is designed so that you can't use it -""" - -from typing import Dict, List - -import jittor as jt -from jittor.einops.einops import TransformRecipe, _reconstruct_from_shape_uncached - - -class JittorJitBackend: - """ - Completely static backend that mimics part of normal backend functionality - but restricted to jittor stuff only - """ - - @staticmethod - def reduce(x: jt.jittor_core.Var, operation: str, reduced_axes: List[int]): - if operation == 'min': - return x.min(dims=reduced_axes) - elif operation == 'max': - return x.max(dims=reduced_axes) - elif operation == 'sum': - return x.sum(dims=reduced_axes) - elif operation == 'mean': - return x.mean(dims=reduced_axes) - elif operation == 'prod': - for i in list(sorted(reduced_axes))[::-1]: - x = x.prod(dim=i) - return x - else: - raise NotImplementedError('Unknown reduction ', operation) - - @staticmethod - def transpose(x, axes: List[int]): - return x.permute(axes) - - @staticmethod - def stack_on_zeroth_dimension(tensors: List[jt.jittor_core.Var]): - return jt.stack(tensors) - - @staticmethod - def tile(x, repeats: List[int]): - return x.repeat(repeats) - - @staticmethod - def add_axes(x, n_axes: int, pos2len: Dict[int, int]): - repeats = [1] * n_axes - for axis_position, axis_length in pos2len.items(): - x = jt.unsqueeze(x, axis_position) - repeats[axis_position] = axis_length - return JittorJitBackend.tile(x, repeats) - - @staticmethod - def is_float_type(x): - return x.dtype in ["float16", "float32", "float64"] - - @staticmethod - def shape(x): - return x.shape - - @staticmethod - def reshape(x, shape: List[int]): - return x.reshape(shape) - - -# mirrors einops.einops._apply_recipe -def apply_for_scriptable_jittor(recipe: TransformRecipe, tensor: jt.jittor_core.Var, reduction_type: str) -> jt.jittor_core.Var: - backend = JittorJitBackend - init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \ - _reconstruct_from_shape_uncached(recipe, backend.shape(tensor)) - tensor = backend.reshape(tensor, init_shapes) - if len(reduced_axes) > 0: - tensor = backend.reduce(tensor, operation=reduction_type, reduced_axes=reduced_axes) - tensor = backend.transpose(tensor, axes_reordering) - if len(added_axes) > 0: - tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes) - return backend.reshape(tensor, final_shapes) diff --git a/python/jittor/einops/einops.py b/python/jittor/einops/einops.py index 21fafd77..da931fd3 100644 --- a/python/jittor/einops/einops.py +++ b/python/jittor/einops/einops.py @@ -1,5 +1,6 @@ import functools import itertools +import string import typing from collections import OrderedDict from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar @@ -393,7 +394,7 @@ def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: i ``` Parameters: - tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array). + tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var). list of tensors is also accepted, those should be of the same type and shape pattern: string, reduction pattern reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive @@ -418,13 +419,7 @@ def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: i -@typing.overload -def rearrange(tensor: Tensor, pattern: str, **axes_length: int) -> Tensor: ... -@typing.overload -def rearrange(tensor: List[Tensor], pattern: str, **axes_lengths: int) -> Tensor: ... - - -def rearrange(tensor, pattern: str, **axes_lengths): +def rearrange(tensor: Union[Tensor, List[Tensor]], pattern: str, **axes_lengths) -> Tensor: """ einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors. This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, @@ -470,7 +465,7 @@ def rearrange(tensor, pattern: str, **axes_lengths): Find more examples in einops tutorial. Parameters: - tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array). + tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var). list of tensors is also accepted, those should be of the same type and shape pattern: string, rearrangement pattern axes_lengths: any additional specifications for dimensions @@ -506,8 +501,8 @@ def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor: (60, 40) # repeat image 2 time along height and 3 times along width - >>> repeat(image, 'h w -> h (repeat w)', repeat=3).shape - (30, 120) + >>> repeat(image, 'h w -> (h2 h) (w3 w)', h2=2, w3=3).shape + (60, 120) # convert each pixel to a small square 2x2. Upsample image by 2x >>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape @@ -524,7 +519,7 @@ def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor: Find more examples in einops tutorial. Parameters: - tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array). + tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.Var). list of tensors is also accepted, those should be of the same type and shape pattern: string, rearrangement pattern axes_lengths: any additional specifications for dimensions @@ -536,7 +531,7 @@ def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor: return reduce(tensor, pattern, reduction='repeat', **axes_lengths) -def parse_shape(x, pattern: str): +def parse_shape(x, pattern: str) -> dict: """ Parse a tensor shape to dictionary mapping axes names to their lengths. @@ -579,11 +574,11 @@ def parse_shape(x, pattern: str): ellipsis_idx = exp.composition.index(_ellipsis) composition = (exp.composition[:ellipsis_idx] + ['_'] * (len(shape) - len(exp.composition) + 1) + - exp.composition[ellipsis_idx+1:]) + exp.composition[ellipsis_idx + 1:]) else: composition = exp.composition result = {} - for (axis_name, ), axis_length in zip(composition, shape): + for (axis_name,), axis_length in zip(composition, shape): if axis_name != '_': result[axis_name] = axis_length return result @@ -623,3 +618,165 @@ def asnumpy(tensor) -> 'numpy.ndarray': `numpy.ndarray`, converted to numpy """ return get_backend(tensor).to_numpy(tensor) + +def _validate_einsum_axis_name(axis_name): + if len(axis_name) == 0: + raise NotImplementedError("Singleton () axes are not yet supported in einsum.") + if len(axis_name) > 1: + raise NotImplementedError("Shape rearrangement is not yet supported in einsum.") + + axis_name = axis_name[0] + + if isinstance(axis_name, AnonymousAxis): + raise NotImplementedError("Anonymous axes are not yet supported in einsum.") + if len(axis_name) == 0: + raise RuntimeError("Encountered empty axis name in einsum.") + if not isinstance(axis_name, str): + raise RuntimeError("Axis name in einsum must be a string.") + + +@functools.lru_cache(256) +def _compactify_pattern_for_einsum(pattern: str) -> str: + if "->" not in pattern: + # numpy allows this, so make sure users + # don't accidentally do something like this. + raise ValueError("Einsum pattern must contain '->'.") + lefts, right = pattern.split('->') + lefts = lefts.split(',') + + lefts = [ + ParsedExpression(left, allow_underscore=True, allow_duplicates=True) + for left in lefts + ] + + right = ParsedExpression(right, allow_underscore=True) + + # Start from 'a' and go up to 'Z' + output_axis_names = string.ascii_letters + i = 0 + axis_name_mapping = {} + + left_patterns = [] + for left in lefts: + left_pattern = "" + for raw_axis_name in left.composition: + + if raw_axis_name == _ellipsis: + left_pattern += '...' + continue + + _validate_einsum_axis_name(raw_axis_name) + axis_name = raw_axis_name[0] + if axis_name not in axis_name_mapping: + if i >= len(output_axis_names): + raise RuntimeError("Too many axes in einsum.") + axis_name_mapping[axis_name] = output_axis_names[i] + i += 1 + + left_pattern += axis_name_mapping[axis_name] + left_patterns.append(left_pattern) + + compact_pattern = ",".join(left_patterns) + "->" + + for raw_axis_name in right.composition: + if raw_axis_name == _ellipsis: + compact_pattern += '...' + continue + + _validate_einsum_axis_name(raw_axis_name) + axis_name = raw_axis_name[0] + + if axis_name not in axis_name_mapping: + raise EinopsError(f"Unknown axis {axis_name} on right side of einsum {pattern}.") + + compact_pattern += axis_name_mapping[axis_name] + + return compact_pattern + + +@typing.overload +def einsum(tensor: Tensor, pattern: str) -> Tensor: ... +@typing.overload +def einsum(tensor1: Tensor, tensor2: Tensor, pattern: str) -> Tensor: ... +@typing.overload +def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, pattern: str) -> Tensor: ... +@typing.overload +def einsum(tensor1: Tensor, tensor2: Tensor, tensor3: Tensor, tensor4: Tensor, pattern: str) -> Tensor: ... + + +def einsum(*tensors_and_pattern: List[Union[Tensor, str]]) -> Tensor: + """ + einops.einsum calls einsum operations with einops-style named + axes indexing, computing tensor products with an arbitrary + number of tensors. Unlike typical einsum syntax, here you must + pass tensors first, and then the pattern. + + Also, note that rearrange operations such as `"(batch chan) out"`, + or singleton axes `()`, are not currently supported. + + Examples: + + For a given pattern such as: + ```python + >>> x, y, z = np.random.randn(3, 20, 20, 20) + >>> output = einsum(x, y, z, "a b c, c b d, a g k -> a b k") + + ``` + the following formula is computed: + ```tex + output[a, b, k] = + \sum_{c, d, g} x[a, b, c] * y[c, b, d] * z[a, g, k] + ``` + where the summation over `c`, `d`, and `g` is performed + because those axes names do not appear on the right-hand side. + + Let's see some additional examples: + ```python + # Filter a set of images: + >>> batched_images = np.random.randn(128, 16, 16) + >>> filters = np.random.randn(16, 16, 30) + >>> result = einsum(batched_images, filters, + ... "batch h w, h w channel -> batch channel") + >>> result.shape + (128, 30) + + # Matrix multiplication, with an unknown input shape: + >>> batch_shape = (50, 30) + >>> data = np.random.randn(*batch_shape, 20) + >>> weights = np.random.randn(10, 20) + >>> result = einsum(weights, data, + ... "out_dim in_dim, ... in_dim -> ... out_dim") + >>> result.shape + (50, 30, 10) + + # Matrix trace on a single tensor: + >>> matrix = np.random.randn(10, 10) + >>> result = einsum(matrix, "i i ->") + >>> result.shape + () + + ``` + + Parameters: + tensors: tensors of any supported library (numpy, jittor). + pattern: string, einsum pattern, with commas + separating specifications for each tensor. + + Returns: + Tensor of the same type as input, after processing with einsum. + + """ + if len(tensors_and_pattern) <= 1: + raise ValueError( + "`einops.einsum` takes at minimum two arguments: the tensors (at least one)," + " followed by the pattern." + ) + pattern = tensors_and_pattern[-1] + if not isinstance(pattern, str): + raise ValueError( + "The last argument passed to `einops.einsum` must be a string," + " representing the einsum pattern." + ) + tensors = tensors_and_pattern[:-1] + pattern = _compactify_pattern_for_einsum(pattern) + return get_backend(tensors[0]).einsum(pattern, *tensors) \ No newline at end of file diff --git a/python/jittor/einops/layers/_einmix.py b/python/jittor/einops/layers/_einmix.py index 29fb0a5a..7f5c5c68 100644 --- a/python/jittor/einops/layers/_einmix.py +++ b/python/jittor/einops/layers/_einmix.py @@ -49,8 +49,8 @@ class _EinmixMixin: Parameters :param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output - :param weight_shape: axes of weight. Tensor od this shape is created, stored, and optimized in a layer - :param bias_shape: axes of bias added to output. + :param weight_shape: axes of weight. A tensor of this shape is created, stored, and optimized in a layer + :param bias_shape: axes of bias added to output. Weights of this shape are created and stored. If `None` (the default), no bias is added. :param axes_lengths: dimensions of weight tensor """ super().__init__() @@ -58,7 +58,9 @@ class _EinmixMixin: self.weight_shape = weight_shape self.bias_shape = bias_shape self.axes_lengths = axes_lengths + self.initialize_einmix(pattern=pattern, weight_shape=weight_shape, bias_shape=bias_shape, axes_lengths=axes_lengths) + def initialize_einmix(self, pattern, weight_shape, bias_shape, axes_lengths): left_pattern, right_pattern = pattern.split('->') left = ParsedExpression(left_pattern) right = ParsedExpression(right_pattern) @@ -84,7 +86,7 @@ class _EinmixMixin: names += group composition = ' '.join(names) pre_reshape_pattern = f'{left_pattern}->{composition}' - pre_reshape_lengths = {name: length for name, length in self.axes_lengths.items() if name in names} + pre_reshape_lengths = {name: length for name, length in axes_lengths.items() if name in names} if any(len(group) != 1 for group in right.composition): names = [] @@ -134,8 +136,7 @@ class _EinmixMixin: _bias_shape.append(1) else: _bias_shape = None - _bias_input_size = None - + weight_bound = (3 / _fan_in) ** 0.5 bias_bound = (1 / _fan_in) ** 0.5 self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound) diff --git a/python/jittor/einops/layers/jittor.py b/python/jittor/einops/layers/jittor.py index 2bcb18e1..e2696b87 100644 --- a/python/jittor/einops/layers/jittor.py +++ b/python/jittor/einops/layers/jittor.py @@ -6,27 +6,18 @@ import numpy as np from jittor.einops.layers import RearrangeMixin, ReduceMixin from jittor.einops.layers._einmix import _EinmixMixin -from jittor.einops._jittor_specific import apply_for_scriptable_jittor __author__ = 'Ruiyang Liu' class Rearrange(RearrangeMixin, jt.nn.Module): def execute(self, input): - return apply_for_scriptable_jittor(self._recipe, input, reduction_type='rearrange') - - def _apply_recipe(self, x): - # overriding parent method to prevent it's scripting - pass - + return self._apply_recipe(input) + class Reduce(ReduceMixin, jt.nn.Module): def execute(self, input): - return apply_for_scriptable_jittor(self._recipe, input, reduction_type=self.reduction) - - def _apply_recipe(self, x): - # overriding parent method to prevent it's scripting - pass + return self._apply_recipe(input) class EinMix(_EinmixMixin, jt.nn.Module): diff --git a/python/jittor/einops/parsing.py b/python/jittor/einops/parsing.py index 77513b2f..e298d6b3 100644 --- a/python/jittor/einops/parsing.py +++ b/python/jittor/einops/parsing.py @@ -26,7 +26,7 @@ class ParsedExpression: non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') and keeps some information important for downstream """ - def __init__(self, expression, *, allow_underscore: bool = False): + def __init__(self, expression, *, allow_underscore: bool = False, allow_duplicates: bool = False): self.has_ellipsis: bool = False self.has_ellipsis_parenthesized: Optional[bool] = None self.identifiers: Set[str] = set() @@ -48,7 +48,7 @@ class ParsedExpression: def add_axis_name(x): if x is not None: if x in self.identifiers: - if not (allow_underscore and x == "_"): + if not (allow_underscore and x == "_") and not allow_duplicates: raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) if x == _ellipsis: self.identifiers.add(_ellipsis) diff --git a/python/jittor/test/test_einops.py b/python/jittor/test/test_einops.py new file mode 100644 index 00000000..31d9cffa --- /dev/null +++ b/python/jittor/test/test_einops.py @@ -0,0 +1,616 @@ +# *************************************************************** +# Copyright (c) 2022 Jittor. All Rights Reserved. +# Maintainers: DongYang Li . +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +from collections import namedtuple +import tempfile +import pickle +import itertools +from jittor.einops.einops import (rearrange, reduce, _enumerate_directions, _reductions) +from jittor.einops import EinopsError +import jittor as jt +import numpy +import unittest + +# tests/__init__.py +import os +from jittor.einops import _backends +import warnings + +flag_to_bool = { + '': False, + '0': False, + '1': True, +} + + +def collect_test_backends(symbolic=False, layers=False): + """ + :param symbolic: symbolic or imperative frameworks? + :param layers: layers or operations? + :return: list of backends satisfying set conditions + """ + if not symbolic: + if not layers: + backend_types = [ + _backends.NumpyBackend, + _backends.JittorBackend, + ] + else: + backend_types = [ + _backends.JittorBackend, + ] + else: + backend_types = [] + result = [] + for backend_type in backend_types: + try: + result.append(backend_type()) + except ImportError: + # problem with backend installation fails a specific test function, + # but will be skipped in all other test cases + warnings.warn('backend could not be initialized for tests: {}'.format(backend_type)) + return result + + +# test/test_ops.py + +imp_op_backends = collect_test_backends(symbolic=False, layers=False) + +# test/test_layer.py + + +class TestSlice(unittest.TestCase): + + def test_anonymous_axes(self): + x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6]) + for pattern, axis_dimensions in test_cases_repeat_anonymous: + check_reversion(x, pattern, **axis_dimensions) + + def test_repeat_imperatives(self): + x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5]) + for backend in imp_op_backends: + print('Repeat tests for ', backend.framework_name) + + for pattern, axis_dimensions in repeat_test_cases: + expected = reduce(x, pattern, reduction='repeat', **axis_dimensions) + converted = backend.from_numpy(x) + repeated = reduce(converted, pattern, reduction='repeat', **axis_dimensions) + result = backend.to_numpy(repeated) + assert numpy.array_equal(result, expected) + + def test_repeat_numpy(self): + # check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well + x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5]) + x1 = reduce(x, 'a b c -> copy a b c ', reduction='repeat', copy=1) + assert numpy.array_equal(x[None], x1) + for pattern, axis_dimensions in repeat_test_cases: + check_reversion(x, pattern, **axis_dimensions) + + def test_tiling_imperatives(self): + for backend in imp_op_backends: + print('Tiling tests for ', backend.framework_name) + input = numpy.arange(2 * 3 * 5, dtype='int64').reshape([2, 1, 3, 1, 5]) + test_cases = [ + (1, 1, 1, 1, 1), + (1, 2, 1, 3, 1), + (3, 1, 1, 4, 1), + ] + for repeats in test_cases: + expected = numpy.tile(input, repeats) + converted = backend.from_numpy(input) + repeated = backend.tile(converted, repeats) + result = backend.to_numpy(repeated) + assert numpy.array_equal(result, expected) + + def test_gradients_imperatives(self): + # lazy - just checking reductions + for reduction in _reductions: + x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype('float32') + results = {} + for backend in imp_op_backends: + y0 = backend.from_numpy(x) + if not 'jittor' in backend.framework_name and not hasattr(y0, 'grad'): + continue + y1 = reduce(y0, 'a b c -> c a', reduction=reduction) + y2 = reduce(y1, 'c a -> a c', reduction=reduction) + y3 = reduce(y2, 'a (c1 c2) -> a', reduction=reduction, c1=2) + y4 = reduce(y3, '... -> ', reduction=reduction) + if 'jittor' in backend.framework_name: + grad = backend.jittor.grad(y4, y0) + else: + y4.backward() + grad = y0.grad + results[backend.framework_name] = backend.to_numpy(grad) + + print('comparing gradients for', results.keys()) + for name1, grad1 in results.items(): + for name2, grad2 in results.items(): + assert numpy.allclose(grad1, grad2), [name1, name2, 'provided different gradients'] + + def test_concatenations_and_stacking(self): + for backend in imp_op_backends: + print('testing shapes for ', backend.framework_name) + for n_arrays in [1, 2, 5]: + shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6] + for shape in shapes: + if (backend.framework_name == 'jittor')\ + and len(shape) == 0: + # jittor stores scalar in 1d array + continue + arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)] + arrays2 = [backend.from_numpy(array) for array in arrays1] + result0 = numpy.asarray(arrays1) + result1 = rearrange(arrays1, '...->...') + result2 = rearrange(arrays2, '...->...') + assert numpy.array_equal(result0, result1) + assert numpy.array_equal(result1, backend.to_numpy(result2)) + + result1 = rearrange(arrays1, 'b ... -> ... b') + result2 = rearrange(arrays2, 'b ... -> ... b') + assert numpy.array_equal(result1, backend.to_numpy(result2)) + + def test_enumerating_directions(self): + for backend in imp_op_backends: + print('testing directions for', backend.framework_name) + for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]: + if (backend.framework_name == 'jittor')\ + and len(shape) == 0: + # jittor stores scalar in 1d array + continue + x = numpy.arange(numpy.prod(shape)).reshape(shape) + axes1 = _enumerate_directions(x) + axes2 = _enumerate_directions(backend.from_numpy(x)) + assert len(axes1) == len(axes2) == len(shape) + for ax1, ax2 in zip(axes1, axes2): + ax2 = backend.to_numpy(ax2) + assert ax1.shape == ax2.shape + assert numpy.allclose(ax1, ax2) + + def test_reduction_with_callable_imperatives(self): + x_numpy = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]).astype('float32') + x_numpy /= x_numpy.max() + + def logsumexp_jittor(x, tuple_of_axes): + import jittor as jt + return jt.nn.logsumexp(x, tuple_of_axes) + + def logsumexp_numpy(x, tuple_of_axes): + # very naive logsumexp to compare to + minused = x.max(tuple_of_axes) + y = x - x.max(tuple_of_axes, keepdims=True) + y = numpy.exp(y) + y = numpy.sum(y, axis=tuple_of_axes) + return numpy.log(y) + minused + + from jittor.einops._backends import JittorBackend, NumpyBackend + backend2callback = { + JittorBackend.framework_name: logsumexp_jittor, + NumpyBackend.framework_name: logsumexp_numpy, + } + + for backend in imp_op_backends: + if backend.framework_name not in backend2callback: + continue + + backend_callback = backend2callback[backend.framework_name] + + x_backend = backend.from_numpy(x_numpy) + for pattern1, pattern2 in equivalent_reduction_patterns: + print('Test reduction with callable for ', backend.framework_name, pattern1, pattern2) + output_numpy = reduce(x_numpy, pattern1, reduction=logsumexp_numpy) + output_backend = reduce(x_backend, pattern1, reduction=backend_callback) + assert numpy.allclose( + output_numpy, + backend.to_numpy(output_backend), + ) + + def test_reduction_stress_imperatives(self): + for backend in imp_op_backends: + print('Stress-testing reduction for ', backend.framework_name) + for reduction in _reductions + ('rearrange',): + dtype = 'int64' + coincide = numpy.array_equal + if reduction in ['mean', 'prod']: + dtype = 'float64' + coincide = numpy.allclose + for n_axes in range(11): + shape = numpy.random.randint(2, 4, size=n_axes) + permutation = numpy.random.permutation(n_axes) + skipped = 0 if reduction == 'rearrange' else numpy.random.randint(n_axes + 1) + left = ' '.join('x' + str(i) for i in range(n_axes)) + right = ' '.join('x' + str(i) for i in permutation[skipped:]) + pattern = left + '->' + right + x = numpy.arange(1, 1 + numpy.prod(shape), dtype=dtype).reshape(shape) + if reduction == 'prod': + x /= x.mean() # to avoid overflows + result1 = reduce(x, pattern, reduction=reduction) + result2 = x.transpose(permutation) + if skipped > 0: + result2 = getattr(result2, reduction)(axis=tuple(range(skipped))) + assert coincide(result1, result2) + check_op_against_numpy(backend, x, pattern, reduction=reduction, axes_lengths={}, is_symbolic=False) + + def test_reduction_imperatives(self): + for backend in imp_op_backends: + print('Reduction tests for ', backend.framework_name) + for reduction in _reductions: + # slight redundancy for simpler order - numpy version is evaluated multiple times + input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype='int64').reshape([2, 3, 4, 5, 6]) + if reduction in ['mean', 'prod']: + input = input / input.astype('float64').mean() + test_cases = [ + ['a b c d e -> ', {}, + getattr(input, reduction)()], + ['a ... -> ', {}, + getattr(input, reduction)()], + ['(a1 a2) ... (e1 e2) -> ', dict(a1=1, e2=2), + getattr(input, reduction)()], + ['a b c d e -> (e c) a', {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])], + ['a ... c d e -> (e c) a', {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])], + ['a b c d e ... -> (e c) a', {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])], + ['a b c d e -> (e c a)', {}, + getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])], + ['(a a2) ... -> (a2 a) ...', dict(a2=1), + input], + ] + for pattern, axes_lengths, expected_result in test_cases: + result = reduce(backend.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths) + result = backend.to_numpy(result) + assert numpy.allclose(result, expected_result) + + def test_rearrange_permutations_numpy(self): + # tests random permutation of axes against two independent numpy ways + for n_axes in range(1, 10): + input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) + permutation = numpy.random.permutation(n_axes) + left_expression = ' '.join('i' + str(axis) for axis in range(n_axes)) + right_expression = ' '.join('i' + str(axis) for axis in permutation) + expression = left_expression + ' -> ' + right_expression + result = rearrange(input, expression) + + for pick in numpy.random.randint(0, 2, [10, n_axes]): + assert input[tuple(pick)] == result[tuple(pick[permutation])] + + for n_axes in range(1, 10): + input = numpy.arange(2 ** n_axes).reshape([2] * n_axes) + permutation = numpy.random.permutation(n_axes) + left_expression = ' '.join('i' + str(axis) for axis in range(n_axes)[::-1]) + right_expression = ' '.join('i' + str(axis) for axis in permutation[::-1]) + expression = left_expression + ' -> ' + right_expression + result = rearrange(input, expression) + assert result.shape == input.shape + expected_result = numpy.zeros_like(input) + for original_axis, result_axis in enumerate(permutation): + expected_result |= ((input >> original_axis) & 1) << result_axis + + assert numpy.array_equal(result, expected_result) + + def test_rearrange_consistency_numpy(self): + shape = [1, 2, 3, 5, 7, 11] + x = numpy.arange(numpy.prod(shape)).reshape(shape) + for pattern in [ + 'a b c d e f -> a b c d e f', + 'b a c d e f -> a b d e f c', + 'a b c d e f -> f e d c b a', + 'a b c d e f -> (f e) d (c b a)', + 'a b c d e f -> (f e d c b a)', + ]: + result = rearrange(x, pattern) + assert len(numpy.setdiff1d(x, result)) == 0 + assert result.dtype == x.dtype + + result = rearrange(x, 'a b c d e f -> a (b) (c d e) f') + assert numpy.array_equal(x.flatten(), result.flatten()) + + result = rearrange(x, 'a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11') + assert numpy.array_equal(x, result) + + result1 = rearrange(x, 'a b c d e f -> f e d c b a') + result2 = rearrange(x, 'f e d c b a -> a b c d e f') + assert numpy.array_equal(result1, result2) + + result = rearrange(rearrange(x, 'a b c d e f -> (f d) c (e b) a'), '(f d) c (e b) a -> a b c d e f', b=2, d=5) + assert numpy.array_equal(x, result) + + sizes = dict(zip('abcdef', shape)) + temp = rearrange(x, 'a b c d e f -> (f d) c (e b) a', **sizes) + result = rearrange(temp, '(f d) c (e b) a -> a b c d e f', **sizes) + assert numpy.array_equal(x, result) + + x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4]) + result = rearrange(x2, 'a b c -> b c a') + assert x2[1, 2, 3] == result[2, 3, 1] + assert x2[0, 1, 2] == result[1, 2, 0] + + def test_ellipsis_ops_imperative(self): + """ Checking various patterns against numpy """ + x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) + for is_symbolic in [True, False]: + for backend in collect_test_backends(symbolic=is_symbolic, layers=False): + for pattern in identity_patterns + list(itertools.chain(*equivalent_rearrange_patterns)): + check_op_against_numpy(backend, x, pattern, axes_lengths={}, + reduction='rearrange', is_symbolic=is_symbolic) + + for reduction in ['min', 'max', 'sum']: + for pattern in itertools.chain(*equivalent_reduction_patterns): + check_op_against_numpy(backend, x, pattern, axes_lengths={}, + reduction=reduction, is_symbolic=is_symbolic) + + def test_ellipsis_ops_numpy(self): + x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]) + for pattern in identity_patterns: + assert numpy.array_equal(x, rearrange(x, pattern)), pattern + + for pattern1, pattern2 in equivalent_rearrange_patterns: + assert numpy.array_equal(rearrange(x, pattern1), rearrange(x, pattern2)) + + for reduction in ['min', 'max', 'sum']: + for pattern1, pattern2 in equivalent_reduction_patterns: + assert numpy.array_equal(reduce(x, pattern1, reduction=reduction), + reduce(x, pattern2, reduction=reduction)) + + # now just check coincidence with numpy + all_rearrange_patterns = [*identity_patterns] + for pattern_pairs in equivalent_rearrange_patterns: + all_rearrange_patterns.extend(pattern_pairs) + + def test_collapsed_ellipsis_errors_out(self): + x = numpy.zeros([1, 1, 1, 1, 1]) + rearrange(x, 'a b c d ... -> a b c ... d') + error = 0 + try: + rearrange(x, 'a b c d (...) -> a b c ... d') + except Exception as e: + error = 1 + assert error == 1 + + rearrange(x, '... -> (...)') + error = 0 + try: + rearrange(x, '(...) -> (...)') + except Exception as e: + error = 1 + assert error == 1 + + def test_rearrange_imperative(self): + for backend in collect_test_backends(symbolic=False, layers=True): + print('Test layer for ', backend.framework_name) + + for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns: + x = numpy.arange(numpy.prod(input_shape), dtype='float32').reshape(input_shape) + result_numpy = rearrange(x, pattern, **axes_lengths) + layer = backend.layers().Rearrange(pattern, **axes_lengths) + for shape in wrong_shapes: + try: + layer(backend.from_numpy(numpy.zeros(shape, dtype='float32'))) + except: + pass + else: + raise AssertionError('Failure expected') + + # simple pickling / unpickling + layer2 = pickle.loads(pickle.dumps(layer)) + result1 = backend.to_numpy(layer(backend.from_numpy(x))) + result2 = backend.to_numpy(layer2(backend.from_numpy(x))) + assert numpy.allclose(result_numpy, result1) + assert numpy.allclose(result1, result2) + + just_sum = backend.layers().Reduce('...->', reduction='sum') + + + variable = backend.from_numpy(x) + result = just_sum(layer(variable)) + + if 'jittor' in backend.framework_name: + grad = backend.jittor.grad(result, variable) + else: + result.backward() + grad = variable.grad + + assert numpy.allclose(backend.to_numpy(grad), 1) + + def test_reduce_imperative(self): + for backend in collect_test_backends(symbolic=False, layers=True): + print('Test layer for ', backend.framework_name) + for reduction in _reductions: + for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns: + print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes) + x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype='float32').reshape(input_shape) + x /= x.mean() + result_numpy = reduce(x, pattern, reduction, **axes_lengths) + layer = backend.layers().Reduce(pattern, reduction, **axes_lengths) + for shape in wrong_shapes: + try: + layer(backend.from_numpy(numpy.zeros(shape, dtype='float32'))) + except: + pass + else: + raise AssertionError('Failure expected') + + # simple pickling / unpickling + layer2 = pickle.loads(pickle.dumps(layer)) + result1 = backend.to_numpy(layer(backend.from_numpy(x))) + result2 = backend.to_numpy(layer2(backend.from_numpy(x))) + assert numpy.allclose(result_numpy, result1) + assert numpy.allclose(result1, result2) + + just_sum = backend.layers().Reduce('...->', reduction='sum') + + + variable = backend.from_numpy(x) + result = just_sum(layer(variable)) + + if 'jittor' in backend.framework_name: + grad = backend.jittor.grad(result, variable) + grad = backend.to_numpy(grad) + else: + result.backward() + grad = backend.to_numpy(variable.grad) + if reduction == 'sum': + assert numpy.allclose(grad, 1) + if reduction == 'mean': + assert numpy.allclose(grad, grad.min()) + if reduction in ['max', 'min']: + assert numpy.all(numpy.in1d(grad, [0, 1])) + assert numpy.sum(grad) > 0.5 + + def test_jittor_layer(self): + has_jittor = any(backend.framework_name == 'jittor' for backend in collect_test_backends(symbolic=False, layers=True)) + if has_jittor: + # checked that jittor present + import jittor + + rtol = 1e-05 + atol = 1e-08 + def allclose(input, other): return jittor.all(jittor.abs(input-other) <= atol+rtol*jittor.abs(other)) + model1 = create_jittor_model(use_reduce=True) + model2 = create_jittor_model(use_reduce=False) + input = jittor.randn([10, 3, 32, 32]) + # random models have different predictions + assert not allclose(model1(input), model2(input)) + model2.load_state_dict(pickle.loads(pickle.dumps(model1.state_dict()))) + assert allclose(model1(input), model2(input)) + + +testcase = namedtuple('testcase', ['pattern', 'axes_lengths', 'input_shape', 'wrong_shapes']) + +rearrangement_patterns = [ + testcase('b c h w -> b (c h w)', dict(c=20), (10, 20, 30, 40), + [(), (10,), (10, 10, 10), (10, 21, 30, 40), [1, 20, 1, 1, 1]]), + testcase('b c (h1 h2) (w1 w2) -> b (c h2 w2) h1 w1', dict(h2=2, w2=2), (10, 20, 30, 40), + [(), (1, 1, 1, 1), (1, 10, 3), ()]), + testcase('b ... c -> c b ...', dict(b=10), (10, 20, 30), + [(), (10,), (5, 10)]), +] + +reduction_patterns = rearrangement_patterns + [ + testcase('b c h w -> b ()', dict(b=10), (10, 20, 30, 40), + [(10,), (10, 20, 30)]), + testcase('b c (h1 h2) (w1 w2) -> b c h1 w1', dict(h1=15, h2=2, w2=2), (10, 20, 30, 40), + [(10, 20, 31, 40)]), + testcase('b ... c -> b', dict(b=10), (10, 20, 30, 40), + [(10,), (11, 10)]), +] + +equivalent_reduction_patterns = [ + ('a b c d e -> ', ' ... -> '), + ('a b c d e -> (e a)', 'a ... e -> (e a)'), + ('a b c d e -> d (a e)', ' a b c d e ... -> d (a e) '), + ('a b c d e -> (a b)', ' ... c d e -> (...) '), +] + +equivalent_rearrange_patterns = [ + ('a b c d e -> (a b) c d e', 'a b ... -> (a b) ... '), + ('a b c d e -> a b (c d) e', '... c d e -> ... (c d) e'), + ('a b c d e -> a b c d e', '... -> ... '), + ('a b c d e -> (a b c d e)', '... -> (...)'), + ('a b c d e -> b (c d e) a', 'a b ... -> b (...) a'), + ('a b c d e -> b (a c d) e', 'a b ... e -> b (a ...) e'), +] + +identity_patterns = [ + '...->...', + 'a b c d e-> a b c d e', + 'a b c d e ...-> ... a b c d e', + 'a b c d e ...-> a ... b c d e', + '... a b c d e -> ... a b c d e', + 'a ... e-> a ... e', + 'a ... -> a ... ', + 'a ... c d e -> a (...) c d e', +] + +test_cases_repeat_anonymous = [ + # all assume that input has shape [1, 2, 4, 6] + ('a b c d -> c a d b', dict()), + ('a b c d -> (c 2 d a b)', dict(a=1, c=4, d=6)), + ('1 b c d -> (d copy 1) 3 b c ', dict(copy=3)), + ('1 ... -> 3 ... ', dict()), + ('() ... d -> 1 (copy1 d copy2) ... ', dict(copy1=2, copy2=3)), + ('1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)', dict()), + +] + +repeat_test_cases = [ + # all assume that input has shape [2, 3, 5] + ('a b c -> c a b', dict()), + ('a b c -> (c copy a b)', dict(copy=2, a=2, b=3, c=5)), + ('a b c -> (a copy) b c ', dict(copy=1)), + ('a b c -> (c a) (copy1 b copy2)', dict(a=2, copy1=1, copy2=2)), + ('a ... -> a ... copy', dict(copy=4)), + ('... c -> ... (copy1 c copy2)', dict(copy1=1, copy2=2)), + ('... -> ... ', dict()), + (' ... -> copy1 ... copy2 ', dict(copy1=2, copy2=3)), + ('a b c -> copy1 a copy2 b c () ', dict(copy1=2, copy2=1)), +] + + +def check_reversion(x, repeat_pattern, **sizes): + """Checks repeat pattern by running reduction """ + left, right = repeat_pattern.split('->') + reduce_pattern = right + '->' + left + repeated = reduce(x, repeat_pattern, reduction='repeat', **sizes) + reduced_min = reduce(repeated, reduce_pattern, reduction='min', **sizes) + reduced_max = reduce(repeated, reduce_pattern, reduction='max', **sizes) + assert numpy.array_equal(x, reduced_min) + assert numpy.array_equal(x, reduced_max) + + +def check_op_against_numpy(backend, numpy_input, pattern, axes_lengths, reduction='rearrange', is_symbolic=False): + """ + Helper to test result of operation (rearrange or transpose) against numpy + if reduction == 'rearrange', rearrange op is tested, otherwise reduce + """ + if len(numpy_input.shape) == 0: + return + + def operation(x): + if reduction == 'rearrange': + return rearrange(x, pattern, **axes_lengths) + else: + return reduce(x, pattern, reduction, **axes_lengths) + + numpy_result = operation(numpy_input) + check_equal = numpy.array_equal + p_none_dimension = 0.5 + if 'jittor' in backend.framework_name: + check_equal = numpy.allclose + p_none_dimension = 0 + if is_symbolic: + symbol_shape = [d if numpy.random.random() >= p_none_dimension else None for d in numpy_input.shape] + symbol = backend.create_symbol(shape=symbol_shape) + result_symbol = operation(symbol) + backend_result = backend.eval_symbol(result_symbol, [(symbol, numpy_input)]) + else: + backend_result = operation(backend.from_numpy(numpy_input)) + backend_result = backend.to_numpy(backend_result) + + check_equal(numpy_result, backend_result) + + +def create_jittor_model(use_reduce=False): + from jittor.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU + from jittor.einops.layers.jittor import Rearrange, Reduce, EinMix + return Sequential( + Conv2d(3, 6, kernel_size=(5, 5)), + Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2), + Conv2d(6, 16, kernel_size=(5, 5)), + Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2), + Rearrange('b c h w -> b (c h w)'), + Linear(16 * 5 * 5, 120), + ReLU(), + Linear(120, 84), + ReLU(), + EinMix('b c1 -> (b c2)', weight_shape='c1 c2', bias_shape='c2', c1=84, c2=84), + EinMix('(b c2) -> b c3', weight_shape='c2 c3', bias_shape='c3', c2=84, c3=84), + Linear(84, 10), + ) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file