mirror of https://github.com/Jittor/Jittor
Support einops for Jittor
This commit is contained in:
parent
8c9bfb639d
commit
13f9eaafc0
|
@ -0,0 +1,8 @@
|
|||
class EinopsError(RuntimeError):
|
||||
""" Runtime error thrown by einops """
|
||||
pass
|
||||
|
||||
|
||||
__all__ = ['rearrange', 'reduce', 'repeat', 'parse_shape', 'asnumpy', 'EinopsError']
|
||||
|
||||
from jittor.einops.einops import rearrange, reduce, repeat, parse_shape, asnumpy
|
|
@ -0,0 +1,255 @@
|
|||
"""
|
||||
Backends in `einops` are organized to meet the following requirements
|
||||
- backends are not imported unless those are actually needed, because
|
||||
- backends may not be installed
|
||||
- importing all available backends will drive to significant memory footprint
|
||||
- backends may by present but installed with errors (but never used),
|
||||
importing may drive to crashes
|
||||
- backend should be either symbolic or imperative (tensorflow is for both, but that causes problems)
|
||||
- this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined
|
||||
- if backend can't (temporarily) provide symbols for shape dimensions, UnknownSize objects are used
|
||||
"""
|
||||
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
__author__ = 'Alex Rogozhnikov'
|
||||
|
||||
_backends = {}
|
||||
_debug_importing = False
|
||||
|
||||
|
||||
def get_backend(tensor) -> 'AbstractBackend':
|
||||
"""
|
||||
Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
|
||||
If needed, imports package and creates backend
|
||||
"""
|
||||
for framework_name, backend in _backends.items():
|
||||
if backend.is_appropriate_type(tensor):
|
||||
return backend
|
||||
|
||||
# Find backend subclasses recursively
|
||||
backend_subclasses = []
|
||||
backends = AbstractBackend.__subclasses__()
|
||||
while backends:
|
||||
backend = backends.pop()
|
||||
backends += backend.__subclasses__()
|
||||
backend_subclasses.append(backend)
|
||||
|
||||
for BackendSubclass in backend_subclasses:
|
||||
if _debug_importing:
|
||||
print('Testing for subclass of ', BackendSubclass)
|
||||
if BackendSubclass.framework_name not in _backends:
|
||||
# check that module was already imported. Otherwise it can't be imported
|
||||
if BackendSubclass.framework_name in sys.modules:
|
||||
if _debug_importing:
|
||||
print('Imported backend for ', BackendSubclass.framework_name)
|
||||
backend = BackendSubclass()
|
||||
_backends[backend.framework_name] = backend
|
||||
if backend.is_appropriate_type(tensor):
|
||||
return backend
|
||||
|
||||
raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))
|
||||
|
||||
|
||||
class AbstractBackend:
|
||||
""" Base backend class, major part of methods are only for debugging purposes. """
|
||||
framework_name = None
|
||||
|
||||
def is_appropriate_type(self, tensor):
|
||||
""" helper method should recognize tensors it can handle """
|
||||
raise NotImplementedError()
|
||||
|
||||
def from_numpy(self, x):
|
||||
raise NotImplementedError("framework doesn't support imperative execution")
|
||||
|
||||
def to_numpy(self, x):
|
||||
raise NotImplementedError("framework doesn't support imperative execution")
|
||||
|
||||
def create_symbol(self, shape):
|
||||
raise NotImplementedError("framework doesn't support symbolic computations")
|
||||
|
||||
def eval_symbol(self, symbol, input_dict):
|
||||
raise NotImplementedError("framework doesn't support symbolic computations")
|
||||
|
||||
def arange(self, start, stop):
|
||||
# supplementary method used only in testing, so should implement CPU version
|
||||
raise NotImplementedError("framework doesn't implement arange")
|
||||
|
||||
def shape(self, x):
|
||||
"""shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)"""
|
||||
return x.shape
|
||||
|
||||
def reshape(self, x, shape):
|
||||
return x.reshape(shape)
|
||||
|
||||
def transpose(self, x, axes):
|
||||
return x.transpose(axes)
|
||||
|
||||
def reduce(self, x, operation, axes):
|
||||
return getattr(x, operation)(axis=axes)
|
||||
|
||||
def stack_on_zeroth_dimension(self, tensors: list):
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_axis(self, x, new_position):
|
||||
raise NotImplementedError()
|
||||
|
||||
def add_axes(self, x, n_axes, pos2len):
|
||||
repeats = [1] * n_axes
|
||||
for axis_position, axis_length in pos2len.items():
|
||||
x = self.add_axis(x, axis_position)
|
||||
repeats[axis_position] = axis_length
|
||||
return self.tile(x, tuple(repeats))
|
||||
|
||||
def tile(self, x, repeats):
|
||||
"""repeats is a number of """
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_float_type(self, x):
|
||||
# some backends (torch) can't compute average for non-floating types.
|
||||
# Decided to drop average for all backends if type is not floating
|
||||
raise NotImplementedError()
|
||||
|
||||
def layers(self):
|
||||
raise NotImplementedError("backend does not provide layers")
|
||||
|
||||
def __repr__(self):
|
||||
return "<einops backend for {}>".format(self.framework_name)
|
||||
|
||||
|
||||
class UnknownSize:
|
||||
""" pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """
|
||||
|
||||
def __floordiv__(self, other):
|
||||
return self
|
||||
|
||||
def __eq__(self, other):
|
||||
return True # we don't know actual size
|
||||
|
||||
def __mul__(self, other):
|
||||
return self
|
||||
|
||||
def __rmul__(self, other):
|
||||
return self
|
||||
|
||||
def __hash__(self):
|
||||
return None.__hash__()
|
||||
|
||||
|
||||
class NumpyBackend(AbstractBackend):
|
||||
framework_name = 'numpy'
|
||||
|
||||
def __init__(self):
|
||||
import numpy
|
||||
self.np = numpy
|
||||
|
||||
def is_appropriate_type(self, tensor):
|
||||
return isinstance(tensor, self.np.ndarray)
|
||||
|
||||
def from_numpy(self, x):
|
||||
return x
|
||||
|
||||
def to_numpy(self, x):
|
||||
return x
|
||||
|
||||
def arange(self, start, stop):
|
||||
return self.np.arange(start, stop)
|
||||
|
||||
def stack_on_zeroth_dimension(self, tensors: list):
|
||||
return self.np.stack(tensors)
|
||||
|
||||
def tile(self, x, repeats):
|
||||
return self.np.tile(x, repeats)
|
||||
|
||||
def is_float_type(self, x):
|
||||
return x.dtype in ('float16', 'float32', 'float64', 'float128')
|
||||
|
||||
def add_axis(self, x, new_position):
|
||||
return self.np.expand_dims(x, new_position)
|
||||
|
||||
class HashableTuple:
|
||||
"""Overcomes non-hashability of symbolic elements"""
|
||||
|
||||
def __init__(self, elements: tuple):
|
||||
self.elements = elements
|
||||
|
||||
def __iter__(self):
|
||||
for x in self.elements:
|
||||
yield x
|
||||
|
||||
def __len__(self):
|
||||
return len(self.elements)
|
||||
|
||||
def __getitem__(self, item):
|
||||
return self.elements[item]
|
||||
|
||||
class JittorBackend(AbstractBackend):
|
||||
framework_name = 'jittor'
|
||||
|
||||
def __init__(self):
|
||||
import jittor
|
||||
self.jittor = jittor
|
||||
|
||||
def is_appropriate_type(self, tensor):
|
||||
return isinstance(tensor, self.jittor.jittor_core.Var)
|
||||
|
||||
def from_numpy(self, x):
|
||||
variable = self.jittor.array(x)
|
||||
if self.is_float_type(variable):
|
||||
# attach grad only to floating types
|
||||
variable.requires_grad = True
|
||||
return variable
|
||||
|
||||
def to_numpy(self, x):
|
||||
return x.detach().numpy()
|
||||
|
||||
def arange(self, start, stop):
|
||||
return self.jittor.arange(start, stop, dtype='int64')
|
||||
|
||||
def shape(self, x):
|
||||
return HashableTuple(tuple(x.shape))
|
||||
|
||||
def reshape(self, x, shape):
|
||||
return self.jittor.reshape(x, shape)
|
||||
|
||||
# def reduce(self, x, operation, axes):
|
||||
# return getattr(x, operation)(dim=axes)
|
||||
|
||||
def reduce(self, x, operation, reduced_axes):
|
||||
for axis in sorted(reduced_axes, reverse=True):
|
||||
if operation == 'min':
|
||||
x = x.min(dim=axis)
|
||||
elif operation == 'max':
|
||||
x = x.max(dim=axis)
|
||||
elif operation in ['sum', 'mean', 'prod']:
|
||||
x = getattr(x, operation)(dim=axis)
|
||||
else:
|
||||
raise NotImplementedError('Unknown reduction ', operation)
|
||||
return x
|
||||
|
||||
def transpose(self, x, axes):
|
||||
return x.permute(axes)
|
||||
|
||||
def stack_on_zeroth_dimension(self, tensors: list):
|
||||
return self.jittor.stack(tensors)
|
||||
|
||||
def add_axes(self, x, n_axes, pos2len):
|
||||
repeats = [-1] * n_axes
|
||||
for axis_position, axis_length in pos2len.items():
|
||||
x = self.add_axis(x, axis_position)
|
||||
repeats[axis_position] = axis_length
|
||||
return x.expand(repeats)
|
||||
|
||||
def tile(self, x, repeats):
|
||||
return x.repeat(repeats)
|
||||
|
||||
def add_axis(self, x, new_position):
|
||||
return self.jittor.unsqueeze(x, new_position)
|
||||
|
||||
def is_float_type(self, x):
|
||||
return x.dtype in ["float16", "float32", "float64"]
|
||||
|
||||
def layers(self):
|
||||
from jittor.einops.layers import jittor
|
||||
return jittor
|
|
@ -0,0 +1,84 @@
|
|||
"""
|
||||
Specialization of einops for jittor.
|
||||
|
||||
Unfortunately, jittor's jit scripting mechanism isn't strong enough,
|
||||
and to have scripting supported at least for layers,
|
||||
a number of changes is required, and this layer helps.
|
||||
|
||||
Importantly, whole lib is designed so that you can't use it
|
||||
"""
|
||||
|
||||
from typing import Dict, List
|
||||
|
||||
import jittor as jt
|
||||
from jittor.einops.einops import TransformRecipe, _reconstruct_from_shape_uncached
|
||||
|
||||
|
||||
class JittorJitBackend:
|
||||
"""
|
||||
Completely static backend that mimics part of normal backend functionality
|
||||
but restricted to jittor stuff only
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def reduce(x: jt.jittor_core.Var, operation: str, reduced_axes: List[int]):
|
||||
if operation == 'min':
|
||||
return x.min(dims=reduced_axes)
|
||||
elif operation == 'max':
|
||||
return x.max(dims=reduced_axes)
|
||||
elif operation == 'sum':
|
||||
return x.sum(dims=reduced_axes)
|
||||
elif operation == 'mean':
|
||||
return x.mean(dims=reduced_axes)
|
||||
elif operation == 'prod':
|
||||
for i in list(sorted(reduced_axes))[::-1]:
|
||||
x = x.prod(dim=i)
|
||||
return x
|
||||
else:
|
||||
raise NotImplementedError('Unknown reduction ', operation)
|
||||
|
||||
@staticmethod
|
||||
def transpose(x, axes: List[int]):
|
||||
return x.permute(axes)
|
||||
|
||||
@staticmethod
|
||||
def stack_on_zeroth_dimension(tensors: List[jt.jittor_core.Var]):
|
||||
return jt.stack(tensors)
|
||||
|
||||
@staticmethod
|
||||
def tile(x, repeats: List[int]):
|
||||
return x.repeat(repeats)
|
||||
|
||||
@staticmethod
|
||||
def add_axes(x, n_axes: int, pos2len: Dict[int, int]):
|
||||
repeats = [1] * n_axes
|
||||
for axis_position, axis_length in pos2len.items():
|
||||
x = jt.unsqueeze(x, axis_position)
|
||||
repeats[axis_position] = axis_length
|
||||
return JittorJitBackend.tile(x, repeats)
|
||||
|
||||
@staticmethod
|
||||
def is_float_type(x):
|
||||
return x.dtype in ["float16", "float32", "float64"]
|
||||
|
||||
@staticmethod
|
||||
def shape(x):
|
||||
return x.shape
|
||||
|
||||
@staticmethod
|
||||
def reshape(x, shape: List[int]):
|
||||
return x.reshape(shape)
|
||||
|
||||
|
||||
# mirrors einops.einops._apply_recipe
|
||||
def apply_for_scriptable_jittor(recipe: TransformRecipe, tensor: jt.jittor_core.Var, reduction_type: str) -> jt.jittor_core.Var:
|
||||
backend = JittorJitBackend
|
||||
init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \
|
||||
_reconstruct_from_shape_uncached(recipe, backend.shape(tensor))
|
||||
tensor = backend.reshape(tensor, init_shapes)
|
||||
if len(reduced_axes) > 0:
|
||||
tensor = backend.reduce(tensor, operation=reduction_type, reduced_axes=reduced_axes)
|
||||
tensor = backend.transpose(tensor, axes_reordering)
|
||||
if len(added_axes) > 0:
|
||||
tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes)
|
||||
return backend.reshape(tensor, final_shapes)
|
|
@ -0,0 +1,625 @@
|
|||
import functools
|
||||
import itertools
|
||||
import typing
|
||||
from collections import OrderedDict
|
||||
from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
import numpy as np
|
||||
|
||||
from jittor.einops import EinopsError
|
||||
from jittor.einops._backends import get_backend
|
||||
from jittor.einops.parsing import ParsedExpression, _ellipsis, AnonymousAxis
|
||||
|
||||
Tensor = TypeVar('Tensor')
|
||||
ReductionCallable = Callable[[Tensor, List[int]], Tensor]
|
||||
Reduction = Union[str, ReductionCallable]
|
||||
|
||||
_reductions = ('min', 'max', 'sum', 'mean', 'prod')
|
||||
_ellipsis_not_in_parenthesis: List[int] = [-999]
|
||||
_unknown_axis_length = -999999
|
||||
|
||||
|
||||
def is_ellipsis_not_in_parenthesis(group: List[int]) -> bool:
|
||||
if len(group) != 1:
|
||||
return False
|
||||
return group[0] == -999
|
||||
|
||||
|
||||
def _product(sequence: List[int]) -> int:
|
||||
""" minimalistic product that works both with numbers and symbols. Supports empty lists """
|
||||
result = 1
|
||||
for element in sequence:
|
||||
result *= element
|
||||
return result
|
||||
|
||||
|
||||
def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend):
|
||||
reduced_axes = tuple(reduced_axes)
|
||||
if callable(reduction_type):
|
||||
# custom callable
|
||||
return reduction_type(tensor, reduced_axes)
|
||||
else:
|
||||
# one of built-in operations
|
||||
if len(reduced_axes) == 0:
|
||||
return tensor
|
||||
assert reduction_type in _reductions
|
||||
if reduction_type == 'mean':
|
||||
if not backend.is_float_type(tensor):
|
||||
raise NotImplementedError('reduce_mean is not available for non-floating tensors')
|
||||
return backend.reduce(tensor, reduction_type, reduced_axes)
|
||||
|
||||
|
||||
def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes):
|
||||
# 'collapses' neighboring axes if those participate in the result pattern in the same order
|
||||
# TODO add support for added_axes
|
||||
assert len(axes_reordering) + len(reduced_axes) == len(init_shapes)
|
||||
# joining consecutive axes that will be reduced
|
||||
# possibly we can skip this if all backends can optimize this (not sure)
|
||||
reduced_axes = tuple(sorted(reduced_axes))
|
||||
for i in range(len(reduced_axes) - 1)[::-1]:
|
||||
if reduced_axes[i] + 1 == reduced_axes[i + 1]:
|
||||
removed_axis = reduced_axes[i + 1]
|
||||
removed_length = init_shapes[removed_axis]
|
||||
init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
|
||||
init_shapes[removed_axis - 1] *= removed_length
|
||||
reduced_axes = reduced_axes[:i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:])
|
||||
|
||||
# removing axes that are moved together during reshape
|
||||
def build_mapping():
|
||||
init_to_final = {}
|
||||
for axis in range(len(init_shapes)):
|
||||
if axis in reduced_axes:
|
||||
init_to_final[axis] = None
|
||||
else:
|
||||
after_reduction = sum(x is not None for x in init_to_final.values())
|
||||
init_to_final[axis] = list(axes_reordering).index(after_reduction)
|
||||
return init_to_final
|
||||
|
||||
init_axis_to_final_axis = build_mapping()
|
||||
|
||||
for init_axis in range(len(init_shapes) - 1)[::-1]:
|
||||
if init_axis_to_final_axis[init_axis] is None:
|
||||
continue
|
||||
if init_axis_to_final_axis[init_axis + 1] is None:
|
||||
continue
|
||||
if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]:
|
||||
removed_axis = init_axis + 1
|
||||
removed_length = init_shapes[removed_axis]
|
||||
removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis))
|
||||
|
||||
reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes)
|
||||
init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
|
||||
init_shapes[removed_axis - 1] *= removed_length
|
||||
old_reordering = axes_reordering
|
||||
axes_reordering = []
|
||||
for axis in old_reordering:
|
||||
if axis == removed_axis_after_reduction:
|
||||
pass
|
||||
elif axis < removed_axis_after_reduction:
|
||||
axes_reordering.append(axis)
|
||||
else:
|
||||
axes_reordering.append(axis - 1)
|
||||
init_axis_to_final_axis = build_mapping()
|
||||
|
||||
return init_shapes, reduced_axes, axes_reordering, final_shapes
|
||||
|
||||
|
||||
CookedRecipe = Tuple[List[int], List[int], List[int], Dict[int, int], List[int]]
|
||||
|
||||
|
||||
class TransformRecipe:
|
||||
"""
|
||||
Recipe describes actual computation pathway.
|
||||
Recipe can be applied to a tensor or variable.
|
||||
"""
|
||||
|
||||
# structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+)
|
||||
|
||||
def __init__(self,
|
||||
# list of expressions (or just sizes) for elementary axes as they appear in left expression.
|
||||
# this is what (after computing unknown parts) will be a shape after first transposition.
|
||||
# If ellipsis is present, it forms one dimension here (in the right position).
|
||||
elementary_axes_lengths: List[int],
|
||||
# each dimension in input can help to reconstruct length of one elementary axis
|
||||
# or verify one of dimensions. Each element points to element of elementary_axes_lengths
|
||||
input_composite_axes: List[Tuple[List[int], List[int]]],
|
||||
# indices of axes to be squashed
|
||||
reduced_elementary_axes: List[int],
|
||||
# in which order should axes be reshuffled after reduction
|
||||
axes_permutation: List[int],
|
||||
# at which positions which of elementary axes should appear
|
||||
added_axes: Dict[int, int],
|
||||
# ids of axes as they appear in result, again pointers to elementary_axes_lengths,
|
||||
# only used to infer result dimensions
|
||||
output_composite_axes: List[List[int]],
|
||||
# positions of ellipsis in lhs and rhs of expression
|
||||
ellipsis_position_in_lhs: Optional[int] = None,
|
||||
):
|
||||
self.elementary_axes_lengths: List[int] = elementary_axes_lengths
|
||||
self.input_composite_axes: List[Tuple[List[int], List[int]]] = input_composite_axes
|
||||
self.output_composite_axes: List[List[int]] = output_composite_axes
|
||||
self.axes_permutation: List[int] = axes_permutation
|
||||
self.added_axes: Dict[int, int] = added_axes
|
||||
# This is redundant information, but more convenient to use
|
||||
self.reduced_elementary_axes: List[int] = reduced_elementary_axes
|
||||
# setting to a large number to avoid handling Nones in reconstruct_from_shape
|
||||
self.ellipsis_position_in_lhs: int = ellipsis_position_in_lhs if ellipsis_position_in_lhs is not None else 10000
|
||||
|
||||
|
||||
def _reconstruct_from_shape_uncached(self: TransformRecipe, shape: List[int]) -> CookedRecipe:
|
||||
"""
|
||||
Reconstruct all actual parameters using shape.
|
||||
Shape is a tuple that may contain integers, shape symbols (tf, keras, theano) and UnknownSize (keras, mxnet)
|
||||
known axes can be integers or symbols, but not Nones.
|
||||
"""
|
||||
axes_lengths: List[int] = list(self.elementary_axes_lengths)
|
||||
if self.ellipsis_position_in_lhs != 10000:
|
||||
if len(shape) < len(self.input_composite_axes) - 1:
|
||||
raise EinopsError('Expected at least {} dimensions, got {}'.format(
|
||||
len(self.input_composite_axes) - 1, len(shape)))
|
||||
else:
|
||||
if len(shape) != len(self.input_composite_axes):
|
||||
raise EinopsError('Expected {} dimensions, got {}'.format(len(self.input_composite_axes), len(shape)))
|
||||
|
||||
ellipsis_shape: List[int] = []
|
||||
for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composite_axes):
|
||||
before_ellipsis = input_axis
|
||||
after_ellipsis = input_axis + len(shape) - len(self.input_composite_axes)
|
||||
if input_axis == self.ellipsis_position_in_lhs:
|
||||
assert len(known_axes) == 0 and len(unknown_axes) == 1
|
||||
unknown_axis, = unknown_axes
|
||||
ellipsis_shape = shape[before_ellipsis:after_ellipsis + 1]
|
||||
for d in ellipsis_shape:
|
||||
if d is None:
|
||||
raise EinopsError("Couldn't infer shape for one or more axes represented by ellipsis")
|
||||
total_dim_size: int = _product(ellipsis_shape)
|
||||
axes_lengths[unknown_axis] = total_dim_size
|
||||
else:
|
||||
if input_axis < self.ellipsis_position_in_lhs:
|
||||
length = shape[before_ellipsis]
|
||||
else:
|
||||
length = shape[after_ellipsis]
|
||||
known_product = 1
|
||||
for axis in known_axes:
|
||||
known_product *= axes_lengths[axis]
|
||||
|
||||
if len(unknown_axes) == 0:
|
||||
if isinstance(length, int) and isinstance(known_product, int) and length != known_product:
|
||||
raise EinopsError('Shape mismatch, {} != {}'.format(length, known_product))
|
||||
# this is enforced when recipe is created
|
||||
# elif len(unknown_axes) > 1:
|
||||
# raise EinopsError(
|
||||
# "Lengths of two or more axes in parenthesis not provided (dim={}), can't infer dimensions".
|
||||
# format(known_product)
|
||||
# )
|
||||
else:
|
||||
if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0:
|
||||
raise EinopsError("Shape mismatch, can't divide axis of length {} in chunks of {}".format(
|
||||
length, known_product))
|
||||
|
||||
unknown_axis: int = unknown_axes[0]
|
||||
inferred_length: int = length // known_product
|
||||
axes_lengths[unknown_axis] = inferred_length
|
||||
|
||||
# at this point all axes_lengths are computed (either have values or variables, but not Nones)
|
||||
|
||||
# TODO more readable expression
|
||||
init_shapes = axes_lengths[:len(axes_lengths) - len(self.added_axes)]
|
||||
final_shapes: List[int] = []
|
||||
for output_axis, grouping in enumerate(self.output_composite_axes):
|
||||
if is_ellipsis_not_in_parenthesis(grouping):
|
||||
final_shapes.extend(ellipsis_shape)
|
||||
else:
|
||||
lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping]
|
||||
final_shapes.append(_product(lengths))
|
||||
reduced_axes = self.reduced_elementary_axes
|
||||
axes_reordering = self.axes_permutation
|
||||
added_axes: Dict[int, int] = {
|
||||
pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items()}
|
||||
# if optimize:
|
||||
# assert len(self.added_axes) == 0
|
||||
# return _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes)
|
||||
return init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes
|
||||
|
||||
|
||||
_reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached)
|
||||
|
||||
|
||||
def _apply_recipe(recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction) -> Tensor:
|
||||
# this method works for all backends but not compilable with
|
||||
backend = get_backend(tensor)
|
||||
init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \
|
||||
_reconstruct_from_shape(recipe, backend.shape(tensor))
|
||||
tensor = backend.reshape(tensor, init_shapes)
|
||||
tensor = _reduce_axes(tensor, reduction_type=reduction_type, reduced_axes=reduced_axes, backend=backend)
|
||||
tensor = backend.transpose(tensor, axes_reordering)
|
||||
if len(added_axes) > 0:
|
||||
tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes)
|
||||
return backend.reshape(tensor, final_shapes)
|
||||
|
||||
|
||||
@functools.lru_cache(256)
|
||||
def _prepare_transformation_recipe(pattern: str,
|
||||
operation: Reduction,
|
||||
axes_lengths: Tuple[Tuple, ...]) -> TransformRecipe:
|
||||
""" Perform initial parsing of pattern and provided supplementary info
|
||||
axes_lengths is a tuple of tuples (axis_name, axis_length)
|
||||
"""
|
||||
left, rght = pattern.split('->')
|
||||
left = ParsedExpression(left)
|
||||
rght = ParsedExpression(rght)
|
||||
|
||||
# checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction
|
||||
if not left.has_ellipsis and rght.has_ellipsis:
|
||||
raise EinopsError('Ellipsis found in right side, but not left side of a pattern {}'.format(pattern))
|
||||
if left.has_ellipsis and left.has_ellipsis_parenthesized:
|
||||
raise EinopsError('Ellipsis is parenthesis in the left side is not allowed: {}'.format(pattern))
|
||||
if operation == 'rearrange':
|
||||
difference = set.symmetric_difference(left.identifiers, rght.identifiers)
|
||||
if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes:
|
||||
raise EinopsError('Non-unitary anonymous axes are not supported in rearrange (exception is length 1)')
|
||||
if len(difference) > 0:
|
||||
raise EinopsError('Identifiers only on one side of expression (should be on both): {}'.format(difference))
|
||||
elif operation == 'repeat':
|
||||
difference = set.difference(left.identifiers, rght.identifiers)
|
||||
if len(difference) > 0:
|
||||
raise EinopsError('Unexpected identifiers on the left side of repeat: {}'.format(difference))
|
||||
axes_without_size = set.difference({ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)},
|
||||
{*left.identifiers, *(ax for ax, _ in axes_lengths)})
|
||||
if len(axes_without_size) > 0:
|
||||
raise EinopsError('Specify sizes for new axes in repeat: {}'.format(axes_without_size))
|
||||
elif operation in _reductions or callable(operation):
|
||||
difference = set.difference(rght.identifiers, left.identifiers)
|
||||
if len(difference) > 0:
|
||||
raise EinopsError('Unexpected identifiers on the right side of reduce {}: {}'.format(operation, difference))
|
||||
else:
|
||||
raise EinopsError('Unknown reduction {}. Expect one of {}.'.format(operation, _reductions))
|
||||
|
||||
# parsing all dimensions to find out lengths
|
||||
axis_name2known_length = OrderedDict()
|
||||
for composite_axis in left.composition:
|
||||
for axis_name in composite_axis:
|
||||
if isinstance(axis_name, AnonymousAxis):
|
||||
axis_name2known_length[axis_name] = axis_name.value
|
||||
else:
|
||||
axis_name2known_length[axis_name] = _unknown_axis_length
|
||||
|
||||
# axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point
|
||||
|
||||
repeat_axes_names = []
|
||||
for axis_name in rght.identifiers:
|
||||
if axis_name not in axis_name2known_length:
|
||||
if isinstance(axis_name, AnonymousAxis):
|
||||
axis_name2known_length[axis_name] = axis_name.value
|
||||
else:
|
||||
axis_name2known_length[axis_name] = _unknown_axis_length
|
||||
repeat_axes_names.append(axis_name)
|
||||
|
||||
axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)}
|
||||
reduced_axes: List[int] = [position for axis, position in axis_name2position.items() if
|
||||
axis not in rght.identifiers]
|
||||
reduced_axes: List[int] = list(sorted(reduced_axes))
|
||||
|
||||
for elementary_axis, axis_length in axes_lengths:
|
||||
if not ParsedExpression.check_axis_name(elementary_axis):
|
||||
raise EinopsError('Invalid name for an axis', elementary_axis)
|
||||
if elementary_axis not in axis_name2known_length:
|
||||
raise EinopsError('Axis {} is not used in transform'.format(elementary_axis))
|
||||
axis_name2known_length[elementary_axis] = axis_length
|
||||
|
||||
input_axes_known_unknown = []
|
||||
# some of shapes will be inferred later - all information is prepared for faster inference
|
||||
for composite_axis in left.composition:
|
||||
known = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}
|
||||
unknown = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}
|
||||
if len(unknown) > 1:
|
||||
raise EinopsError('Could not infer sizes for {}'.format(unknown))
|
||||
assert len(unknown) + len(known) == len(composite_axis)
|
||||
input_axes_known_unknown.append(
|
||||
([axis_name2position[axis] for axis in known],
|
||||
[axis_name2position[axis] for axis in unknown])
|
||||
)
|
||||
|
||||
axis_position_after_reduction = {}
|
||||
for axis_name in itertools.chain(*left.composition):
|
||||
if axis_name in rght.identifiers:
|
||||
axis_position_after_reduction[axis_name] = len(axis_position_after_reduction)
|
||||
|
||||
result_axes_grouping: List[List[int]] = []
|
||||
for composite_axis in rght.composition:
|
||||
if composite_axis == _ellipsis:
|
||||
result_axes_grouping.append(_ellipsis_not_in_parenthesis)
|
||||
else:
|
||||
result_axes_grouping.append([axis_name2position[axis] for axis in composite_axis])
|
||||
|
||||
ordered_axis_right = list(itertools.chain(*rght.composition))
|
||||
axes_permutation = [
|
||||
axis_position_after_reduction[axis] for axis in ordered_axis_right if axis in left.identifiers]
|
||||
added_axes = {i: axis_name2position[axis_name] for i, axis_name in enumerate(ordered_axis_right)
|
||||
if axis_name not in left.identifiers}
|
||||
|
||||
ellipsis_left = None if _ellipsis not in left.composition else left.composition.index(_ellipsis)
|
||||
|
||||
return TransformRecipe(
|
||||
elementary_axes_lengths=list(axis_name2known_length.values()),
|
||||
input_composite_axes=input_axes_known_unknown,
|
||||
reduced_elementary_axes=reduced_axes,
|
||||
axes_permutation=axes_permutation,
|
||||
added_axes=added_axes,
|
||||
output_composite_axes=result_axes_grouping,
|
||||
ellipsis_position_in_lhs=ellipsis_left,
|
||||
)
|
||||
|
||||
|
||||
def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor:
|
||||
"""
|
||||
einops.reduce provides combination of reordering and reduction using reader-friendly notation.
|
||||
|
||||
Examples for reduce operation:
|
||||
|
||||
```python
|
||||
>>> x = np.random.randn(100, 32, 64)
|
||||
|
||||
# perform max-reduction on the first axis
|
||||
>>> y = reduce(x, 't b c -> b c', 'max')
|
||||
|
||||
# same as previous, but with clearer axes meaning
|
||||
>>> y = reduce(x, 'time batch channel -> batch channel', 'max')
|
||||
|
||||
>>> x = np.random.randn(10, 20, 30, 40)
|
||||
|
||||
# 2d max-pooling with kernel size = 2 * 2 for image processing
|
||||
>>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
|
||||
|
||||
# if one wants to go back to the original height and width, depth-to-space trick can be applied
|
||||
>>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2)
|
||||
>>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w')
|
||||
|
||||
# Adaptive 2d max-pooling to 3 * 4 grid
|
||||
>>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape
|
||||
(10, 20, 3, 4)
|
||||
|
||||
# Global average pooling
|
||||
>>> reduce(x, 'b c h w -> b c', 'mean').shape
|
||||
(10, 20)
|
||||
|
||||
# Subtracting mean over batch for each channel
|
||||
>>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean')
|
||||
|
||||
# Subtracting per-image mean for each channel
|
||||
>>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean')
|
||||
|
||||
```
|
||||
|
||||
Parameters:
|
||||
tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
|
||||
list of tensors is also accepted, those should be of the same type and shape
|
||||
pattern: string, reduction pattern
|
||||
reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
|
||||
alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided.
|
||||
axes_lengths: any additional specifications for dimensions
|
||||
|
||||
Returns:
|
||||
tensor of the same type as input
|
||||
"""
|
||||
try:
|
||||
hashable_axes_lengths = tuple(sorted(axes_lengths.items()))
|
||||
recipe = _prepare_transformation_recipe(pattern, reduction, axes_lengths=hashable_axes_lengths)
|
||||
return _apply_recipe(recipe, tensor, reduction_type=reduction)
|
||||
except EinopsError as e:
|
||||
message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern)
|
||||
if not isinstance(tensor, list):
|
||||
message += '\n Input tensor shape: {}. '.format(get_backend(tensor).shape(tensor))
|
||||
else:
|
||||
message += '\n Input is list. '
|
||||
message += 'Additional info: {}.'.format(axes_lengths)
|
||||
raise EinopsError(message + '\n {}'.format(e))
|
||||
|
||||
|
||||
|
||||
@typing.overload
|
||||
def rearrange(tensor: Tensor, pattern: str, **axes_length: int) -> Tensor: ...
|
||||
@typing.overload
|
||||
def rearrange(tensor: List[Tensor], pattern: str, **axes_lengths: int) -> Tensor: ...
|
||||
|
||||
|
||||
def rearrange(tensor, pattern: str, **axes_lengths):
|
||||
"""
|
||||
einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors.
|
||||
This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
|
||||
stack, concatenate and other operations.
|
||||
|
||||
Examples for rearrange operation:
|
||||
|
||||
```python
|
||||
# suppose we have a set of 32 images in "h w c" format (height-width-channel)
|
||||
>>> images = [np.random.randn(30, 40, 3) for _ in range(32)]
|
||||
|
||||
# stack along first (batch) axis, output is a single array
|
||||
>>> rearrange(images, 'b h w c -> b h w c').shape
|
||||
(32, 30, 40, 3)
|
||||
|
||||
# concatenate images along height (vertical axis), 960 = 32 * 30
|
||||
>>> rearrange(images, 'b h w c -> (b h) w c').shape
|
||||
(960, 40, 3)
|
||||
|
||||
# concatenated images along horizontal axis, 1280 = 32 * 40
|
||||
>>> rearrange(images, 'b h w c -> h (b w) c').shape
|
||||
(30, 1280, 3)
|
||||
|
||||
# reordered axes to "b c h w" format for deep learning
|
||||
>>> rearrange(images, 'b h w c -> b c h w').shape
|
||||
(32, 3, 30, 40)
|
||||
|
||||
# flattened each image into a vector, 3600 = 30 * 40 * 3
|
||||
>>> rearrange(images, 'b h w c -> b (c h w)').shape
|
||||
(32, 3600)
|
||||
|
||||
# split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
|
||||
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
|
||||
(128, 15, 20, 3)
|
||||
|
||||
# space-to-depth operation
|
||||
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
|
||||
(32, 15, 20, 12)
|
||||
|
||||
```
|
||||
|
||||
When composing axes, C-order enumeration used (consecutive elements have different last axis)
|
||||
Find more examples in einops tutorial.
|
||||
|
||||
Parameters:
|
||||
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
|
||||
list of tensors is also accepted, those should be of the same type and shape
|
||||
pattern: string, rearrangement pattern
|
||||
axes_lengths: any additional specifications for dimensions
|
||||
|
||||
Returns:
|
||||
tensor of the same type as input. If possible, a view to the original tensor is returned.
|
||||
|
||||
"""
|
||||
if isinstance(tensor, list):
|
||||
if len(tensor) == 0:
|
||||
raise TypeError("Rearrange can't be applied to an empty list")
|
||||
tensor = get_backend(tensor[0]).stack_on_zeroth_dimension(tensor)
|
||||
return reduce(tensor, pattern, reduction='rearrange', **axes_lengths)
|
||||
|
||||
|
||||
def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
|
||||
"""
|
||||
einops.repeat allows reordering elements and repeating them in arbitrary combinations.
|
||||
This operation includes functionality of repeat, tile, broadcast functions.
|
||||
|
||||
Examples for repeat operation:
|
||||
|
||||
```python
|
||||
# a grayscale image (of shape height x width)
|
||||
>>> image = np.random.randn(30, 40)
|
||||
|
||||
# change it to RGB format by repeating in each channel
|
||||
>>> repeat(image, 'h w -> h w c', c=3).shape
|
||||
(30, 40, 3)
|
||||
|
||||
# repeat image 2 times along height (vertical axis)
|
||||
>>> repeat(image, 'h w -> (repeat h) w', repeat=2).shape
|
||||
(60, 40)
|
||||
|
||||
# repeat image 2 time along height and 3 times along width
|
||||
>>> repeat(image, 'h w -> h (repeat w)', repeat=3).shape
|
||||
(30, 120)
|
||||
|
||||
# convert each pixel to a small square 2x2. Upsample image by 2x
|
||||
>>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
|
||||
(60, 80)
|
||||
|
||||
# pixelate image first by downsampling by 2x, then upsampling
|
||||
>>> downsampled = reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2)
|
||||
>>> repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
|
||||
(30, 40)
|
||||
|
||||
```
|
||||
|
||||
When composing axes, C-order enumeration used (consecutive elements have different last axis)
|
||||
Find more examples in einops tutorial.
|
||||
|
||||
Parameters:
|
||||
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
|
||||
list of tensors is also accepted, those should be of the same type and shape
|
||||
pattern: string, rearrangement pattern
|
||||
axes_lengths: any additional specifications for dimensions
|
||||
|
||||
Returns:
|
||||
Tensor of the same type as input. If possible, a view to the original tensor is returned.
|
||||
|
||||
"""
|
||||
return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
|
||||
|
||||
|
||||
def parse_shape(x, pattern: str):
|
||||
"""
|
||||
Parse a tensor shape to dictionary mapping axes names to their lengths.
|
||||
|
||||
```python
|
||||
# Use underscore to skip the dimension in parsing.
|
||||
>>> x = np.zeros([2, 3, 5, 7])
|
||||
>>> parse_shape(x, 'batch _ h w')
|
||||
{'batch': 2, 'h': 5, 'w': 7}
|
||||
|
||||
# `parse_shape` output can be used to specify axes_lengths for other operations:
|
||||
>>> y = np.zeros([700])
|
||||
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
|
||||
(2, 10, 5, 7)
|
||||
|
||||
```
|
||||
|
||||
For symbolic frameworks may return symbols, not integers.
|
||||
|
||||
Parameters:
|
||||
x: tensor of any of supported frameworks
|
||||
pattern: str, space separated names for axes, underscore means skip axis
|
||||
|
||||
Returns:
|
||||
dict, maps axes names to their lengths
|
||||
"""
|
||||
exp = ParsedExpression(pattern, allow_underscore=True)
|
||||
shape = get_backend(x).shape(x)
|
||||
if exp.has_composed_axes():
|
||||
raise RuntimeError("Can't parse shape with composite axes: {pattern} {shape}".format(
|
||||
pattern=pattern, shape=shape))
|
||||
if len(shape) != len(exp.composition):
|
||||
if exp.has_ellipsis:
|
||||
if len(shape) < len(exp.composition) - 1:
|
||||
raise RuntimeError("Can't parse shape with this number of dimensions: {pattern} {shape}".format(
|
||||
pattern=pattern, shape=shape))
|
||||
else:
|
||||
raise RuntimeError("Can't parse shape with different number of dimensions: {pattern} {shape}".format(
|
||||
pattern=pattern, shape=shape))
|
||||
if exp.has_ellipsis:
|
||||
ellipsis_idx = exp.composition.index(_ellipsis)
|
||||
composition = (exp.composition[:ellipsis_idx] +
|
||||
['_'] * (len(shape) - len(exp.composition) + 1) +
|
||||
exp.composition[ellipsis_idx+1:])
|
||||
else:
|
||||
composition = exp.composition
|
||||
result = {}
|
||||
for (axis_name, ), axis_length in zip(composition, shape):
|
||||
if axis_name != '_':
|
||||
result[axis_name] = axis_length
|
||||
return result
|
||||
|
||||
|
||||
# this one is probably not needed in the public API
|
||||
def _enumerate_directions(x):
|
||||
"""
|
||||
For an n-dimensional tensor, returns tensors to enumerate each axis.
|
||||
```python
|
||||
x = np.zeros([2, 3, 4]) # or any other tensor
|
||||
i, j, k = _enumerate_directions(x)
|
||||
result = i + 2*j + 3*k
|
||||
```
|
||||
|
||||
`result[i, j, k] = i + 2j + 3k`, and also has the same shape as result
|
||||
Works very similarly to numpy.ogrid (open indexing grid)
|
||||
"""
|
||||
backend = get_backend(x)
|
||||
shape = backend.shape(x)
|
||||
result = []
|
||||
for axis_id, axis_length in enumerate(shape):
|
||||
shape = [1] * len(shape)
|
||||
shape[axis_id] = axis_length
|
||||
result.append(backend.reshape(backend.arange(0, axis_length), shape))
|
||||
return result
|
||||
|
||||
|
||||
def asnumpy(tensor) -> 'numpy.ndarray':
|
||||
"""
|
||||
Convert a tensor of an imperative framework (i.e. numpy/jittor.) to `numpy.ndarray`
|
||||
|
||||
Parameters:
|
||||
tensor: tensor of any of known imperative framework
|
||||
|
||||
Returns:
|
||||
`numpy.ndarray`, converted to numpy
|
||||
"""
|
||||
return get_backend(tensor).to_numpy(tensor)
|
|
@ -0,0 +1,393 @@
|
|||
"""
|
||||
|
||||
Indexing one array with the other(s).
|
||||
|
||||
Concept for discussion.
|
||||
|
||||
Notation targets hard cases, not simple ones, like indexing of 1d-array with another 1d-array
|
||||
(notation supports that, but you can't simplify arr[ind], and there is no reason to)
|
||||
|
||||
Examples
|
||||
|
||||
1. query for every token in sequence a token in the image. Images and sequences are paired
|
||||
einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, [h_indices_bt, w_indices_bt])
|
||||
|
||||
this is equivalent, so you can pass indexers idependently or together
|
||||
einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, np.asarray([h_indices_bt, w_indices_bt]))
|
||||
|
||||
after some thinking I decided that having first axis for indexing variable is not too restrictive,
|
||||
but should simplify mapping of such cases.
|
||||
For this reason [...] part should always go first in indexer.
|
||||
|
||||
This makes the largest difference with einindex https://github.com/malmaud/einindex,
|
||||
which has almost identical grammar, but puts special dimension last, while we put it first.
|
||||
This trick allows naturally decomposing multiindex into individual dimensions or visa versa.
|
||||
|
||||
|
||||
2. query for every token in the video the most suitable word in a (matching) sentence
|
||||
einindex('b t h w <- seq b, [seq] t b h w', arr_tbc, [t_indices_bhw])
|
||||
|
||||
note, that only one indexer is used, but still it has to be enclosed in the list.
|
||||
That's a price for being generic. Alternatively leading singleton dimension can be added.
|
||||
|
||||
|
||||
3. (not supported now, future planning)
|
||||
for every timeframe in a video, find the token with the highest norm (across h and w), and compose a new stack of them
|
||||
indices_2bt = argmax(x_bthwc.norm(dim=-1), 'b t h w -> [h, w] b t')
|
||||
selected_embeddings_btc = einindex('b t c <- b t h w c, [h, w] b t', x_bthwc, indices_2bt)
|
||||
|
||||
while currently question is around 'how do we index',
|
||||
it is important to pre-align that with a question 'what are natural ways to get indices'.
|
||||
Most common are min/max. less common options: topk (works here), random sampling.
|
||||
|
||||
|
||||
|
||||
Some important properties of this notation:
|
||||
- support for multiple indexers, including using a single tensor to keep multiple indexers
|
||||
- 'batch' indexing, when some axes of indexer and array should be matched
|
||||
- universal (one-indexing-to-rule-them-all)
|
||||
- extensible for (named) ellipses, including variadic number of indexers
|
||||
- extensible for einops-style compositions and decompositions
|
||||
- extensible for outer indexing when indexers are not aligned
|
||||
|
||||
Current implementation based on python array api and uses loops,
|
||||
because no appropriate indexing available in the standard.
|
||||
|
||||
"""
|
||||
|
||||
from typing import List, Union, TypeVar, Tuple
|
||||
|
||||
from jittor.einops import EinopsError
|
||||
|
||||
T = TypeVar('T')
|
||||
|
||||
|
||||
class CompositionDecomposition:
|
||||
def __init__(
|
||||
self,
|
||||
decomposed_shape: List[str],
|
||||
composed_shape: List[List[str]],
|
||||
):
|
||||
flat_shape = []
|
||||
for x in composed_shape:
|
||||
flat_shape.extend(x)
|
||||
|
||||
self.compose_transposition: Tuple[int] = tuple([decomposed_shape.index(x) for x in flat_shape])
|
||||
self.decompose_transposition: Tuple[int] = tuple([flat_shape.index(x) for x in decomposed_shape])
|
||||
self.composed_shape = composed_shape
|
||||
self.decomposed_shape = decomposed_shape
|
||||
|
||||
def decompose(self, x, known_axes_lengths: dict[str, int]):
|
||||
xp = x.__array_namespace__()
|
||||
shape = x.shape
|
||||
|
||||
flat_shape = []
|
||||
|
||||
for i, axis_group in enumerate(self.composed_shape):
|
||||
unknown_axis_name = None
|
||||
known_sizes_prod = 1
|
||||
for axis_name in axis_group:
|
||||
if axis_name in known_axes_lengths:
|
||||
known_sizes_prod *= known_axes_lengths[axis_name]
|
||||
else:
|
||||
if unknown_axis_name is None:
|
||||
unknown_axis_name = axis_name
|
||||
else:
|
||||
raise EinopsError("Can't infer the size")
|
||||
|
||||
if unknown_axis_name is None:
|
||||
assert shape[i] == known_sizes_prod
|
||||
else:
|
||||
known_axes_lengths[unknown_axis_name] = shape[i] // known_sizes_prod
|
||||
|
||||
for axis in axis_group:
|
||||
flat_shape.append(known_axes_lengths[axis])
|
||||
|
||||
x = xp.reshape(x, flat_shape)
|
||||
return xp.permute_dims(x, self.decompose_transposition)
|
||||
|
||||
def compose(self, x, known_axes_lengths: dict[str, int]):
|
||||
xp = x.__array_namespace__()
|
||||
|
||||
for axis_len, axis_name in zip(x.shape, self.decomposed_shape):
|
||||
if axis_name in known_axes_lengths:
|
||||
assert known_axes_lengths[axis_name] == axis_len
|
||||
else:
|
||||
known_axes_lengths[axis_name] = axis_len
|
||||
|
||||
x = xp.permute_dims(x, self.compose_transposition)
|
||||
new_shape = []
|
||||
for axis_group in self.composed_shape:
|
||||
composed_axis_size = 1
|
||||
for axis_name in axis_group:
|
||||
composed_axis_size *= known_axes_lengths[axis_name]
|
||||
new_shape.append(composed_axis_size)
|
||||
|
||||
return xp.reshape(x, tuple(new_shape))
|
||||
|
||||
|
||||
def arange_at_position(xp, n_axes, axis, axis_len, device=None):
|
||||
x = xp.arange(axis_len, dtype=xp.int64, device=device)
|
||||
shape = [1] * n_axes
|
||||
shape[axis] = axis_len
|
||||
x = xp.reshape(x, shape)
|
||||
return x
|
||||
|
||||
|
||||
class IndexingFormula:
|
||||
|
||||
def __init__(self, pattern: str):
|
||||
"""
|
||||
:param pattern: example 'b t c <- b hsel wsel c, [hsel, wsel] b t'
|
||||
"""
|
||||
self.pattern = pattern
|
||||
left, right = pattern.split('<-')
|
||||
arg_split = right.index(',')
|
||||
arr_pattern, ind_pattern = right[:arg_split], right[arg_split + 1:]
|
||||
ind_pattern = ind_pattern.strip()
|
||||
# print(
|
||||
# arr_pattern, '\n',
|
||||
# ind_pattern,
|
||||
# )
|
||||
assert ind_pattern.startswith('['), 'composition axis should go first in indexer (second argument) [h w] i j k'
|
||||
composition_start = ind_pattern.index('[')
|
||||
composition_end = ind_pattern.index(']')
|
||||
composition = ind_pattern[composition_start + 1: composition_end]
|
||||
ind_other_axes = ind_pattern[composition_end + 1:]
|
||||
|
||||
self.result_axes_names = left.split()
|
||||
self.array_axes_names = arr_pattern.split()
|
||||
self.indexing_axes_names = [x.strip() for x in composition.split(',')]
|
||||
self.indexer_other_axes_names = ind_other_axes.split()
|
||||
|
||||
for group_name, group in [
|
||||
('result', self.result_axes_names),
|
||||
('array', self.array_axes_names),
|
||||
('indexer', self.indexing_axes_names + self.indexer_other_axes_names),
|
||||
]:
|
||||
if len(set(group)) != len(group):
|
||||
# need more verbosity, which axis, raise
|
||||
raise EinopsError(f'{group_name} pattern ({group}) contains a duplicated axis')
|
||||
|
||||
axis_groups = [
|
||||
self.result_axes_names,
|
||||
self.array_axes_names,
|
||||
self.indexing_axes_names,
|
||||
self.indexer_other_axes_names,
|
||||
]
|
||||
|
||||
all_axes = set()
|
||||
for group in axis_groups:
|
||||
all_axes.update(group)
|
||||
|
||||
self.indexer_axes = []
|
||||
self.batch_axes = []
|
||||
self.result_and_index_axes = []
|
||||
self.result_and_array_axes = []
|
||||
|
||||
for axis in all_axes:
|
||||
presence = tuple(axis in g for g in axis_groups)
|
||||
# want match-case here. sweet dreams
|
||||
if presence == (False, True, True, False):
|
||||
self.indexer_axes.append(axis)
|
||||
elif presence[2]:
|
||||
raise EinopsError(f'Wrong usage of indexer variable {axis}')
|
||||
elif presence == (True, True, False, True):
|
||||
self.batch_axes.append(axis)
|
||||
elif presence == (True, False, False, True):
|
||||
self.result_and_index_axes.append(axis)
|
||||
elif presence == (True, True, False, False):
|
||||
self.result_and_array_axes.append(axis)
|
||||
else:
|
||||
# TODO better categorization of wrong usage patterns
|
||||
raise EinopsError(f'{axis} is used incorrectly in {pattern}')
|
||||
|
||||
assert set(self.indexer_axes) == set(self.indexing_axes_names)
|
||||
# order of these variables matters, since we can't lose mapping here
|
||||
self.indexer_axes = self.indexing_axes_names
|
||||
|
||||
self.array_composition = CompositionDecomposition(
|
||||
decomposed_shape=self.array_axes_names,
|
||||
composed_shape=[self.batch_axes + self.indexer_axes, self.result_and_array_axes],
|
||||
)
|
||||
|
||||
self.index_composition = CompositionDecomposition(
|
||||
decomposed_shape=self.indexer_other_axes_names,
|
||||
# single axis after composition
|
||||
composed_shape=[self.batch_axes + self.result_and_index_axes],
|
||||
)
|
||||
|
||||
self.result_composition = CompositionDecomposition(
|
||||
decomposed_shape=self.result_axes_names,
|
||||
composed_shape=[self.batch_axes + self.result_and_index_axes, self.result_and_array_axes],
|
||||
)
|
||||
|
||||
def apply_to_array_api(self, arr: T, ind: Union[T, List[T]]):
|
||||
known_axes_sizes: dict[str, int] = {}
|
||||
xp = arr.__array_namespace__()
|
||||
|
||||
if not isinstance(ind, list):
|
||||
ind = [ind[i, ...] for i in range(ind.shape[0])]
|
||||
|
||||
for indexer in ind:
|
||||
assert len(indexer.shape) == len(self.indexer_other_axes_names)
|
||||
|
||||
# step 1. transpose, reshapes of arr; learn its dimensions
|
||||
arr_2d = self.array_composition.compose(arr, known_axes_sizes)
|
||||
|
||||
# step 2. compute shifts and create an actual indexing array
|
||||
shift = 1
|
||||
full_index = xp.zeros([1] * len(ind[0].shape), dtype=xp.int64, device=arr.device)
|
||||
|
||||
# original order: [*batch-like axes, *indexing_axes,]
|
||||
# now we need to traverse them in the opposite direction
|
||||
|
||||
for axis_name, indexer in list(zip(self.indexing_axes_names, ind))[::-1]:
|
||||
full_index = full_index + shift * (indexer % known_axes_sizes[axis_name])
|
||||
shift *= known_axes_sizes[axis_name]
|
||||
|
||||
for axis_name in self.batch_axes[::-1]:
|
||||
axis_id = self.indexer_other_axes_names.index(axis_name)
|
||||
full_index = full_index + arange_at_position(
|
||||
xp, len(self.indexer_other_axes_names), axis=axis_id, axis_len=known_axes_sizes[axis_name],
|
||||
device=arr.device,
|
||||
) * shift
|
||||
shift *= known_axes_sizes[axis_name]
|
||||
|
||||
assert shift == arr_2d.shape[0]
|
||||
|
||||
# step 3. Flatten index
|
||||
full_index = self.index_composition.compose(full_index, known_axes_sizes)
|
||||
|
||||
# step 4. indexing
|
||||
# python array api lacks any integer indexing, so... I use loops.
|
||||
# did you know that there is conceptual programming ... just like art?
|
||||
# result_2d = arr_2d[full_index]
|
||||
result_2d = xp.stack([arr_2d[full_index[i], :] for i in range(full_index.shape[0])])
|
||||
|
||||
# step 5. doing resulting
|
||||
result = self.result_composition.decompose(result_2d, known_axes_sizes)
|
||||
return result
|
||||
|
||||
|
||||
def einindex(pattern: str, arr: T, /, ind: Union[T, List[T]]):
|
||||
"""
|
||||
Demonstrates how einindex should work.
|
||||
Supports data-api compliant arrays.
|
||||
"""
|
||||
formula = IndexingFormula(pattern)
|
||||
return formula.apply_to_array_api(arr, ind)
|
||||
|
||||
|
||||
def test_composition_and_decomposition():
|
||||
import numpy.array_api as np
|
||||
x = np.arange(2 * 3 * 5 * 7)
|
||||
x = np.reshape(x, (2, 3, 5, 7))
|
||||
comp = CompositionDecomposition(
|
||||
decomposed_shape=['a', 'b', 'c', 'd'],
|
||||
composed_shape=[['a', 'b'], ['c', 'd']],
|
||||
)
|
||||
assert comp.compose(x, known_axes_lengths={}).shape == (2 * 3, 5 * 7)
|
||||
|
||||
y = CompositionDecomposition(
|
||||
decomposed_shape=['a', 'b', 'c', 'd'],
|
||||
composed_shape=[['a', 'b'], [], ['c', 'd']],
|
||||
).compose(x, {})
|
||||
assert y.shape == (2 * 3, 1, 5 * 7)
|
||||
assert np.all(np.reshape(x, (-1,)) == np.reshape(y, (-1,)))
|
||||
|
||||
comp = CompositionDecomposition(
|
||||
decomposed_shape=['a', 'b', 'e', 'c', 'd'],
|
||||
composed_shape=[['e', 'c'], ['b'], ['a', 'd']],
|
||||
)
|
||||
x = np.arange(2 * 3 * 5 * 7 * 3)
|
||||
x = np.reshape(x, (2, 3, 5, 7, 3))
|
||||
|
||||
axes = {}
|
||||
y = comp.compose(x, axes)
|
||||
x2 = comp.decompose(y, axes)
|
||||
assert np.all(x == x2)
|
||||
|
||||
|
||||
def test_simple_indexing():
|
||||
import numpy.array_api as np
|
||||
|
||||
# simple 2d test
|
||||
arr = np.reshape(np.arange(5 * 7), (5, 7))
|
||||
ind = np.arange(7) % 5
|
||||
x = einindex('j <- i j, [i] j', arr, [ind])
|
||||
for j, i in enumerate(ind):
|
||||
assert arr[i, j] == x[j]
|
||||
|
||||
y = einindex('j <- j i, [i] j', np.permute_dims(arr, (1, 0)), [ind])
|
||||
for j, i in enumerate(ind):
|
||||
assert arr[i, j] == y[j]
|
||||
|
||||
|
||||
def test_multidimensional_indexing():
|
||||
import numpy.array_api as np
|
||||
|
||||
embedding_bhwc = (
|
||||
+ arange_at_position(np, 4, 0, 2) * 1000
|
||||
+ arange_at_position(np, 4, 1, 3) * 100
|
||||
+ arange_at_position(np, 4, 2, 5) * 10
|
||||
+ arange_at_position(np, 4, 3, 7) * 1
|
||||
)
|
||||
|
||||
hindices_bt = np.reshape(np.arange(6), (2, 3)) % 3
|
||||
windices_bt = np.reshape(np.arange(6), (2, 3)) % 5
|
||||
|
||||
# imagine that you have pairs of image <> sentence
|
||||
# your goal is to get most suitable token from image for every token in sentence
|
||||
# thus for every token in sentence you compute best k and v
|
||||
|
||||
result = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, [hindices_bt, windices_bt])
|
||||
# example of using a single array for indexing multiple axes
|
||||
hw_indices_bt = np.stack([hindices_bt, windices_bt])
|
||||
result2 = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, hw_indices_bt)
|
||||
assert np.all(result == result2)
|
||||
|
||||
# check vs manual element computation
|
||||
result_manual = result * 0
|
||||
for b in range(2):
|
||||
for t in range(3):
|
||||
for c in range(7):
|
||||
h = hindices_bt[b, t]
|
||||
w = windices_bt[b, t]
|
||||
result_manual[c, t, b] = embedding_bhwc[b, h, w, c]
|
||||
|
||||
assert np.all(result == result_manual)
|
||||
|
||||
|
||||
def test_reverse_indexing():
|
||||
import numpy.array_api as np
|
||||
|
||||
C, T, B = 2, 3, 5
|
||||
# G = GPU, batch-like varaible
|
||||
G = 4
|
||||
H = 7
|
||||
W = 9
|
||||
|
||||
arr_gtbc = (
|
||||
+ arange_at_position(np, 4, 0, G) * 1000
|
||||
+ arange_at_position(np, 4, 1, T) * 100
|
||||
+ arange_at_position(np, 4, 2, B) * 10
|
||||
+ arange_at_position(np, 4, 3, C) * 1
|
||||
)
|
||||
|
||||
t_indices_gbhw = np.reshape(np.arange(G * B * H * W), (G, B, H, W)) % T
|
||||
|
||||
result = einindex('g b c h w <- g t b c, [t] g b h w', arr_gtbc, [t_indices_gbhw])
|
||||
|
||||
result_manual = result * 0
|
||||
for g in range(G):
|
||||
for b in range(B):
|
||||
for c in range(C):
|
||||
for h in range(H):
|
||||
for w in range(W):
|
||||
t = t_indices_gbhw[g, b, h, w]
|
||||
result_manual[g, b, c, h, w] = arr_gtbc[g, t, b, c]
|
||||
|
||||
assert np.all(result == result_manual)
|
||||
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
__author__ = 'Alex Rogozhnikov'
|
||||
|
||||
import functools
|
||||
|
||||
from jittor.einops.einops import _apply_recipe
|
||||
|
||||
from jittor.einops.einops import TransformRecipe, _prepare_transformation_recipe
|
||||
from jittor.einops import EinopsError
|
||||
|
||||
|
||||
class RearrangeMixin:
|
||||
"""
|
||||
Rearrange layer behaves identically to einops.rearrange operation.
|
||||
|
||||
:param pattern: str, rearrangement pattern
|
||||
:param axes_lengths: any additional specification of dimensions
|
||||
|
||||
See einops.rearrange for source_examples.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern, **axes_lengths):
|
||||
super().__init__()
|
||||
self.pattern = pattern
|
||||
self.axes_lengths = axes_lengths
|
||||
self._recipe = self.recipe() # checking parameters
|
||||
|
||||
def __repr__(self):
|
||||
params = repr(self.pattern)
|
||||
for axis, length in self.axes_lengths.items():
|
||||
params += ', {}={}'.format(axis, length)
|
||||
return '{}({})'.format(self.__class__.__name__, params)
|
||||
|
||||
@functools.lru_cache(maxsize=1024)
|
||||
def recipe(self) -> TransformRecipe:
|
||||
try:
|
||||
hashable_lengths = tuple(sorted(self.axes_lengths.items()))
|
||||
return _prepare_transformation_recipe(self.pattern, operation='rearrange', axes_lengths=hashable_lengths)
|
||||
except EinopsError as e:
|
||||
raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e))
|
||||
|
||||
def _apply_recipe(self, x):
|
||||
return _apply_recipe(self._recipe, x, reduction_type='rearrange')
|
||||
|
||||
|
||||
class ReduceMixin:
|
||||
"""
|
||||
Reduce layer behaves identically to einops.reduce operation.
|
||||
|
||||
:param pattern: str, rearrangement pattern
|
||||
:param reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
|
||||
:param axes_lengths: any additional specification of dimensions
|
||||
|
||||
See einops.reduce for source_examples.
|
||||
"""
|
||||
|
||||
def __init__(self, pattern, reduction, **axes_lengths):
|
||||
super().__init__()
|
||||
self.pattern = pattern
|
||||
self.reduction = reduction
|
||||
self.axes_lengths = axes_lengths
|
||||
self._recipe = self.recipe() # checking parameters
|
||||
|
||||
def __repr__(self):
|
||||
params = '{!r}, {!r}'.format(self.pattern, self.reduction)
|
||||
for axis, length in self.axes_lengths.items():
|
||||
params += ', {}={}'.format(axis, length)
|
||||
return '{}({})'.format(self.__class__.__name__, params)
|
||||
|
||||
@functools.lru_cache(maxsize=1024)
|
||||
def recipe(self) -> TransformRecipe:
|
||||
try:
|
||||
hashable_lengths = tuple(sorted(self.axes_lengths.items()))
|
||||
return _prepare_transformation_recipe(
|
||||
self.pattern, operation=self.reduction, axes_lengths=hashable_lengths)
|
||||
except EinopsError as e:
|
||||
raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e))
|
||||
|
||||
def _apply_recipe(self, x):
|
||||
return _apply_recipe(self._recipe, x, reduction_type=self.reduction)
|
|
@ -0,0 +1,175 @@
|
|||
from typing import Optional, Dict
|
||||
|
||||
from jittor.einops import EinopsError
|
||||
from jittor.einops.parsing import ParsedExpression
|
||||
import warnings
|
||||
import string
|
||||
from jittor.einops.einops import _product
|
||||
|
||||
|
||||
def _report_axes(axes: set, report_message: str):
|
||||
if len(axes) > 0:
|
||||
raise EinopsError(report_message.format(axes))
|
||||
|
||||
|
||||
class _EinmixMixin:
|
||||
def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths):
|
||||
"""
|
||||
EinMix - Einstein summation with automated tensor management and axis packing/unpacking.
|
||||
|
||||
EinMix is an advanced tool, helpful tutorial:
|
||||
https://github.com/arogozhnikov/einops/blob/master/docs/3-einmix-layer.ipynb
|
||||
|
||||
Imagine taking einsum with two arguments, one of each input, and one - tensor with weights
|
||||
>>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight)
|
||||
|
||||
This layer manages weights for you, syntax highlights separate role of weight matrix
|
||||
>>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out')
|
||||
But otherwise it is the same einsum under the hood.
|
||||
|
||||
Simple linear layer with bias term (you have one like that in your framework)
|
||||
>>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20)
|
||||
There is restriction to mix the last axis. Let's mix along height
|
||||
>>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32)
|
||||
Channel-wise multiplication (like one used in normalizations)
|
||||
>>> EinMix('t b c -> t b c', weight_shape='c', c=128)
|
||||
Separate dense layer within each head, no connection between different heads
|
||||
>>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...)
|
||||
|
||||
... ah yes, you need to specify all dimensions of weight shape/bias shape in parameters.
|
||||
|
||||
Use cases:
|
||||
- when channel dimension is not last, use EinMix, not transposition
|
||||
- patch/segment embeddings
|
||||
- when need only within-group connections to reduce number of weights and computations
|
||||
- perfect as a part of sequential models
|
||||
- next-gen MLPs (follow tutorial to learn more)
|
||||
|
||||
Uniform He initialization is applied to weight tensor and encounters for number of elements mixed.
|
||||
|
||||
Parameters
|
||||
:param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output
|
||||
:param weight_shape: axes of weight. Tensor od this shape is created, stored, and optimized in a layer
|
||||
:param bias_shape: axes of bias added to output.
|
||||
:param axes_lengths: dimensions of weight tensor
|
||||
"""
|
||||
super().__init__()
|
||||
self.pattern = pattern
|
||||
self.weight_shape = weight_shape
|
||||
self.bias_shape = bias_shape
|
||||
self.axes_lengths = axes_lengths
|
||||
|
||||
left_pattern, right_pattern = pattern.split('->')
|
||||
left = ParsedExpression(left_pattern)
|
||||
right = ParsedExpression(right_pattern)
|
||||
weight = ParsedExpression(weight_shape)
|
||||
_report_axes(
|
||||
set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}),
|
||||
'Unrecognized identifiers on the right side of EinMix {}'
|
||||
)
|
||||
|
||||
if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis:
|
||||
raise EinopsError('Ellipsis is not supported in EinMix (right now)')
|
||||
if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]):
|
||||
raise EinopsError('Anonymous axes (numbers) are not allowed in EinMix')
|
||||
if '(' in weight_shape or ')' in weight_shape:
|
||||
raise EinopsError(f'Parenthesis is not allowed in weight shape: {weight_shape}')
|
||||
|
||||
pre_reshape_pattern = None
|
||||
pre_reshape_lengths = None
|
||||
post_reshape_pattern = None
|
||||
if any(len(group) != 1 for group in left.composition):
|
||||
names = []
|
||||
for group in left.composition:
|
||||
names += group
|
||||
composition = ' '.join(names)
|
||||
pre_reshape_pattern = f'{left_pattern}->{composition}'
|
||||
pre_reshape_lengths = {name: length for name, length in self.axes_lengths.items() if name in names}
|
||||
|
||||
if any(len(group) != 1 for group in right.composition):
|
||||
names = []
|
||||
for group in right.composition:
|
||||
names += group
|
||||
composition = ' '.join(names)
|
||||
post_reshape_pattern = f'{composition}->{right_pattern}'
|
||||
|
||||
self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {})
|
||||
|
||||
for axis in weight.identifiers:
|
||||
if axis not in axes_lengths:
|
||||
raise EinopsError('Dimension {} of weight should be specified'.format(axis))
|
||||
_report_axes(
|
||||
set.difference(set(axes_lengths), {*left.identifiers, *weight.identifiers}),
|
||||
'Axes {} are not used in pattern',
|
||||
)
|
||||
_report_axes(
|
||||
set.difference(weight.identifiers, {*left.identifiers, *right.identifiers}),
|
||||
'Weight axes {} are redundant'
|
||||
)
|
||||
if len(weight.identifiers) == 0:
|
||||
warnings.warn('EinMix: weight has no dimensions (means multiplication by a number)')
|
||||
|
||||
_weight_shape = [axes_lengths[axis] for axis, in weight.composition]
|
||||
# single output element is a combination of fan_in input elements
|
||||
_fan_in = _product([axes_lengths[axis] for axis, in weight.composition if axis not in right.identifiers])
|
||||
if bias_shape is not None:
|
||||
if not isinstance(bias_shape, str):
|
||||
raise EinopsError('bias shape should be string specifying which axes bias depends on')
|
||||
bias = ParsedExpression(bias_shape)
|
||||
_report_axes(
|
||||
set.difference(bias.identifiers, right.identifiers),
|
||||
'Bias axes {} not present in output'
|
||||
)
|
||||
_report_axes(
|
||||
set.difference(bias.identifiers, set(axes_lengths)),
|
||||
'Sizes not provided for bias axes {}',
|
||||
)
|
||||
|
||||
_bias_shape = []
|
||||
for axes in right.composition:
|
||||
for axis in axes:
|
||||
if axis in bias.identifiers:
|
||||
_bias_shape.append(axes_lengths[axis])
|
||||
else:
|
||||
_bias_shape.append(1)
|
||||
else:
|
||||
_bias_shape = None
|
||||
_bias_input_size = None
|
||||
|
||||
weight_bound = (3 / _fan_in) ** 0.5
|
||||
bias_bound = (1 / _fan_in) ** 0.5
|
||||
self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound)
|
||||
|
||||
# rewrite einsum expression with single-letter latin identifiers so that
|
||||
# expression will be understood by any framework
|
||||
mapping2letters = {*left.identifiers, *right.identifiers, *weight.identifiers}
|
||||
mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapping2letters)}
|
||||
|
||||
def write_flat(axes: list):
|
||||
return ''.join(mapping2letters[axis] for axis in axes)
|
||||
|
||||
self.einsum_pattern: str = '{},{}->{}'.format(
|
||||
write_flat(left.flat_axes_order()),
|
||||
write_flat(weight.flat_axes_order()),
|
||||
write_flat(right.flat_axes_order()),
|
||||
)
|
||||
|
||||
def _create_rearrange_layers(self,
|
||||
pre_reshape_pattern: Optional[str],
|
||||
pre_reshape_lengths: Optional[Dict],
|
||||
post_reshape_pattern: Optional[str],
|
||||
post_reshape_lengths: Optional[Dict]):
|
||||
raise NotImplementedError('Should be defined in framework implementations')
|
||||
|
||||
def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
|
||||
""" Shape and implementations """
|
||||
raise NotImplementedError('Should be defined in framework implementations')
|
||||
|
||||
def __repr__(self):
|
||||
params = repr(self.pattern)
|
||||
params += f", '{self.weight_shape}'"
|
||||
if self.bias_shape is not None:
|
||||
params += f", '{self.bias_shape}'"
|
||||
for axis, length in self.axes_lengths.items():
|
||||
params += ', {}={}'.format(axis, length)
|
||||
return '{}({})'.format(self.__class__.__name__, params)
|
|
@ -0,0 +1,64 @@
|
|||
from typing import Optional, Dict
|
||||
|
||||
import jittor as jt
|
||||
from jittor import nn
|
||||
import numpy as np
|
||||
|
||||
from jittor.einops.layers import RearrangeMixin, ReduceMixin
|
||||
from jittor.einops.layers._einmix import _EinmixMixin
|
||||
from jittor.einops._jittor_specific import apply_for_scriptable_jittor
|
||||
|
||||
__author__ = 'Ruiyang Liu'
|
||||
|
||||
|
||||
class Rearrange(RearrangeMixin, jt.nn.Module):
|
||||
def execute(self, input):
|
||||
return apply_for_scriptable_jittor(self._recipe, input, reduction_type='rearrange')
|
||||
|
||||
def _apply_recipe(self, x):
|
||||
# overriding parent method to prevent it's scripting
|
||||
pass
|
||||
|
||||
|
||||
class Reduce(ReduceMixin, jt.nn.Module):
|
||||
def execute(self, input):
|
||||
return apply_for_scriptable_jittor(self._recipe, input, reduction_type=self.reduction)
|
||||
|
||||
def _apply_recipe(self, x):
|
||||
# overriding parent method to prevent it's scripting
|
||||
pass
|
||||
|
||||
|
||||
class EinMix(_EinmixMixin, jt.nn.Module):
|
||||
def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
|
||||
self.weight = jt.zeros(weight_shape)
|
||||
nn.init.uniform_(self.weight, low = -weight_bound, high = weight_bound)
|
||||
if bias_shape is not None:
|
||||
self.bias = jt.zeros(bias_shape)
|
||||
nn.init.uniform_(self.bias, low = -bias_bound, high = bias_bound)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
def _create_rearrange_layers(self,
|
||||
pre_reshape_pattern: Optional[str],
|
||||
pre_reshape_lengths: Optional[Dict],
|
||||
post_reshape_pattern: Optional[str],
|
||||
post_reshape_lengths: Optional[Dict],
|
||||
):
|
||||
self.pre_rearrange = None
|
||||
if pre_reshape_pattern is not None:
|
||||
self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths)
|
||||
|
||||
self.post_rearrange = None
|
||||
if post_reshape_pattern is not None:
|
||||
self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths)
|
||||
|
||||
def execute(self, input):
|
||||
if self.pre_rearrange is not None:
|
||||
input = self.pre_rearrange(input)
|
||||
result = jt.linalg.einsum(self.einsum_pattern, input, self.weight)
|
||||
if self.bias is not None:
|
||||
result += self.bias
|
||||
if self.post_rearrange is not None:
|
||||
result = self.post_rearrange(result)
|
||||
return result
|
|
@ -0,0 +1,147 @@
|
|||
from jittor.einops import EinopsError
|
||||
import keyword
|
||||
import warnings
|
||||
from typing import List, Optional, Set, Tuple
|
||||
|
||||
_ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
|
||||
|
||||
|
||||
class AnonymousAxis(object):
|
||||
"""Important thing: all instances of this class are not equal to each other """
|
||||
|
||||
def __init__(self, value: str):
|
||||
self.value = int(value)
|
||||
if self.value <= 1:
|
||||
if self.value == 1:
|
||||
raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue')
|
||||
else:
|
||||
raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value))
|
||||
|
||||
def __repr__(self):
|
||||
return "{}-axis".format(str(self.value))
|
||||
|
||||
|
||||
class ParsedExpression:
|
||||
"""
|
||||
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
|
||||
and keeps some information important for downstream
|
||||
"""
|
||||
def __init__(self, expression, *, allow_underscore: bool = False):
|
||||
self.has_ellipsis: bool = False
|
||||
self.has_ellipsis_parenthesized: Optional[bool] = None
|
||||
self.identifiers: Set[str] = set()
|
||||
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
|
||||
self.has_non_unitary_anonymous_axes: bool = False
|
||||
# composition keeps structure of composite axes, see how different corner cases are handled in tests
|
||||
self.composition = []
|
||||
if '.' in expression:
|
||||
if '...' not in expression:
|
||||
raise EinopsError('Expression may contain dots only inside ellipsis (...)')
|
||||
if str.count(expression, '...') != 1 or str.count(expression, '.') != 3:
|
||||
raise EinopsError(
|
||||
'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ')
|
||||
expression = expression.replace('...', _ellipsis)
|
||||
self.has_ellipsis = True
|
||||
|
||||
bracket_group = None
|
||||
|
||||
def add_axis_name(x):
|
||||
if x is not None:
|
||||
if x in self.identifiers:
|
||||
if not (allow_underscore and x == "_"):
|
||||
raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
|
||||
if x == _ellipsis:
|
||||
self.identifiers.add(_ellipsis)
|
||||
if bracket_group is None:
|
||||
self.composition.append(_ellipsis)
|
||||
self.has_ellipsis_parenthesized = False
|
||||
else:
|
||||
bracket_group.append(_ellipsis)
|
||||
self.has_ellipsis_parenthesized = True
|
||||
else:
|
||||
is_number = str.isdecimal(x)
|
||||
if is_number and int(x) == 1:
|
||||
# handling the case of anonymous axis of length 1
|
||||
if bracket_group is None:
|
||||
self.composition.append([])
|
||||
else:
|
||||
pass # no need to think about 1s inside parenthesis
|
||||
return
|
||||
is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore)
|
||||
if not (is_number or is_axis_name):
|
||||
raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason))
|
||||
if is_number:
|
||||
x = AnonymousAxis(x)
|
||||
self.identifiers.add(x)
|
||||
if is_number:
|
||||
self.has_non_unitary_anonymous_axes = True
|
||||
if bracket_group is None:
|
||||
self.composition.append([x])
|
||||
else:
|
||||
bracket_group.append(x)
|
||||
|
||||
current_identifier = None
|
||||
for char in expression:
|
||||
if char in '() ':
|
||||
add_axis_name(current_identifier)
|
||||
current_identifier = None
|
||||
if char == '(':
|
||||
if bracket_group is not None:
|
||||
raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)")
|
||||
bracket_group = []
|
||||
elif char == ')':
|
||||
if bracket_group is None:
|
||||
raise EinopsError('Brackets are not balanced')
|
||||
self.composition.append(bracket_group)
|
||||
bracket_group = None
|
||||
elif str.isalnum(char) or char in ['_', _ellipsis]:
|
||||
if current_identifier is None:
|
||||
current_identifier = char
|
||||
else:
|
||||
current_identifier += char
|
||||
else:
|
||||
raise EinopsError("Unknown character '{}'".format(char))
|
||||
|
||||
if bracket_group is not None:
|
||||
raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression))
|
||||
add_axis_name(current_identifier)
|
||||
|
||||
def flat_axes_order(self) -> List:
|
||||
result = []
|
||||
for composed_axis in self.composition:
|
||||
assert isinstance(composed_axis, list), 'does not work with ellipsis'
|
||||
for axis in composed_axis:
|
||||
result.append(axis)
|
||||
return result
|
||||
|
||||
def has_composed_axes(self) -> bool:
|
||||
# this will ignore 1 inside brackets
|
||||
for axes in self.composition:
|
||||
if isinstance(axes, list) and len(axes) > 1:
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]:
|
||||
if not str.isidentifier(name):
|
||||
return False, 'not a valid python identifier'
|
||||
elif name[0] == '_' or name[-1] == '_':
|
||||
if name == '_' and allow_underscore:
|
||||
return True, ''
|
||||
return False, 'axis name should should not start or end with underscore'
|
||||
else:
|
||||
if keyword.iskeyword(name):
|
||||
warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning)
|
||||
if name in ['axis']:
|
||||
warnings.warn("It is discouraged to use 'axis' as an axis name "
|
||||
"and will raise an error in future", FutureWarning)
|
||||
return True, ''
|
||||
|
||||
@staticmethod
|
||||
def check_axis_name(name: str) -> bool:
|
||||
"""
|
||||
Valid axes names are python identifiers except keywords,
|
||||
and additionally should not start or end with underscore
|
||||
"""
|
||||
is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name)
|
||||
return is_valid
|
Loading…
Reference in New Issue