Support einops for Jittor

This commit is contained in:
liuruiyang98 2022-07-13 21:57:06 +08:00 committed by Jittor
parent 8c9bfb639d
commit 13f9eaafc0
10 changed files with 1830 additions and 0 deletions

View File

@ -0,0 +1,8 @@
class EinopsError(RuntimeError):
""" Runtime error thrown by einops """
pass
__all__ = ['rearrange', 'reduce', 'repeat', 'parse_shape', 'asnumpy', 'EinopsError']
from jittor.einops.einops import rearrange, reduce, repeat, parse_shape, asnumpy

View File

@ -0,0 +1,255 @@
"""
Backends in `einops` are organized to meet the following requirements
- backends are not imported unless those are actually needed, because
- backends may not be installed
- importing all available backends will drive to significant memory footprint
- backends may by present but installed with errors (but never used),
importing may drive to crashes
- backend should be either symbolic or imperative (tensorflow is for both, but that causes problems)
- this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined
- if backend can't (temporarily) provide symbols for shape dimensions, UnknownSize objects are used
"""
import sys
import warnings
__author__ = 'Alex Rogozhnikov'
_backends = {}
_debug_importing = False
def get_backend(tensor) -> 'AbstractBackend':
"""
Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
If needed, imports package and creates backend
"""
for framework_name, backend in _backends.items():
if backend.is_appropriate_type(tensor):
return backend
# Find backend subclasses recursively
backend_subclasses = []
backends = AbstractBackend.__subclasses__()
while backends:
backend = backends.pop()
backends += backend.__subclasses__()
backend_subclasses.append(backend)
for BackendSubclass in backend_subclasses:
if _debug_importing:
print('Testing for subclass of ', BackendSubclass)
if BackendSubclass.framework_name not in _backends:
# check that module was already imported. Otherwise it can't be imported
if BackendSubclass.framework_name in sys.modules:
if _debug_importing:
print('Imported backend for ', BackendSubclass.framework_name)
backend = BackendSubclass()
_backends[backend.framework_name] = backend
if backend.is_appropriate_type(tensor):
return backend
raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))
class AbstractBackend:
""" Base backend class, major part of methods are only for debugging purposes. """
framework_name = None
def is_appropriate_type(self, tensor):
""" helper method should recognize tensors it can handle """
raise NotImplementedError()
def from_numpy(self, x):
raise NotImplementedError("framework doesn't support imperative execution")
def to_numpy(self, x):
raise NotImplementedError("framework doesn't support imperative execution")
def create_symbol(self, shape):
raise NotImplementedError("framework doesn't support symbolic computations")
def eval_symbol(self, symbol, input_dict):
raise NotImplementedError("framework doesn't support symbolic computations")
def arange(self, start, stop):
# supplementary method used only in testing, so should implement CPU version
raise NotImplementedError("framework doesn't implement arange")
def shape(self, x):
"""shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)"""
return x.shape
def reshape(self, x, shape):
return x.reshape(shape)
def transpose(self, x, axes):
return x.transpose(axes)
def reduce(self, x, operation, axes):
return getattr(x, operation)(axis=axes)
def stack_on_zeroth_dimension(self, tensors: list):
raise NotImplementedError()
def add_axis(self, x, new_position):
raise NotImplementedError()
def add_axes(self, x, n_axes, pos2len):
repeats = [1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return self.tile(x, tuple(repeats))
def tile(self, x, repeats):
"""repeats is a number of """
raise NotImplementedError()
def is_float_type(self, x):
# some backends (torch) can't compute average for non-floating types.
# Decided to drop average for all backends if type is not floating
raise NotImplementedError()
def layers(self):
raise NotImplementedError("backend does not provide layers")
def __repr__(self):
return "<einops backend for {}>".format(self.framework_name)
class UnknownSize:
""" pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """
def __floordiv__(self, other):
return self
def __eq__(self, other):
return True # we don't know actual size
def __mul__(self, other):
return self
def __rmul__(self, other):
return self
def __hash__(self):
return None.__hash__()
class NumpyBackend(AbstractBackend):
framework_name = 'numpy'
def __init__(self):
import numpy
self.np = numpy
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.np.ndarray)
def from_numpy(self, x):
return x
def to_numpy(self, x):
return x
def arange(self, start, stop):
return self.np.arange(start, stop)
def stack_on_zeroth_dimension(self, tensors: list):
return self.np.stack(tensors)
def tile(self, x, repeats):
return self.np.tile(x, repeats)
def is_float_type(self, x):
return x.dtype in ('float16', 'float32', 'float64', 'float128')
def add_axis(self, x, new_position):
return self.np.expand_dims(x, new_position)
class HashableTuple:
"""Overcomes non-hashability of symbolic elements"""
def __init__(self, elements: tuple):
self.elements = elements
def __iter__(self):
for x in self.elements:
yield x
def __len__(self):
return len(self.elements)
def __getitem__(self, item):
return self.elements[item]
class JittorBackend(AbstractBackend):
framework_name = 'jittor'
def __init__(self):
import jittor
self.jittor = jittor
def is_appropriate_type(self, tensor):
return isinstance(tensor, self.jittor.jittor_core.Var)
def from_numpy(self, x):
variable = self.jittor.array(x)
if self.is_float_type(variable):
# attach grad only to floating types
variable.requires_grad = True
return variable
def to_numpy(self, x):
return x.detach().numpy()
def arange(self, start, stop):
return self.jittor.arange(start, stop, dtype='int64')
def shape(self, x):
return HashableTuple(tuple(x.shape))
def reshape(self, x, shape):
return self.jittor.reshape(x, shape)
# def reduce(self, x, operation, axes):
# return getattr(x, operation)(dim=axes)
def reduce(self, x, operation, reduced_axes):
for axis in sorted(reduced_axes, reverse=True):
if operation == 'min':
x = x.min(dim=axis)
elif operation == 'max':
x = x.max(dim=axis)
elif operation in ['sum', 'mean', 'prod']:
x = getattr(x, operation)(dim=axis)
else:
raise NotImplementedError('Unknown reduction ', operation)
return x
def transpose(self, x, axes):
return x.permute(axes)
def stack_on_zeroth_dimension(self, tensors: list):
return self.jittor.stack(tensors)
def add_axes(self, x, n_axes, pos2len):
repeats = [-1] * n_axes
for axis_position, axis_length in pos2len.items():
x = self.add_axis(x, axis_position)
repeats[axis_position] = axis_length
return x.expand(repeats)
def tile(self, x, repeats):
return x.repeat(repeats)
def add_axis(self, x, new_position):
return self.jittor.unsqueeze(x, new_position)
def is_float_type(self, x):
return x.dtype in ["float16", "float32", "float64"]
def layers(self):
from jittor.einops.layers import jittor
return jittor

View File

@ -0,0 +1,84 @@
"""
Specialization of einops for jittor.
Unfortunately, jittor's jit scripting mechanism isn't strong enough,
and to have scripting supported at least for layers,
a number of changes is required, and this layer helps.
Importantly, whole lib is designed so that you can't use it
"""
from typing import Dict, List
import jittor as jt
from jittor.einops.einops import TransformRecipe, _reconstruct_from_shape_uncached
class JittorJitBackend:
"""
Completely static backend that mimics part of normal backend functionality
but restricted to jittor stuff only
"""
@staticmethod
def reduce(x: jt.jittor_core.Var, operation: str, reduced_axes: List[int]):
if operation == 'min':
return x.min(dims=reduced_axes)
elif operation == 'max':
return x.max(dims=reduced_axes)
elif operation == 'sum':
return x.sum(dims=reduced_axes)
elif operation == 'mean':
return x.mean(dims=reduced_axes)
elif operation == 'prod':
for i in list(sorted(reduced_axes))[::-1]:
x = x.prod(dim=i)
return x
else:
raise NotImplementedError('Unknown reduction ', operation)
@staticmethod
def transpose(x, axes: List[int]):
return x.permute(axes)
@staticmethod
def stack_on_zeroth_dimension(tensors: List[jt.jittor_core.Var]):
return jt.stack(tensors)
@staticmethod
def tile(x, repeats: List[int]):
return x.repeat(repeats)
@staticmethod
def add_axes(x, n_axes: int, pos2len: Dict[int, int]):
repeats = [1] * n_axes
for axis_position, axis_length in pos2len.items():
x = jt.unsqueeze(x, axis_position)
repeats[axis_position] = axis_length
return JittorJitBackend.tile(x, repeats)
@staticmethod
def is_float_type(x):
return x.dtype in ["float16", "float32", "float64"]
@staticmethod
def shape(x):
return x.shape
@staticmethod
def reshape(x, shape: List[int]):
return x.reshape(shape)
# mirrors einops.einops._apply_recipe
def apply_for_scriptable_jittor(recipe: TransformRecipe, tensor: jt.jittor_core.Var, reduction_type: str) -> jt.jittor_core.Var:
backend = JittorJitBackend
init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \
_reconstruct_from_shape_uncached(recipe, backend.shape(tensor))
tensor = backend.reshape(tensor, init_shapes)
if len(reduced_axes) > 0:
tensor = backend.reduce(tensor, operation=reduction_type, reduced_axes=reduced_axes)
tensor = backend.transpose(tensor, axes_reordering)
if len(added_axes) > 0:
tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes)
return backend.reshape(tensor, final_shapes)

View File

@ -0,0 +1,625 @@
import functools
import itertools
import typing
from collections import OrderedDict
from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar
if typing.TYPE_CHECKING:
import numpy as np
from jittor.einops import EinopsError
from jittor.einops._backends import get_backend
from jittor.einops.parsing import ParsedExpression, _ellipsis, AnonymousAxis
Tensor = TypeVar('Tensor')
ReductionCallable = Callable[[Tensor, List[int]], Tensor]
Reduction = Union[str, ReductionCallable]
_reductions = ('min', 'max', 'sum', 'mean', 'prod')
_ellipsis_not_in_parenthesis: List[int] = [-999]
_unknown_axis_length = -999999
def is_ellipsis_not_in_parenthesis(group: List[int]) -> bool:
if len(group) != 1:
return False
return group[0] == -999
def _product(sequence: List[int]) -> int:
""" minimalistic product that works both with numbers and symbols. Supports empty lists """
result = 1
for element in sequence:
result *= element
return result
def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend):
reduced_axes = tuple(reduced_axes)
if callable(reduction_type):
# custom callable
return reduction_type(tensor, reduced_axes)
else:
# one of built-in operations
if len(reduced_axes) == 0:
return tensor
assert reduction_type in _reductions
if reduction_type == 'mean':
if not backend.is_float_type(tensor):
raise NotImplementedError('reduce_mean is not available for non-floating tensors')
return backend.reduce(tensor, reduction_type, reduced_axes)
def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes):
# 'collapses' neighboring axes if those participate in the result pattern in the same order
# TODO add support for added_axes
assert len(axes_reordering) + len(reduced_axes) == len(init_shapes)
# joining consecutive axes that will be reduced
# possibly we can skip this if all backends can optimize this (not sure)
reduced_axes = tuple(sorted(reduced_axes))
for i in range(len(reduced_axes) - 1)[::-1]:
if reduced_axes[i] + 1 == reduced_axes[i + 1]:
removed_axis = reduced_axes[i + 1]
removed_length = init_shapes[removed_axis]
init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
init_shapes[removed_axis - 1] *= removed_length
reduced_axes = reduced_axes[:i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:])
# removing axes that are moved together during reshape
def build_mapping():
init_to_final = {}
for axis in range(len(init_shapes)):
if axis in reduced_axes:
init_to_final[axis] = None
else:
after_reduction = sum(x is not None for x in init_to_final.values())
init_to_final[axis] = list(axes_reordering).index(after_reduction)
return init_to_final
init_axis_to_final_axis = build_mapping()
for init_axis in range(len(init_shapes) - 1)[::-1]:
if init_axis_to_final_axis[init_axis] is None:
continue
if init_axis_to_final_axis[init_axis + 1] is None:
continue
if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]:
removed_axis = init_axis + 1
removed_length = init_shapes[removed_axis]
removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis))
reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes)
init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
init_shapes[removed_axis - 1] *= removed_length
old_reordering = axes_reordering
axes_reordering = []
for axis in old_reordering:
if axis == removed_axis_after_reduction:
pass
elif axis < removed_axis_after_reduction:
axes_reordering.append(axis)
else:
axes_reordering.append(axis - 1)
init_axis_to_final_axis = build_mapping()
return init_shapes, reduced_axes, axes_reordering, final_shapes
CookedRecipe = Tuple[List[int], List[int], List[int], Dict[int, int], List[int]]
class TransformRecipe:
"""
Recipe describes actual computation pathway.
Recipe can be applied to a tensor or variable.
"""
# structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+)
def __init__(self,
# list of expressions (or just sizes) for elementary axes as they appear in left expression.
# this is what (after computing unknown parts) will be a shape after first transposition.
# If ellipsis is present, it forms one dimension here (in the right position).
elementary_axes_lengths: List[int],
# each dimension in input can help to reconstruct length of one elementary axis
# or verify one of dimensions. Each element points to element of elementary_axes_lengths
input_composite_axes: List[Tuple[List[int], List[int]]],
# indices of axes to be squashed
reduced_elementary_axes: List[int],
# in which order should axes be reshuffled after reduction
axes_permutation: List[int],
# at which positions which of elementary axes should appear
added_axes: Dict[int, int],
# ids of axes as they appear in result, again pointers to elementary_axes_lengths,
# only used to infer result dimensions
output_composite_axes: List[List[int]],
# positions of ellipsis in lhs and rhs of expression
ellipsis_position_in_lhs: Optional[int] = None,
):
self.elementary_axes_lengths: List[int] = elementary_axes_lengths
self.input_composite_axes: List[Tuple[List[int], List[int]]] = input_composite_axes
self.output_composite_axes: List[List[int]] = output_composite_axes
self.axes_permutation: List[int] = axes_permutation
self.added_axes: Dict[int, int] = added_axes
# This is redundant information, but more convenient to use
self.reduced_elementary_axes: List[int] = reduced_elementary_axes
# setting to a large number to avoid handling Nones in reconstruct_from_shape
self.ellipsis_position_in_lhs: int = ellipsis_position_in_lhs if ellipsis_position_in_lhs is not None else 10000
def _reconstruct_from_shape_uncached(self: TransformRecipe, shape: List[int]) -> CookedRecipe:
"""
Reconstruct all actual parameters using shape.
Shape is a tuple that may contain integers, shape symbols (tf, keras, theano) and UnknownSize (keras, mxnet)
known axes can be integers or symbols, but not Nones.
"""
axes_lengths: List[int] = list(self.elementary_axes_lengths)
if self.ellipsis_position_in_lhs != 10000:
if len(shape) < len(self.input_composite_axes) - 1:
raise EinopsError('Expected at least {} dimensions, got {}'.format(
len(self.input_composite_axes) - 1, len(shape)))
else:
if len(shape) != len(self.input_composite_axes):
raise EinopsError('Expected {} dimensions, got {}'.format(len(self.input_composite_axes), len(shape)))
ellipsis_shape: List[int] = []
for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composite_axes):
before_ellipsis = input_axis
after_ellipsis = input_axis + len(shape) - len(self.input_composite_axes)
if input_axis == self.ellipsis_position_in_lhs:
assert len(known_axes) == 0 and len(unknown_axes) == 1
unknown_axis, = unknown_axes
ellipsis_shape = shape[before_ellipsis:after_ellipsis + 1]
for d in ellipsis_shape:
if d is None:
raise EinopsError("Couldn't infer shape for one or more axes represented by ellipsis")
total_dim_size: int = _product(ellipsis_shape)
axes_lengths[unknown_axis] = total_dim_size
else:
if input_axis < self.ellipsis_position_in_lhs:
length = shape[before_ellipsis]
else:
length = shape[after_ellipsis]
known_product = 1
for axis in known_axes:
known_product *= axes_lengths[axis]
if len(unknown_axes) == 0:
if isinstance(length, int) and isinstance(known_product, int) and length != known_product:
raise EinopsError('Shape mismatch, {} != {}'.format(length, known_product))
# this is enforced when recipe is created
# elif len(unknown_axes) > 1:
# raise EinopsError(
# "Lengths of two or more axes in parenthesis not provided (dim={}), can't infer dimensions".
# format(known_product)
# )
else:
if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0:
raise EinopsError("Shape mismatch, can't divide axis of length {} in chunks of {}".format(
length, known_product))
unknown_axis: int = unknown_axes[0]
inferred_length: int = length // known_product
axes_lengths[unknown_axis] = inferred_length
# at this point all axes_lengths are computed (either have values or variables, but not Nones)
# TODO more readable expression
init_shapes = axes_lengths[:len(axes_lengths) - len(self.added_axes)]
final_shapes: List[int] = []
for output_axis, grouping in enumerate(self.output_composite_axes):
if is_ellipsis_not_in_parenthesis(grouping):
final_shapes.extend(ellipsis_shape)
else:
lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping]
final_shapes.append(_product(lengths))
reduced_axes = self.reduced_elementary_axes
axes_reordering = self.axes_permutation
added_axes: Dict[int, int] = {
pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items()}
# if optimize:
# assert len(self.added_axes) == 0
# return _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes)
return init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes
_reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached)
def _apply_recipe(recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction) -> Tensor:
# this method works for all backends but not compilable with
backend = get_backend(tensor)
init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \
_reconstruct_from_shape(recipe, backend.shape(tensor))
tensor = backend.reshape(tensor, init_shapes)
tensor = _reduce_axes(tensor, reduction_type=reduction_type, reduced_axes=reduced_axes, backend=backend)
tensor = backend.transpose(tensor, axes_reordering)
if len(added_axes) > 0:
tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes)
return backend.reshape(tensor, final_shapes)
@functools.lru_cache(256)
def _prepare_transformation_recipe(pattern: str,
operation: Reduction,
axes_lengths: Tuple[Tuple, ...]) -> TransformRecipe:
""" Perform initial parsing of pattern and provided supplementary info
axes_lengths is a tuple of tuples (axis_name, axis_length)
"""
left, rght = pattern.split('->')
left = ParsedExpression(left)
rght = ParsedExpression(rght)
# checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction
if not left.has_ellipsis and rght.has_ellipsis:
raise EinopsError('Ellipsis found in right side, but not left side of a pattern {}'.format(pattern))
if left.has_ellipsis and left.has_ellipsis_parenthesized:
raise EinopsError('Ellipsis is parenthesis in the left side is not allowed: {}'.format(pattern))
if operation == 'rearrange':
difference = set.symmetric_difference(left.identifiers, rght.identifiers)
if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes:
raise EinopsError('Non-unitary anonymous axes are not supported in rearrange (exception is length 1)')
if len(difference) > 0:
raise EinopsError('Identifiers only on one side of expression (should be on both): {}'.format(difference))
elif operation == 'repeat':
difference = set.difference(left.identifiers, rght.identifiers)
if len(difference) > 0:
raise EinopsError('Unexpected identifiers on the left side of repeat: {}'.format(difference))
axes_without_size = set.difference({ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)},
{*left.identifiers, *(ax for ax, _ in axes_lengths)})
if len(axes_without_size) > 0:
raise EinopsError('Specify sizes for new axes in repeat: {}'.format(axes_without_size))
elif operation in _reductions or callable(operation):
difference = set.difference(rght.identifiers, left.identifiers)
if len(difference) > 0:
raise EinopsError('Unexpected identifiers on the right side of reduce {}: {}'.format(operation, difference))
else:
raise EinopsError('Unknown reduction {}. Expect one of {}.'.format(operation, _reductions))
# parsing all dimensions to find out lengths
axis_name2known_length = OrderedDict()
for composite_axis in left.composition:
for axis_name in composite_axis:
if isinstance(axis_name, AnonymousAxis):
axis_name2known_length[axis_name] = axis_name.value
else:
axis_name2known_length[axis_name] = _unknown_axis_length
# axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point
repeat_axes_names = []
for axis_name in rght.identifiers:
if axis_name not in axis_name2known_length:
if isinstance(axis_name, AnonymousAxis):
axis_name2known_length[axis_name] = axis_name.value
else:
axis_name2known_length[axis_name] = _unknown_axis_length
repeat_axes_names.append(axis_name)
axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)}
reduced_axes: List[int] = [position for axis, position in axis_name2position.items() if
axis not in rght.identifiers]
reduced_axes: List[int] = list(sorted(reduced_axes))
for elementary_axis, axis_length in axes_lengths:
if not ParsedExpression.check_axis_name(elementary_axis):
raise EinopsError('Invalid name for an axis', elementary_axis)
if elementary_axis not in axis_name2known_length:
raise EinopsError('Axis {} is not used in transform'.format(elementary_axis))
axis_name2known_length[elementary_axis] = axis_length
input_axes_known_unknown = []
# some of shapes will be inferred later - all information is prepared for faster inference
for composite_axis in left.composition:
known = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}
unknown = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}
if len(unknown) > 1:
raise EinopsError('Could not infer sizes for {}'.format(unknown))
assert len(unknown) + len(known) == len(composite_axis)
input_axes_known_unknown.append(
([axis_name2position[axis] for axis in known],
[axis_name2position[axis] for axis in unknown])
)
axis_position_after_reduction = {}
for axis_name in itertools.chain(*left.composition):
if axis_name in rght.identifiers:
axis_position_after_reduction[axis_name] = len(axis_position_after_reduction)
result_axes_grouping: List[List[int]] = []
for composite_axis in rght.composition:
if composite_axis == _ellipsis:
result_axes_grouping.append(_ellipsis_not_in_parenthesis)
else:
result_axes_grouping.append([axis_name2position[axis] for axis in composite_axis])
ordered_axis_right = list(itertools.chain(*rght.composition))
axes_permutation = [
axis_position_after_reduction[axis] for axis in ordered_axis_right if axis in left.identifiers]
added_axes = {i: axis_name2position[axis_name] for i, axis_name in enumerate(ordered_axis_right)
if axis_name not in left.identifiers}
ellipsis_left = None if _ellipsis not in left.composition else left.composition.index(_ellipsis)
return TransformRecipe(
elementary_axes_lengths=list(axis_name2known_length.values()),
input_composite_axes=input_axes_known_unknown,
reduced_elementary_axes=reduced_axes,
axes_permutation=axes_permutation,
added_axes=added_axes,
output_composite_axes=result_axes_grouping,
ellipsis_position_in_lhs=ellipsis_left,
)
def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor:
"""
einops.reduce provides combination of reordering and reduction using reader-friendly notation.
Examples for reduce operation:
```python
>>> x = np.random.randn(100, 32, 64)
# perform max-reduction on the first axis
>>> y = reduce(x, 't b c -> b c', 'max')
# same as previous, but with clearer axes meaning
>>> y = reduce(x, 'time batch channel -> batch channel', 'max')
>>> x = np.random.randn(10, 20, 30, 40)
# 2d max-pooling with kernel size = 2 * 2 for image processing
>>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
# if one wants to go back to the original height and width, depth-to-space trick can be applied
>>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2)
>>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w')
# Adaptive 2d max-pooling to 3 * 4 grid
>>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape
(10, 20, 3, 4)
# Global average pooling
>>> reduce(x, 'b c h w -> b c', 'mean').shape
(10, 20)
# Subtracting mean over batch for each channel
>>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean')
# Subtracting per-image mean for each channel
>>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean')
```
Parameters:
tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
list of tensors is also accepted, those should be of the same type and shape
pattern: string, reduction pattern
reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided.
axes_lengths: any additional specifications for dimensions
Returns:
tensor of the same type as input
"""
try:
hashable_axes_lengths = tuple(sorted(axes_lengths.items()))
recipe = _prepare_transformation_recipe(pattern, reduction, axes_lengths=hashable_axes_lengths)
return _apply_recipe(recipe, tensor, reduction_type=reduction)
except EinopsError as e:
message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern)
if not isinstance(tensor, list):
message += '\n Input tensor shape: {}. '.format(get_backend(tensor).shape(tensor))
else:
message += '\n Input is list. '
message += 'Additional info: {}.'.format(axes_lengths)
raise EinopsError(message + '\n {}'.format(e))
@typing.overload
def rearrange(tensor: Tensor, pattern: str, **axes_length: int) -> Tensor: ...
@typing.overload
def rearrange(tensor: List[Tensor], pattern: str, **axes_lengths: int) -> Tensor: ...
def rearrange(tensor, pattern: str, **axes_lengths):
"""
einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors.
This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
stack, concatenate and other operations.
Examples for rearrange operation:
```python
# suppose we have a set of 32 images in "h w c" format (height-width-channel)
>>> images = [np.random.randn(30, 40, 3) for _ in range(32)]
# stack along first (batch) axis, output is a single array
>>> rearrange(images, 'b h w c -> b h w c').shape
(32, 30, 40, 3)
# concatenate images along height (vertical axis), 960 = 32 * 30
>>> rearrange(images, 'b h w c -> (b h) w c').shape
(960, 40, 3)
# concatenated images along horizontal axis, 1280 = 32 * 40
>>> rearrange(images, 'b h w c -> h (b w) c').shape
(30, 1280, 3)
# reordered axes to "b c h w" format for deep learning
>>> rearrange(images, 'b h w c -> b c h w').shape
(32, 3, 30, 40)
# flattened each image into a vector, 3600 = 30 * 40 * 3
>>> rearrange(images, 'b h w c -> b (c h w)').shape
(32, 3600)
# split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
(128, 15, 20, 3)
# space-to-depth operation
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
(32, 15, 20, 12)
```
When composing axes, C-order enumeration used (consecutive elements have different last axis)
Find more examples in einops tutorial.
Parameters:
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
list of tensors is also accepted, those should be of the same type and shape
pattern: string, rearrangement pattern
axes_lengths: any additional specifications for dimensions
Returns:
tensor of the same type as input. If possible, a view to the original tensor is returned.
"""
if isinstance(tensor, list):
if len(tensor) == 0:
raise TypeError("Rearrange can't be applied to an empty list")
tensor = get_backend(tensor[0]).stack_on_zeroth_dimension(tensor)
return reduce(tensor, pattern, reduction='rearrange', **axes_lengths)
def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
"""
einops.repeat allows reordering elements and repeating them in arbitrary combinations.
This operation includes functionality of repeat, tile, broadcast functions.
Examples for repeat operation:
```python
# a grayscale image (of shape height x width)
>>> image = np.random.randn(30, 40)
# change it to RGB format by repeating in each channel
>>> repeat(image, 'h w -> h w c', c=3).shape
(30, 40, 3)
# repeat image 2 times along height (vertical axis)
>>> repeat(image, 'h w -> (repeat h) w', repeat=2).shape
(60, 40)
# repeat image 2 time along height and 3 times along width
>>> repeat(image, 'h w -> h (repeat w)', repeat=3).shape
(30, 120)
# convert each pixel to a small square 2x2. Upsample image by 2x
>>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
(60, 80)
# pixelate image first by downsampling by 2x, then upsampling
>>> downsampled = reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2)
>>> repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
(30, 40)
```
When composing axes, C-order enumeration used (consecutive elements have different last axis)
Find more examples in einops tutorial.
Parameters:
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
list of tensors is also accepted, those should be of the same type and shape
pattern: string, rearrangement pattern
axes_lengths: any additional specifications for dimensions
Returns:
Tensor of the same type as input. If possible, a view to the original tensor is returned.
"""
return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
def parse_shape(x, pattern: str):
"""
Parse a tensor shape to dictionary mapping axes names to their lengths.
```python
# Use underscore to skip the dimension in parsing.
>>> x = np.zeros([2, 3, 5, 7])
>>> parse_shape(x, 'batch _ h w')
{'batch': 2, 'h': 5, 'w': 7}
# `parse_shape` output can be used to specify axes_lengths for other operations:
>>> y = np.zeros([700])
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
(2, 10, 5, 7)
```
For symbolic frameworks may return symbols, not integers.
Parameters:
x: tensor of any of supported frameworks
pattern: str, space separated names for axes, underscore means skip axis
Returns:
dict, maps axes names to their lengths
"""
exp = ParsedExpression(pattern, allow_underscore=True)
shape = get_backend(x).shape(x)
if exp.has_composed_axes():
raise RuntimeError("Can't parse shape with composite axes: {pattern} {shape}".format(
pattern=pattern, shape=shape))
if len(shape) != len(exp.composition):
if exp.has_ellipsis:
if len(shape) < len(exp.composition) - 1:
raise RuntimeError("Can't parse shape with this number of dimensions: {pattern} {shape}".format(
pattern=pattern, shape=shape))
else:
raise RuntimeError("Can't parse shape with different number of dimensions: {pattern} {shape}".format(
pattern=pattern, shape=shape))
if exp.has_ellipsis:
ellipsis_idx = exp.composition.index(_ellipsis)
composition = (exp.composition[:ellipsis_idx] +
['_'] * (len(shape) - len(exp.composition) + 1) +
exp.composition[ellipsis_idx+1:])
else:
composition = exp.composition
result = {}
for (axis_name, ), axis_length in zip(composition, shape):
if axis_name != '_':
result[axis_name] = axis_length
return result
# this one is probably not needed in the public API
def _enumerate_directions(x):
"""
For an n-dimensional tensor, returns tensors to enumerate each axis.
```python
x = np.zeros([2, 3, 4]) # or any other tensor
i, j, k = _enumerate_directions(x)
result = i + 2*j + 3*k
```
`result[i, j, k] = i + 2j + 3k`, and also has the same shape as result
Works very similarly to numpy.ogrid (open indexing grid)
"""
backend = get_backend(x)
shape = backend.shape(x)
result = []
for axis_id, axis_length in enumerate(shape):
shape = [1] * len(shape)
shape[axis_id] = axis_length
result.append(backend.reshape(backend.arange(0, axis_length), shape))
return result
def asnumpy(tensor) -> 'numpy.ndarray':
"""
Convert a tensor of an imperative framework (i.e. numpy/jittor.) to `numpy.ndarray`
Parameters:
tensor: tensor of any of known imperative framework
Returns:
`numpy.ndarray`, converted to numpy
"""
return get_backend(tensor).to_numpy(tensor)

View File

@ -0,0 +1,393 @@
"""
Indexing one array with the other(s).
Concept for discussion.
Notation targets hard cases, not simple ones, like indexing of 1d-array with another 1d-array
(notation supports that, but you can't simplify arr[ind], and there is no reason to)
Examples
1. query for every token in sequence a token in the image. Images and sequences are paired
einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, [h_indices_bt, w_indices_bt])
this is equivalent, so you can pass indexers idependently or together
einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, np.asarray([h_indices_bt, w_indices_bt]))
after some thinking I decided that having first axis for indexing variable is not too restrictive,
but should simplify mapping of such cases.
For this reason [...] part should always go first in indexer.
This makes the largest difference with einindex https://github.com/malmaud/einindex,
which has almost identical grammar, but puts special dimension last, while we put it first.
This trick allows naturally decomposing multiindex into individual dimensions or visa versa.
2. query for every token in the video the most suitable word in a (matching) sentence
einindex('b t h w <- seq b, [seq] t b h w', arr_tbc, [t_indices_bhw])
note, that only one indexer is used, but still it has to be enclosed in the list.
That's a price for being generic. Alternatively leading singleton dimension can be added.
3. (not supported now, future planning)
for every timeframe in a video, find the token with the highest norm (across h and w), and compose a new stack of them
indices_2bt = argmax(x_bthwc.norm(dim=-1), 'b t h w -> [h, w] b t')
selected_embeddings_btc = einindex('b t c <- b t h w c, [h, w] b t', x_bthwc, indices_2bt)
while currently question is around 'how do we index',
it is important to pre-align that with a question 'what are natural ways to get indices'.
Most common are min/max. less common options: topk (works here), random sampling.
Some important properties of this notation:
- support for multiple indexers, including using a single tensor to keep multiple indexers
- 'batch' indexing, when some axes of indexer and array should be matched
- universal (one-indexing-to-rule-them-all)
- extensible for (named) ellipses, including variadic number of indexers
- extensible for einops-style compositions and decompositions
- extensible for outer indexing when indexers are not aligned
Current implementation based on python array api and uses loops,
because no appropriate indexing available in the standard.
"""
from typing import List, Union, TypeVar, Tuple
from jittor.einops import EinopsError
T = TypeVar('T')
class CompositionDecomposition:
def __init__(
self,
decomposed_shape: List[str],
composed_shape: List[List[str]],
):
flat_shape = []
for x in composed_shape:
flat_shape.extend(x)
self.compose_transposition: Tuple[int] = tuple([decomposed_shape.index(x) for x in flat_shape])
self.decompose_transposition: Tuple[int] = tuple([flat_shape.index(x) for x in decomposed_shape])
self.composed_shape = composed_shape
self.decomposed_shape = decomposed_shape
def decompose(self, x, known_axes_lengths: dict[str, int]):
xp = x.__array_namespace__()
shape = x.shape
flat_shape = []
for i, axis_group in enumerate(self.composed_shape):
unknown_axis_name = None
known_sizes_prod = 1
for axis_name in axis_group:
if axis_name in known_axes_lengths:
known_sizes_prod *= known_axes_lengths[axis_name]
else:
if unknown_axis_name is None:
unknown_axis_name = axis_name
else:
raise EinopsError("Can't infer the size")
if unknown_axis_name is None:
assert shape[i] == known_sizes_prod
else:
known_axes_lengths[unknown_axis_name] = shape[i] // known_sizes_prod
for axis in axis_group:
flat_shape.append(known_axes_lengths[axis])
x = xp.reshape(x, flat_shape)
return xp.permute_dims(x, self.decompose_transposition)
def compose(self, x, known_axes_lengths: dict[str, int]):
xp = x.__array_namespace__()
for axis_len, axis_name in zip(x.shape, self.decomposed_shape):
if axis_name in known_axes_lengths:
assert known_axes_lengths[axis_name] == axis_len
else:
known_axes_lengths[axis_name] = axis_len
x = xp.permute_dims(x, self.compose_transposition)
new_shape = []
for axis_group in self.composed_shape:
composed_axis_size = 1
for axis_name in axis_group:
composed_axis_size *= known_axes_lengths[axis_name]
new_shape.append(composed_axis_size)
return xp.reshape(x, tuple(new_shape))
def arange_at_position(xp, n_axes, axis, axis_len, device=None):
x = xp.arange(axis_len, dtype=xp.int64, device=device)
shape = [1] * n_axes
shape[axis] = axis_len
x = xp.reshape(x, shape)
return x
class IndexingFormula:
def __init__(self, pattern: str):
"""
:param pattern: example 'b t c <- b hsel wsel c, [hsel, wsel] b t'
"""
self.pattern = pattern
left, right = pattern.split('<-')
arg_split = right.index(',')
arr_pattern, ind_pattern = right[:arg_split], right[arg_split + 1:]
ind_pattern = ind_pattern.strip()
# print(
# arr_pattern, '\n',
# ind_pattern,
# )
assert ind_pattern.startswith('['), 'composition axis should go first in indexer (second argument) [h w] i j k'
composition_start = ind_pattern.index('[')
composition_end = ind_pattern.index(']')
composition = ind_pattern[composition_start + 1: composition_end]
ind_other_axes = ind_pattern[composition_end + 1:]
self.result_axes_names = left.split()
self.array_axes_names = arr_pattern.split()
self.indexing_axes_names = [x.strip() for x in composition.split(',')]
self.indexer_other_axes_names = ind_other_axes.split()
for group_name, group in [
('result', self.result_axes_names),
('array', self.array_axes_names),
('indexer', self.indexing_axes_names + self.indexer_other_axes_names),
]:
if len(set(group)) != len(group):
# need more verbosity, which axis, raise
raise EinopsError(f'{group_name} pattern ({group}) contains a duplicated axis')
axis_groups = [
self.result_axes_names,
self.array_axes_names,
self.indexing_axes_names,
self.indexer_other_axes_names,
]
all_axes = set()
for group in axis_groups:
all_axes.update(group)
self.indexer_axes = []
self.batch_axes = []
self.result_and_index_axes = []
self.result_and_array_axes = []
for axis in all_axes:
presence = tuple(axis in g for g in axis_groups)
# want match-case here. sweet dreams
if presence == (False, True, True, False):
self.indexer_axes.append(axis)
elif presence[2]:
raise EinopsError(f'Wrong usage of indexer variable {axis}')
elif presence == (True, True, False, True):
self.batch_axes.append(axis)
elif presence == (True, False, False, True):
self.result_and_index_axes.append(axis)
elif presence == (True, True, False, False):
self.result_and_array_axes.append(axis)
else:
# TODO better categorization of wrong usage patterns
raise EinopsError(f'{axis} is used incorrectly in {pattern}')
assert set(self.indexer_axes) == set(self.indexing_axes_names)
# order of these variables matters, since we can't lose mapping here
self.indexer_axes = self.indexing_axes_names
self.array_composition = CompositionDecomposition(
decomposed_shape=self.array_axes_names,
composed_shape=[self.batch_axes + self.indexer_axes, self.result_and_array_axes],
)
self.index_composition = CompositionDecomposition(
decomposed_shape=self.indexer_other_axes_names,
# single axis after composition
composed_shape=[self.batch_axes + self.result_and_index_axes],
)
self.result_composition = CompositionDecomposition(
decomposed_shape=self.result_axes_names,
composed_shape=[self.batch_axes + self.result_and_index_axes, self.result_and_array_axes],
)
def apply_to_array_api(self, arr: T, ind: Union[T, List[T]]):
known_axes_sizes: dict[str, int] = {}
xp = arr.__array_namespace__()
if not isinstance(ind, list):
ind = [ind[i, ...] for i in range(ind.shape[0])]
for indexer in ind:
assert len(indexer.shape) == len(self.indexer_other_axes_names)
# step 1. transpose, reshapes of arr; learn its dimensions
arr_2d = self.array_composition.compose(arr, known_axes_sizes)
# step 2. compute shifts and create an actual indexing array
shift = 1
full_index = xp.zeros([1] * len(ind[0].shape), dtype=xp.int64, device=arr.device)
# original order: [*batch-like axes, *indexing_axes,]
# now we need to traverse them in the opposite direction
for axis_name, indexer in list(zip(self.indexing_axes_names, ind))[::-1]:
full_index = full_index + shift * (indexer % known_axes_sizes[axis_name])
shift *= known_axes_sizes[axis_name]
for axis_name in self.batch_axes[::-1]:
axis_id = self.indexer_other_axes_names.index(axis_name)
full_index = full_index + arange_at_position(
xp, len(self.indexer_other_axes_names), axis=axis_id, axis_len=known_axes_sizes[axis_name],
device=arr.device,
) * shift
shift *= known_axes_sizes[axis_name]
assert shift == arr_2d.shape[0]
# step 3. Flatten index
full_index = self.index_composition.compose(full_index, known_axes_sizes)
# step 4. indexing
# python array api lacks any integer indexing, so... I use loops.
# did you know that there is conceptual programming ... just like art?
# result_2d = arr_2d[full_index]
result_2d = xp.stack([arr_2d[full_index[i], :] for i in range(full_index.shape[0])])
# step 5. doing resulting
result = self.result_composition.decompose(result_2d, known_axes_sizes)
return result
def einindex(pattern: str, arr: T, /, ind: Union[T, List[T]]):
"""
Demonstrates how einindex should work.
Supports data-api compliant arrays.
"""
formula = IndexingFormula(pattern)
return formula.apply_to_array_api(arr, ind)
def test_composition_and_decomposition():
import numpy.array_api as np
x = np.arange(2 * 3 * 5 * 7)
x = np.reshape(x, (2, 3, 5, 7))
comp = CompositionDecomposition(
decomposed_shape=['a', 'b', 'c', 'd'],
composed_shape=[['a', 'b'], ['c', 'd']],
)
assert comp.compose(x, known_axes_lengths={}).shape == (2 * 3, 5 * 7)
y = CompositionDecomposition(
decomposed_shape=['a', 'b', 'c', 'd'],
composed_shape=[['a', 'b'], [], ['c', 'd']],
).compose(x, {})
assert y.shape == (2 * 3, 1, 5 * 7)
assert np.all(np.reshape(x, (-1,)) == np.reshape(y, (-1,)))
comp = CompositionDecomposition(
decomposed_shape=['a', 'b', 'e', 'c', 'd'],
composed_shape=[['e', 'c'], ['b'], ['a', 'd']],
)
x = np.arange(2 * 3 * 5 * 7 * 3)
x = np.reshape(x, (2, 3, 5, 7, 3))
axes = {}
y = comp.compose(x, axes)
x2 = comp.decompose(y, axes)
assert np.all(x == x2)
def test_simple_indexing():
import numpy.array_api as np
# simple 2d test
arr = np.reshape(np.arange(5 * 7), (5, 7))
ind = np.arange(7) % 5
x = einindex('j <- i j, [i] j', arr, [ind])
for j, i in enumerate(ind):
assert arr[i, j] == x[j]
y = einindex('j <- j i, [i] j', np.permute_dims(arr, (1, 0)), [ind])
for j, i in enumerate(ind):
assert arr[i, j] == y[j]
def test_multidimensional_indexing():
import numpy.array_api as np
embedding_bhwc = (
+ arange_at_position(np, 4, 0, 2) * 1000
+ arange_at_position(np, 4, 1, 3) * 100
+ arange_at_position(np, 4, 2, 5) * 10
+ arange_at_position(np, 4, 3, 7) * 1
)
hindices_bt = np.reshape(np.arange(6), (2, 3)) % 3
windices_bt = np.reshape(np.arange(6), (2, 3)) % 5
# imagine that you have pairs of image <> sentence
# your goal is to get most suitable token from image for every token in sentence
# thus for every token in sentence you compute best k and v
result = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, [hindices_bt, windices_bt])
# example of using a single array for indexing multiple axes
hw_indices_bt = np.stack([hindices_bt, windices_bt])
result2 = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, hw_indices_bt)
assert np.all(result == result2)
# check vs manual element computation
result_manual = result * 0
for b in range(2):
for t in range(3):
for c in range(7):
h = hindices_bt[b, t]
w = windices_bt[b, t]
result_manual[c, t, b] = embedding_bhwc[b, h, w, c]
assert np.all(result == result_manual)
def test_reverse_indexing():
import numpy.array_api as np
C, T, B = 2, 3, 5
# G = GPU, batch-like varaible
G = 4
H = 7
W = 9
arr_gtbc = (
+ arange_at_position(np, 4, 0, G) * 1000
+ arange_at_position(np, 4, 1, T) * 100
+ arange_at_position(np, 4, 2, B) * 10
+ arange_at_position(np, 4, 3, C) * 1
)
t_indices_gbhw = np.reshape(np.arange(G * B * H * W), (G, B, H, W)) % T
result = einindex('g b c h w <- g t b c, [t] g b h w', arr_gtbc, [t_indices_gbhw])
result_manual = result * 0
for g in range(G):
for b in range(B):
for c in range(C):
for h in range(H):
for w in range(W):
t = t_indices_gbhw[g, b, h, w]
result_manual[g, b, c, h, w] = arr_gtbc[g, t, b, c]
assert np.all(result == result_manual)

View File

@ -0,0 +1,79 @@
__author__ = 'Alex Rogozhnikov'
import functools
from jittor.einops.einops import _apply_recipe
from jittor.einops.einops import TransformRecipe, _prepare_transformation_recipe
from jittor.einops import EinopsError
class RearrangeMixin:
"""
Rearrange layer behaves identically to einops.rearrange operation.
:param pattern: str, rearrangement pattern
:param axes_lengths: any additional specification of dimensions
See einops.rearrange for source_examples.
"""
def __init__(self, pattern, **axes_lengths):
super().__init__()
self.pattern = pattern
self.axes_lengths = axes_lengths
self._recipe = self.recipe() # checking parameters
def __repr__(self):
params = repr(self.pattern)
for axis, length in self.axes_lengths.items():
params += ', {}={}'.format(axis, length)
return '{}({})'.format(self.__class__.__name__, params)
@functools.lru_cache(maxsize=1024)
def recipe(self) -> TransformRecipe:
try:
hashable_lengths = tuple(sorted(self.axes_lengths.items()))
return _prepare_transformation_recipe(self.pattern, operation='rearrange', axes_lengths=hashable_lengths)
except EinopsError as e:
raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e))
def _apply_recipe(self, x):
return _apply_recipe(self._recipe, x, reduction_type='rearrange')
class ReduceMixin:
"""
Reduce layer behaves identically to einops.reduce operation.
:param pattern: str, rearrangement pattern
:param reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
:param axes_lengths: any additional specification of dimensions
See einops.reduce for source_examples.
"""
def __init__(self, pattern, reduction, **axes_lengths):
super().__init__()
self.pattern = pattern
self.reduction = reduction
self.axes_lengths = axes_lengths
self._recipe = self.recipe() # checking parameters
def __repr__(self):
params = '{!r}, {!r}'.format(self.pattern, self.reduction)
for axis, length in self.axes_lengths.items():
params += ', {}={}'.format(axis, length)
return '{}({})'.format(self.__class__.__name__, params)
@functools.lru_cache(maxsize=1024)
def recipe(self) -> TransformRecipe:
try:
hashable_lengths = tuple(sorted(self.axes_lengths.items()))
return _prepare_transformation_recipe(
self.pattern, operation=self.reduction, axes_lengths=hashable_lengths)
except EinopsError as e:
raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e))
def _apply_recipe(self, x):
return _apply_recipe(self._recipe, x, reduction_type=self.reduction)

View File

@ -0,0 +1,175 @@
from typing import Optional, Dict
from jittor.einops import EinopsError
from jittor.einops.parsing import ParsedExpression
import warnings
import string
from jittor.einops.einops import _product
def _report_axes(axes: set, report_message: str):
if len(axes) > 0:
raise EinopsError(report_message.format(axes))
class _EinmixMixin:
def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths):
"""
EinMix - Einstein summation with automated tensor management and axis packing/unpacking.
EinMix is an advanced tool, helpful tutorial:
https://github.com/arogozhnikov/einops/blob/master/docs/3-einmix-layer.ipynb
Imagine taking einsum with two arguments, one of each input, and one - tensor with weights
>>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight)
This layer manages weights for you, syntax highlights separate role of weight matrix
>>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out')
But otherwise it is the same einsum under the hood.
Simple linear layer with bias term (you have one like that in your framework)
>>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20)
There is restriction to mix the last axis. Let's mix along height
>>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32)
Channel-wise multiplication (like one used in normalizations)
>>> EinMix('t b c -> t b c', weight_shape='c', c=128)
Separate dense layer within each head, no connection between different heads
>>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...)
... ah yes, you need to specify all dimensions of weight shape/bias shape in parameters.
Use cases:
- when channel dimension is not last, use EinMix, not transposition
- patch/segment embeddings
- when need only within-group connections to reduce number of weights and computations
- perfect as a part of sequential models
- next-gen MLPs (follow tutorial to learn more)
Uniform He initialization is applied to weight tensor and encounters for number of elements mixed.
Parameters
:param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output
:param weight_shape: axes of weight. Tensor od this shape is created, stored, and optimized in a layer
:param bias_shape: axes of bias added to output.
:param axes_lengths: dimensions of weight tensor
"""
super().__init__()
self.pattern = pattern
self.weight_shape = weight_shape
self.bias_shape = bias_shape
self.axes_lengths = axes_lengths
left_pattern, right_pattern = pattern.split('->')
left = ParsedExpression(left_pattern)
right = ParsedExpression(right_pattern)
weight = ParsedExpression(weight_shape)
_report_axes(
set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}),
'Unrecognized identifiers on the right side of EinMix {}'
)
if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis:
raise EinopsError('Ellipsis is not supported in EinMix (right now)')
if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]):
raise EinopsError('Anonymous axes (numbers) are not allowed in EinMix')
if '(' in weight_shape or ')' in weight_shape:
raise EinopsError(f'Parenthesis is not allowed in weight shape: {weight_shape}')
pre_reshape_pattern = None
pre_reshape_lengths = None
post_reshape_pattern = None
if any(len(group) != 1 for group in left.composition):
names = []
for group in left.composition:
names += group
composition = ' '.join(names)
pre_reshape_pattern = f'{left_pattern}->{composition}'
pre_reshape_lengths = {name: length for name, length in self.axes_lengths.items() if name in names}
if any(len(group) != 1 for group in right.composition):
names = []
for group in right.composition:
names += group
composition = ' '.join(names)
post_reshape_pattern = f'{composition}->{right_pattern}'
self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {})
for axis in weight.identifiers:
if axis not in axes_lengths:
raise EinopsError('Dimension {} of weight should be specified'.format(axis))
_report_axes(
set.difference(set(axes_lengths), {*left.identifiers, *weight.identifiers}),
'Axes {} are not used in pattern',
)
_report_axes(
set.difference(weight.identifiers, {*left.identifiers, *right.identifiers}),
'Weight axes {} are redundant'
)
if len(weight.identifiers) == 0:
warnings.warn('EinMix: weight has no dimensions (means multiplication by a number)')
_weight_shape = [axes_lengths[axis] for axis, in weight.composition]
# single output element is a combination of fan_in input elements
_fan_in = _product([axes_lengths[axis] for axis, in weight.composition if axis not in right.identifiers])
if bias_shape is not None:
if not isinstance(bias_shape, str):
raise EinopsError('bias shape should be string specifying which axes bias depends on')
bias = ParsedExpression(bias_shape)
_report_axes(
set.difference(bias.identifiers, right.identifiers),
'Bias axes {} not present in output'
)
_report_axes(
set.difference(bias.identifiers, set(axes_lengths)),
'Sizes not provided for bias axes {}',
)
_bias_shape = []
for axes in right.composition:
for axis in axes:
if axis in bias.identifiers:
_bias_shape.append(axes_lengths[axis])
else:
_bias_shape.append(1)
else:
_bias_shape = None
_bias_input_size = None
weight_bound = (3 / _fan_in) ** 0.5
bias_bound = (1 / _fan_in) ** 0.5
self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound)
# rewrite einsum expression with single-letter latin identifiers so that
# expression will be understood by any framework
mapping2letters = {*left.identifiers, *right.identifiers, *weight.identifiers}
mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapping2letters)}
def write_flat(axes: list):
return ''.join(mapping2letters[axis] for axis in axes)
self.einsum_pattern: str = '{},{}->{}'.format(
write_flat(left.flat_axes_order()),
write_flat(weight.flat_axes_order()),
write_flat(right.flat_axes_order()),
)
def _create_rearrange_layers(self,
pre_reshape_pattern: Optional[str],
pre_reshape_lengths: Optional[Dict],
post_reshape_pattern: Optional[str],
post_reshape_lengths: Optional[Dict]):
raise NotImplementedError('Should be defined in framework implementations')
def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
""" Shape and implementations """
raise NotImplementedError('Should be defined in framework implementations')
def __repr__(self):
params = repr(self.pattern)
params += f", '{self.weight_shape}'"
if self.bias_shape is not None:
params += f", '{self.bias_shape}'"
for axis, length in self.axes_lengths.items():
params += ', {}={}'.format(axis, length)
return '{}({})'.format(self.__class__.__name__, params)

View File

@ -0,0 +1,64 @@
from typing import Optional, Dict
import jittor as jt
from jittor import nn
import numpy as np
from jittor.einops.layers import RearrangeMixin, ReduceMixin
from jittor.einops.layers._einmix import _EinmixMixin
from jittor.einops._jittor_specific import apply_for_scriptable_jittor
__author__ = 'Ruiyang Liu'
class Rearrange(RearrangeMixin, jt.nn.Module):
def execute(self, input):
return apply_for_scriptable_jittor(self._recipe, input, reduction_type='rearrange')
def _apply_recipe(self, x):
# overriding parent method to prevent it's scripting
pass
class Reduce(ReduceMixin, jt.nn.Module):
def execute(self, input):
return apply_for_scriptable_jittor(self._recipe, input, reduction_type=self.reduction)
def _apply_recipe(self, x):
# overriding parent method to prevent it's scripting
pass
class EinMix(_EinmixMixin, jt.nn.Module):
def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
self.weight = jt.zeros(weight_shape)
nn.init.uniform_(self.weight, low = -weight_bound, high = weight_bound)
if bias_shape is not None:
self.bias = jt.zeros(bias_shape)
nn.init.uniform_(self.bias, low = -bias_bound, high = bias_bound)
else:
self.bias = None
def _create_rearrange_layers(self,
pre_reshape_pattern: Optional[str],
pre_reshape_lengths: Optional[Dict],
post_reshape_pattern: Optional[str],
post_reshape_lengths: Optional[Dict],
):
self.pre_rearrange = None
if pre_reshape_pattern is not None:
self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths)
self.post_rearrange = None
if post_reshape_pattern is not None:
self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths)
def execute(self, input):
if self.pre_rearrange is not None:
input = self.pre_rearrange(input)
result = jt.linalg.einsum(self.einsum_pattern, input, self.weight)
if self.bias is not None:
result += self.bias
if self.post_rearrange is not None:
result = self.post_rearrange(result)
return result

View File

@ -0,0 +1,147 @@
from jittor.einops import EinopsError
import keyword
import warnings
from typing import List, Optional, Set, Tuple
_ellipsis: str = '' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
class AnonymousAxis(object):
"""Important thing: all instances of this class are not equal to each other """
def __init__(self, value: str):
self.value = int(value)
if self.value <= 1:
if self.value == 1:
raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue')
else:
raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value))
def __repr__(self):
return "{}-axis".format(str(self.value))
class ParsedExpression:
"""
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
and keeps some information important for downstream
"""
def __init__(self, expression, *, allow_underscore: bool = False):
self.has_ellipsis: bool = False
self.has_ellipsis_parenthesized: Optional[bool] = None
self.identifiers: Set[str] = set()
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
self.has_non_unitary_anonymous_axes: bool = False
# composition keeps structure of composite axes, see how different corner cases are handled in tests
self.composition = []
if '.' in expression:
if '...' not in expression:
raise EinopsError('Expression may contain dots only inside ellipsis (...)')
if str.count(expression, '...') != 1 or str.count(expression, '.') != 3:
raise EinopsError(
'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ')
expression = expression.replace('...', _ellipsis)
self.has_ellipsis = True
bracket_group = None
def add_axis_name(x):
if x is not None:
if x in self.identifiers:
if not (allow_underscore and x == "_"):
raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
if x == _ellipsis:
self.identifiers.add(_ellipsis)
if bracket_group is None:
self.composition.append(_ellipsis)
self.has_ellipsis_parenthesized = False
else:
bracket_group.append(_ellipsis)
self.has_ellipsis_parenthesized = True
else:
is_number = str.isdecimal(x)
if is_number and int(x) == 1:
# handling the case of anonymous axis of length 1
if bracket_group is None:
self.composition.append([])
else:
pass # no need to think about 1s inside parenthesis
return
is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore)
if not (is_number or is_axis_name):
raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason))
if is_number:
x = AnonymousAxis(x)
self.identifiers.add(x)
if is_number:
self.has_non_unitary_anonymous_axes = True
if bracket_group is None:
self.composition.append([x])
else:
bracket_group.append(x)
current_identifier = None
for char in expression:
if char in '() ':
add_axis_name(current_identifier)
current_identifier = None
if char == '(':
if bracket_group is not None:
raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)")
bracket_group = []
elif char == ')':
if bracket_group is None:
raise EinopsError('Brackets are not balanced')
self.composition.append(bracket_group)
bracket_group = None
elif str.isalnum(char) or char in ['_', _ellipsis]:
if current_identifier is None:
current_identifier = char
else:
current_identifier += char
else:
raise EinopsError("Unknown character '{}'".format(char))
if bracket_group is not None:
raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression))
add_axis_name(current_identifier)
def flat_axes_order(self) -> List:
result = []
for composed_axis in self.composition:
assert isinstance(composed_axis, list), 'does not work with ellipsis'
for axis in composed_axis:
result.append(axis)
return result
def has_composed_axes(self) -> bool:
# this will ignore 1 inside brackets
for axes in self.composition:
if isinstance(axes, list) and len(axes) > 1:
return True
return False
@staticmethod
def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]:
if not str.isidentifier(name):
return False, 'not a valid python identifier'
elif name[0] == '_' or name[-1] == '_':
if name == '_' and allow_underscore:
return True, ''
return False, 'axis name should should not start or end with underscore'
else:
if keyword.iskeyword(name):
warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning)
if name in ['axis']:
warnings.warn("It is discouraged to use 'axis' as an axis name "
"and will raise an error in future", FutureWarning)
return True, ''
@staticmethod
def check_axis_name(name: str) -> bool:
"""
Valid axes names are python identifiers except keywords,
and additionally should not start or end with underscore
"""
is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name)
return is_valid