diff --git a/python/jittor/einops/__init__.py b/python/jittor/einops/__init__.py new file mode 100644 index 00000000..503dff42 --- /dev/null +++ b/python/jittor/einops/__init__.py @@ -0,0 +1,8 @@ +class EinopsError(RuntimeError): + """ Runtime error thrown by einops """ + pass + + +__all__ = ['rearrange', 'reduce', 'repeat', 'parse_shape', 'asnumpy', 'EinopsError'] + +from jittor.einops.einops import rearrange, reduce, repeat, parse_shape, asnumpy diff --git a/python/jittor/einops/_backends.py b/python/jittor/einops/_backends.py new file mode 100644 index 00000000..a04f34be --- /dev/null +++ b/python/jittor/einops/_backends.py @@ -0,0 +1,255 @@ +""" +Backends in `einops` are organized to meet the following requirements +- backends are not imported unless those are actually needed, because + - backends may not be installed + - importing all available backends will drive to significant memory footprint + - backends may by present but installed with errors (but never used), + importing may drive to crashes +- backend should be either symbolic or imperative (tensorflow is for both, but that causes problems) + - this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined +- if backend can't (temporarily) provide symbols for shape dimensions, UnknownSize objects are used +""" + +import sys +import warnings + +__author__ = 'Alex Rogozhnikov' + +_backends = {} +_debug_importing = False + + +def get_backend(tensor) -> 'AbstractBackend': + """ + Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor. + If needed, imports package and creates backend + """ + for framework_name, backend in _backends.items(): + if backend.is_appropriate_type(tensor): + return backend + + # Find backend subclasses recursively + backend_subclasses = [] + backends = AbstractBackend.__subclasses__() + while backends: + backend = backends.pop() + backends += backend.__subclasses__() + backend_subclasses.append(backend) + + for BackendSubclass in backend_subclasses: + if _debug_importing: + print('Testing for subclass of ', BackendSubclass) + if BackendSubclass.framework_name not in _backends: + # check that module was already imported. Otherwise it can't be imported + if BackendSubclass.framework_name in sys.modules: + if _debug_importing: + print('Imported backend for ', BackendSubclass.framework_name) + backend = BackendSubclass() + _backends[backend.framework_name] = backend + if backend.is_appropriate_type(tensor): + return backend + + raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor))) + + +class AbstractBackend: + """ Base backend class, major part of methods are only for debugging purposes. """ + framework_name = None + + def is_appropriate_type(self, tensor): + """ helper method should recognize tensors it can handle """ + raise NotImplementedError() + + def from_numpy(self, x): + raise NotImplementedError("framework doesn't support imperative execution") + + def to_numpy(self, x): + raise NotImplementedError("framework doesn't support imperative execution") + + def create_symbol(self, shape): + raise NotImplementedError("framework doesn't support symbolic computations") + + def eval_symbol(self, symbol, input_dict): + raise NotImplementedError("framework doesn't support symbolic computations") + + def arange(self, start, stop): + # supplementary method used only in testing, so should implement CPU version + raise NotImplementedError("framework doesn't implement arange") + + def shape(self, x): + """shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)""" + return x.shape + + def reshape(self, x, shape): + return x.reshape(shape) + + def transpose(self, x, axes): + return x.transpose(axes) + + def reduce(self, x, operation, axes): + return getattr(x, operation)(axis=axes) + + def stack_on_zeroth_dimension(self, tensors: list): + raise NotImplementedError() + + def add_axis(self, x, new_position): + raise NotImplementedError() + + def add_axes(self, x, n_axes, pos2len): + repeats = [1] * n_axes + for axis_position, axis_length in pos2len.items(): + x = self.add_axis(x, axis_position) + repeats[axis_position] = axis_length + return self.tile(x, tuple(repeats)) + + def tile(self, x, repeats): + """repeats is a number of """ + raise NotImplementedError() + + def is_float_type(self, x): + # some backends (torch) can't compute average for non-floating types. + # Decided to drop average for all backends if type is not floating + raise NotImplementedError() + + def layers(self): + raise NotImplementedError("backend does not provide layers") + + def __repr__(self): + return "".format(self.framework_name) + + +class UnknownSize: + """ pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """ + + def __floordiv__(self, other): + return self + + def __eq__(self, other): + return True # we don't know actual size + + def __mul__(self, other): + return self + + def __rmul__(self, other): + return self + + def __hash__(self): + return None.__hash__() + + +class NumpyBackend(AbstractBackend): + framework_name = 'numpy' + + def __init__(self): + import numpy + self.np = numpy + + def is_appropriate_type(self, tensor): + return isinstance(tensor, self.np.ndarray) + + def from_numpy(self, x): + return x + + def to_numpy(self, x): + return x + + def arange(self, start, stop): + return self.np.arange(start, stop) + + def stack_on_zeroth_dimension(self, tensors: list): + return self.np.stack(tensors) + + def tile(self, x, repeats): + return self.np.tile(x, repeats) + + def is_float_type(self, x): + return x.dtype in ('float16', 'float32', 'float64', 'float128') + + def add_axis(self, x, new_position): + return self.np.expand_dims(x, new_position) + +class HashableTuple: + """Overcomes non-hashability of symbolic elements""" + + def __init__(self, elements: tuple): + self.elements = elements + + def __iter__(self): + for x in self.elements: + yield x + + def __len__(self): + return len(self.elements) + + def __getitem__(self, item): + return self.elements[item] + +class JittorBackend(AbstractBackend): + framework_name = 'jittor' + + def __init__(self): + import jittor + self.jittor = jittor + + def is_appropriate_type(self, tensor): + return isinstance(tensor, self.jittor.jittor_core.Var) + + def from_numpy(self, x): + variable = self.jittor.array(x) + if self.is_float_type(variable): + # attach grad only to floating types + variable.requires_grad = True + return variable + + def to_numpy(self, x): + return x.detach().numpy() + + def arange(self, start, stop): + return self.jittor.arange(start, stop, dtype='int64') + + def shape(self, x): + return HashableTuple(tuple(x.shape)) + + def reshape(self, x, shape): + return self.jittor.reshape(x, shape) + + # def reduce(self, x, operation, axes): + # return getattr(x, operation)(dim=axes) + + def reduce(self, x, operation, reduced_axes): + for axis in sorted(reduced_axes, reverse=True): + if operation == 'min': + x = x.min(dim=axis) + elif operation == 'max': + x = x.max(dim=axis) + elif operation in ['sum', 'mean', 'prod']: + x = getattr(x, operation)(dim=axis) + else: + raise NotImplementedError('Unknown reduction ', operation) + return x + + def transpose(self, x, axes): + return x.permute(axes) + + def stack_on_zeroth_dimension(self, tensors: list): + return self.jittor.stack(tensors) + + def add_axes(self, x, n_axes, pos2len): + repeats = [-1] * n_axes + for axis_position, axis_length in pos2len.items(): + x = self.add_axis(x, axis_position) + repeats[axis_position] = axis_length + return x.expand(repeats) + + def tile(self, x, repeats): + return x.repeat(repeats) + + def add_axis(self, x, new_position): + return self.jittor.unsqueeze(x, new_position) + + def is_float_type(self, x): + return x.dtype in ["float16", "float32", "float64"] + + def layers(self): + from jittor.einops.layers import jittor + return jittor \ No newline at end of file diff --git a/python/jittor/einops/_jittor_specific.py b/python/jittor/einops/_jittor_specific.py new file mode 100644 index 00000000..d21ff76d --- /dev/null +++ b/python/jittor/einops/_jittor_specific.py @@ -0,0 +1,84 @@ +""" +Specialization of einops for jittor. + +Unfortunately, jittor's jit scripting mechanism isn't strong enough, +and to have scripting supported at least for layers, +a number of changes is required, and this layer helps. + +Importantly, whole lib is designed so that you can't use it +""" + +from typing import Dict, List + +import jittor as jt +from jittor.einops.einops import TransformRecipe, _reconstruct_from_shape_uncached + + +class JittorJitBackend: + """ + Completely static backend that mimics part of normal backend functionality + but restricted to jittor stuff only + """ + + @staticmethod + def reduce(x: jt.jittor_core.Var, operation: str, reduced_axes: List[int]): + if operation == 'min': + return x.min(dims=reduced_axes) + elif operation == 'max': + return x.max(dims=reduced_axes) + elif operation == 'sum': + return x.sum(dims=reduced_axes) + elif operation == 'mean': + return x.mean(dims=reduced_axes) + elif operation == 'prod': + for i in list(sorted(reduced_axes))[::-1]: + x = x.prod(dim=i) + return x + else: + raise NotImplementedError('Unknown reduction ', operation) + + @staticmethod + def transpose(x, axes: List[int]): + return x.permute(axes) + + @staticmethod + def stack_on_zeroth_dimension(tensors: List[jt.jittor_core.Var]): + return jt.stack(tensors) + + @staticmethod + def tile(x, repeats: List[int]): + return x.repeat(repeats) + + @staticmethod + def add_axes(x, n_axes: int, pos2len: Dict[int, int]): + repeats = [1] * n_axes + for axis_position, axis_length in pos2len.items(): + x = jt.unsqueeze(x, axis_position) + repeats[axis_position] = axis_length + return JittorJitBackend.tile(x, repeats) + + @staticmethod + def is_float_type(x): + return x.dtype in ["float16", "float32", "float64"] + + @staticmethod + def shape(x): + return x.shape + + @staticmethod + def reshape(x, shape: List[int]): + return x.reshape(shape) + + +# mirrors einops.einops._apply_recipe +def apply_for_scriptable_jittor(recipe: TransformRecipe, tensor: jt.jittor_core.Var, reduction_type: str) -> jt.jittor_core.Var: + backend = JittorJitBackend + init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \ + _reconstruct_from_shape_uncached(recipe, backend.shape(tensor)) + tensor = backend.reshape(tensor, init_shapes) + if len(reduced_axes) > 0: + tensor = backend.reduce(tensor, operation=reduction_type, reduced_axes=reduced_axes) + tensor = backend.transpose(tensor, axes_reordering) + if len(added_axes) > 0: + tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes) + return backend.reshape(tensor, final_shapes) diff --git a/python/jittor/einops/einops.py b/python/jittor/einops/einops.py new file mode 100644 index 00000000..21fafd77 --- /dev/null +++ b/python/jittor/einops/einops.py @@ -0,0 +1,625 @@ +import functools +import itertools +import typing +from collections import OrderedDict +from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar + +if typing.TYPE_CHECKING: + import numpy as np + +from jittor.einops import EinopsError +from jittor.einops._backends import get_backend +from jittor.einops.parsing import ParsedExpression, _ellipsis, AnonymousAxis + +Tensor = TypeVar('Tensor') +ReductionCallable = Callable[[Tensor, List[int]], Tensor] +Reduction = Union[str, ReductionCallable] + +_reductions = ('min', 'max', 'sum', 'mean', 'prod') +_ellipsis_not_in_parenthesis: List[int] = [-999] +_unknown_axis_length = -999999 + + +def is_ellipsis_not_in_parenthesis(group: List[int]) -> bool: + if len(group) != 1: + return False + return group[0] == -999 + + +def _product(sequence: List[int]) -> int: + """ minimalistic product that works both with numbers and symbols. Supports empty lists """ + result = 1 + for element in sequence: + result *= element + return result + + +def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend): + reduced_axes = tuple(reduced_axes) + if callable(reduction_type): + # custom callable + return reduction_type(tensor, reduced_axes) + else: + # one of built-in operations + if len(reduced_axes) == 0: + return tensor + assert reduction_type in _reductions + if reduction_type == 'mean': + if not backend.is_float_type(tensor): + raise NotImplementedError('reduce_mean is not available for non-floating tensors') + return backend.reduce(tensor, reduction_type, reduced_axes) + + +def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes): + # 'collapses' neighboring axes if those participate in the result pattern in the same order + # TODO add support for added_axes + assert len(axes_reordering) + len(reduced_axes) == len(init_shapes) + # joining consecutive axes that will be reduced + # possibly we can skip this if all backends can optimize this (not sure) + reduced_axes = tuple(sorted(reduced_axes)) + for i in range(len(reduced_axes) - 1)[::-1]: + if reduced_axes[i] + 1 == reduced_axes[i + 1]: + removed_axis = reduced_axes[i + 1] + removed_length = init_shapes[removed_axis] + init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] + init_shapes[removed_axis - 1] *= removed_length + reduced_axes = reduced_axes[:i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:]) + + # removing axes that are moved together during reshape + def build_mapping(): + init_to_final = {} + for axis in range(len(init_shapes)): + if axis in reduced_axes: + init_to_final[axis] = None + else: + after_reduction = sum(x is not None for x in init_to_final.values()) + init_to_final[axis] = list(axes_reordering).index(after_reduction) + return init_to_final + + init_axis_to_final_axis = build_mapping() + + for init_axis in range(len(init_shapes) - 1)[::-1]: + if init_axis_to_final_axis[init_axis] is None: + continue + if init_axis_to_final_axis[init_axis + 1] is None: + continue + if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]: + removed_axis = init_axis + 1 + removed_length = init_shapes[removed_axis] + removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis)) + + reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes) + init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:] + init_shapes[removed_axis - 1] *= removed_length + old_reordering = axes_reordering + axes_reordering = [] + for axis in old_reordering: + if axis == removed_axis_after_reduction: + pass + elif axis < removed_axis_after_reduction: + axes_reordering.append(axis) + else: + axes_reordering.append(axis - 1) + init_axis_to_final_axis = build_mapping() + + return init_shapes, reduced_axes, axes_reordering, final_shapes + + +CookedRecipe = Tuple[List[int], List[int], List[int], Dict[int, int], List[int]] + + +class TransformRecipe: + """ + Recipe describes actual computation pathway. + Recipe can be applied to a tensor or variable. + """ + + # structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+) + + def __init__(self, + # list of expressions (or just sizes) for elementary axes as they appear in left expression. + # this is what (after computing unknown parts) will be a shape after first transposition. + # If ellipsis is present, it forms one dimension here (in the right position). + elementary_axes_lengths: List[int], + # each dimension in input can help to reconstruct length of one elementary axis + # or verify one of dimensions. Each element points to element of elementary_axes_lengths + input_composite_axes: List[Tuple[List[int], List[int]]], + # indices of axes to be squashed + reduced_elementary_axes: List[int], + # in which order should axes be reshuffled after reduction + axes_permutation: List[int], + # at which positions which of elementary axes should appear + added_axes: Dict[int, int], + # ids of axes as they appear in result, again pointers to elementary_axes_lengths, + # only used to infer result dimensions + output_composite_axes: List[List[int]], + # positions of ellipsis in lhs and rhs of expression + ellipsis_position_in_lhs: Optional[int] = None, + ): + self.elementary_axes_lengths: List[int] = elementary_axes_lengths + self.input_composite_axes: List[Tuple[List[int], List[int]]] = input_composite_axes + self.output_composite_axes: List[List[int]] = output_composite_axes + self.axes_permutation: List[int] = axes_permutation + self.added_axes: Dict[int, int] = added_axes + # This is redundant information, but more convenient to use + self.reduced_elementary_axes: List[int] = reduced_elementary_axes + # setting to a large number to avoid handling Nones in reconstruct_from_shape + self.ellipsis_position_in_lhs: int = ellipsis_position_in_lhs if ellipsis_position_in_lhs is not None else 10000 + + +def _reconstruct_from_shape_uncached(self: TransformRecipe, shape: List[int]) -> CookedRecipe: + """ + Reconstruct all actual parameters using shape. + Shape is a tuple that may contain integers, shape symbols (tf, keras, theano) and UnknownSize (keras, mxnet) + known axes can be integers or symbols, but not Nones. + """ + axes_lengths: List[int] = list(self.elementary_axes_lengths) + if self.ellipsis_position_in_lhs != 10000: + if len(shape) < len(self.input_composite_axes) - 1: + raise EinopsError('Expected at least {} dimensions, got {}'.format( + len(self.input_composite_axes) - 1, len(shape))) + else: + if len(shape) != len(self.input_composite_axes): + raise EinopsError('Expected {} dimensions, got {}'.format(len(self.input_composite_axes), len(shape))) + + ellipsis_shape: List[int] = [] + for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composite_axes): + before_ellipsis = input_axis + after_ellipsis = input_axis + len(shape) - len(self.input_composite_axes) + if input_axis == self.ellipsis_position_in_lhs: + assert len(known_axes) == 0 and len(unknown_axes) == 1 + unknown_axis, = unknown_axes + ellipsis_shape = shape[before_ellipsis:after_ellipsis + 1] + for d in ellipsis_shape: + if d is None: + raise EinopsError("Couldn't infer shape for one or more axes represented by ellipsis") + total_dim_size: int = _product(ellipsis_shape) + axes_lengths[unknown_axis] = total_dim_size + else: + if input_axis < self.ellipsis_position_in_lhs: + length = shape[before_ellipsis] + else: + length = shape[after_ellipsis] + known_product = 1 + for axis in known_axes: + known_product *= axes_lengths[axis] + + if len(unknown_axes) == 0: + if isinstance(length, int) and isinstance(known_product, int) and length != known_product: + raise EinopsError('Shape mismatch, {} != {}'.format(length, known_product)) + # this is enforced when recipe is created + # elif len(unknown_axes) > 1: + # raise EinopsError( + # "Lengths of two or more axes in parenthesis not provided (dim={}), can't infer dimensions". + # format(known_product) + # ) + else: + if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0: + raise EinopsError("Shape mismatch, can't divide axis of length {} in chunks of {}".format( + length, known_product)) + + unknown_axis: int = unknown_axes[0] + inferred_length: int = length // known_product + axes_lengths[unknown_axis] = inferred_length + + # at this point all axes_lengths are computed (either have values or variables, but not Nones) + + # TODO more readable expression + init_shapes = axes_lengths[:len(axes_lengths) - len(self.added_axes)] + final_shapes: List[int] = [] + for output_axis, grouping in enumerate(self.output_composite_axes): + if is_ellipsis_not_in_parenthesis(grouping): + final_shapes.extend(ellipsis_shape) + else: + lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping] + final_shapes.append(_product(lengths)) + reduced_axes = self.reduced_elementary_axes + axes_reordering = self.axes_permutation + added_axes: Dict[int, int] = { + pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items()} + # if optimize: + # assert len(self.added_axes) == 0 + # return _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes) + return init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes + + +_reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached) + + +def _apply_recipe(recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction) -> Tensor: + # this method works for all backends but not compilable with + backend = get_backend(tensor) + init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \ + _reconstruct_from_shape(recipe, backend.shape(tensor)) + tensor = backend.reshape(tensor, init_shapes) + tensor = _reduce_axes(tensor, reduction_type=reduction_type, reduced_axes=reduced_axes, backend=backend) + tensor = backend.transpose(tensor, axes_reordering) + if len(added_axes) > 0: + tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes) + return backend.reshape(tensor, final_shapes) + + +@functools.lru_cache(256) +def _prepare_transformation_recipe(pattern: str, + operation: Reduction, + axes_lengths: Tuple[Tuple, ...]) -> TransformRecipe: + """ Perform initial parsing of pattern and provided supplementary info + axes_lengths is a tuple of tuples (axis_name, axis_length) + """ + left, rght = pattern.split('->') + left = ParsedExpression(left) + rght = ParsedExpression(rght) + + # checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction + if not left.has_ellipsis and rght.has_ellipsis: + raise EinopsError('Ellipsis found in right side, but not left side of a pattern {}'.format(pattern)) + if left.has_ellipsis and left.has_ellipsis_parenthesized: + raise EinopsError('Ellipsis is parenthesis in the left side is not allowed: {}'.format(pattern)) + if operation == 'rearrange': + difference = set.symmetric_difference(left.identifiers, rght.identifiers) + if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes: + raise EinopsError('Non-unitary anonymous axes are not supported in rearrange (exception is length 1)') + if len(difference) > 0: + raise EinopsError('Identifiers only on one side of expression (should be on both): {}'.format(difference)) + elif operation == 'repeat': + difference = set.difference(left.identifiers, rght.identifiers) + if len(difference) > 0: + raise EinopsError('Unexpected identifiers on the left side of repeat: {}'.format(difference)) + axes_without_size = set.difference({ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)}, + {*left.identifiers, *(ax for ax, _ in axes_lengths)}) + if len(axes_without_size) > 0: + raise EinopsError('Specify sizes for new axes in repeat: {}'.format(axes_without_size)) + elif operation in _reductions or callable(operation): + difference = set.difference(rght.identifiers, left.identifiers) + if len(difference) > 0: + raise EinopsError('Unexpected identifiers on the right side of reduce {}: {}'.format(operation, difference)) + else: + raise EinopsError('Unknown reduction {}. Expect one of {}.'.format(operation, _reductions)) + + # parsing all dimensions to find out lengths + axis_name2known_length = OrderedDict() + for composite_axis in left.composition: + for axis_name in composite_axis: + if isinstance(axis_name, AnonymousAxis): + axis_name2known_length[axis_name] = axis_name.value + else: + axis_name2known_length[axis_name] = _unknown_axis_length + + # axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point + + repeat_axes_names = [] + for axis_name in rght.identifiers: + if axis_name not in axis_name2known_length: + if isinstance(axis_name, AnonymousAxis): + axis_name2known_length[axis_name] = axis_name.value + else: + axis_name2known_length[axis_name] = _unknown_axis_length + repeat_axes_names.append(axis_name) + + axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)} + reduced_axes: List[int] = [position for axis, position in axis_name2position.items() if + axis not in rght.identifiers] + reduced_axes: List[int] = list(sorted(reduced_axes)) + + for elementary_axis, axis_length in axes_lengths: + if not ParsedExpression.check_axis_name(elementary_axis): + raise EinopsError('Invalid name for an axis', elementary_axis) + if elementary_axis not in axis_name2known_length: + raise EinopsError('Axis {} is not used in transform'.format(elementary_axis)) + axis_name2known_length[elementary_axis] = axis_length + + input_axes_known_unknown = [] + # some of shapes will be inferred later - all information is prepared for faster inference + for composite_axis in left.composition: + known = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length} + unknown = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length} + if len(unknown) > 1: + raise EinopsError('Could not infer sizes for {}'.format(unknown)) + assert len(unknown) + len(known) == len(composite_axis) + input_axes_known_unknown.append( + ([axis_name2position[axis] for axis in known], + [axis_name2position[axis] for axis in unknown]) + ) + + axis_position_after_reduction = {} + for axis_name in itertools.chain(*left.composition): + if axis_name in rght.identifiers: + axis_position_after_reduction[axis_name] = len(axis_position_after_reduction) + + result_axes_grouping: List[List[int]] = [] + for composite_axis in rght.composition: + if composite_axis == _ellipsis: + result_axes_grouping.append(_ellipsis_not_in_parenthesis) + else: + result_axes_grouping.append([axis_name2position[axis] for axis in composite_axis]) + + ordered_axis_right = list(itertools.chain(*rght.composition)) + axes_permutation = [ + axis_position_after_reduction[axis] for axis in ordered_axis_right if axis in left.identifiers] + added_axes = {i: axis_name2position[axis_name] for i, axis_name in enumerate(ordered_axis_right) + if axis_name not in left.identifiers} + + ellipsis_left = None if _ellipsis not in left.composition else left.composition.index(_ellipsis) + + return TransformRecipe( + elementary_axes_lengths=list(axis_name2known_length.values()), + input_composite_axes=input_axes_known_unknown, + reduced_elementary_axes=reduced_axes, + axes_permutation=axes_permutation, + added_axes=added_axes, + output_composite_axes=result_axes_grouping, + ellipsis_position_in_lhs=ellipsis_left, + ) + + +def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor: + """ + einops.reduce provides combination of reordering and reduction using reader-friendly notation. + + Examples for reduce operation: + + ```python + >>> x = np.random.randn(100, 32, 64) + + # perform max-reduction on the first axis + >>> y = reduce(x, 't b c -> b c', 'max') + + # same as previous, but with clearer axes meaning + >>> y = reduce(x, 'time batch channel -> batch channel', 'max') + + >>> x = np.random.randn(10, 20, 30, 40) + + # 2d max-pooling with kernel size = 2 * 2 for image processing + >>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2) + + # if one wants to go back to the original height and width, depth-to-space trick can be applied + >>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2) + >>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w') + + # Adaptive 2d max-pooling to 3 * 4 grid + >>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape + (10, 20, 3, 4) + + # Global average pooling + >>> reduce(x, 'b c h w -> b c', 'mean').shape + (10, 20) + + # Subtracting mean over batch for each channel + >>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean') + + # Subtracting per-image mean for each channel + >>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean') + + ``` + + Parameters: + tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, reduction pattern + reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive + alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided. + axes_lengths: any additional specifications for dimensions + + Returns: + tensor of the same type as input + """ + try: + hashable_axes_lengths = tuple(sorted(axes_lengths.items())) + recipe = _prepare_transformation_recipe(pattern, reduction, axes_lengths=hashable_axes_lengths) + return _apply_recipe(recipe, tensor, reduction_type=reduction) + except EinopsError as e: + message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern) + if not isinstance(tensor, list): + message += '\n Input tensor shape: {}. '.format(get_backend(tensor).shape(tensor)) + else: + message += '\n Input is list. ' + message += 'Additional info: {}.'.format(axes_lengths) + raise EinopsError(message + '\n {}'.format(e)) + + + +@typing.overload +def rearrange(tensor: Tensor, pattern: str, **axes_length: int) -> Tensor: ... +@typing.overload +def rearrange(tensor: List[Tensor], pattern: str, **axes_lengths: int) -> Tensor: ... + + +def rearrange(tensor, pattern: str, **axes_lengths): + """ + einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors. + This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze, + stack, concatenate and other operations. + + Examples for rearrange operation: + + ```python + # suppose we have a set of 32 images in "h w c" format (height-width-channel) + >>> images = [np.random.randn(30, 40, 3) for _ in range(32)] + + # stack along first (batch) axis, output is a single array + >>> rearrange(images, 'b h w c -> b h w c').shape + (32, 30, 40, 3) + + # concatenate images along height (vertical axis), 960 = 32 * 30 + >>> rearrange(images, 'b h w c -> (b h) w c').shape + (960, 40, 3) + + # concatenated images along horizontal axis, 1280 = 32 * 40 + >>> rearrange(images, 'b h w c -> h (b w) c').shape + (30, 1280, 3) + + # reordered axes to "b c h w" format for deep learning + >>> rearrange(images, 'b h w c -> b c h w').shape + (32, 3, 30, 40) + + # flattened each image into a vector, 3600 = 30 * 40 * 3 + >>> rearrange(images, 'b h w c -> b (c h w)').shape + (32, 3600) + + # split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2 + >>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape + (128, 15, 20, 3) + + # space-to-depth operation + >>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape + (32, 15, 20, 12) + + ``` + + When composing axes, C-order enumeration used (consecutive elements have different last axis) + Find more examples in einops tutorial. + + Parameters: + tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, rearrangement pattern + axes_lengths: any additional specifications for dimensions + + Returns: + tensor of the same type as input. If possible, a view to the original tensor is returned. + + """ + if isinstance(tensor, list): + if len(tensor) == 0: + raise TypeError("Rearrange can't be applied to an empty list") + tensor = get_backend(tensor[0]).stack_on_zeroth_dimension(tensor) + return reduce(tensor, pattern, reduction='rearrange', **axes_lengths) + + +def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor: + """ + einops.repeat allows reordering elements and repeating them in arbitrary combinations. + This operation includes functionality of repeat, tile, broadcast functions. + + Examples for repeat operation: + + ```python + # a grayscale image (of shape height x width) + >>> image = np.random.randn(30, 40) + + # change it to RGB format by repeating in each channel + >>> repeat(image, 'h w -> h w c', c=3).shape + (30, 40, 3) + + # repeat image 2 times along height (vertical axis) + >>> repeat(image, 'h w -> (repeat h) w', repeat=2).shape + (60, 40) + + # repeat image 2 time along height and 3 times along width + >>> repeat(image, 'h w -> h (repeat w)', repeat=3).shape + (30, 120) + + # convert each pixel to a small square 2x2. Upsample image by 2x + >>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape + (60, 80) + + # pixelate image first by downsampling by 2x, then upsampling + >>> downsampled = reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2) + >>> repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape + (30, 40) + + ``` + + When composing axes, C-order enumeration used (consecutive elements have different last axis) + Find more examples in einops tutorial. + + Parameters: + tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array). + list of tensors is also accepted, those should be of the same type and shape + pattern: string, rearrangement pattern + axes_lengths: any additional specifications for dimensions + + Returns: + Tensor of the same type as input. If possible, a view to the original tensor is returned. + + """ + return reduce(tensor, pattern, reduction='repeat', **axes_lengths) + + +def parse_shape(x, pattern: str): + """ + Parse a tensor shape to dictionary mapping axes names to their lengths. + + ```python + # Use underscore to skip the dimension in parsing. + >>> x = np.zeros([2, 3, 5, 7]) + >>> parse_shape(x, 'batch _ h w') + {'batch': 2, 'h': 5, 'w': 7} + + # `parse_shape` output can be used to specify axes_lengths for other operations: + >>> y = np.zeros([700]) + >>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape + (2, 10, 5, 7) + + ``` + + For symbolic frameworks may return symbols, not integers. + + Parameters: + x: tensor of any of supported frameworks + pattern: str, space separated names for axes, underscore means skip axis + + Returns: + dict, maps axes names to their lengths + """ + exp = ParsedExpression(pattern, allow_underscore=True) + shape = get_backend(x).shape(x) + if exp.has_composed_axes(): + raise RuntimeError("Can't parse shape with composite axes: {pattern} {shape}".format( + pattern=pattern, shape=shape)) + if len(shape) != len(exp.composition): + if exp.has_ellipsis: + if len(shape) < len(exp.composition) - 1: + raise RuntimeError("Can't parse shape with this number of dimensions: {pattern} {shape}".format( + pattern=pattern, shape=shape)) + else: + raise RuntimeError("Can't parse shape with different number of dimensions: {pattern} {shape}".format( + pattern=pattern, shape=shape)) + if exp.has_ellipsis: + ellipsis_idx = exp.composition.index(_ellipsis) + composition = (exp.composition[:ellipsis_idx] + + ['_'] * (len(shape) - len(exp.composition) + 1) + + exp.composition[ellipsis_idx+1:]) + else: + composition = exp.composition + result = {} + for (axis_name, ), axis_length in zip(composition, shape): + if axis_name != '_': + result[axis_name] = axis_length + return result + + +# this one is probably not needed in the public API +def _enumerate_directions(x): + """ + For an n-dimensional tensor, returns tensors to enumerate each axis. + ```python + x = np.zeros([2, 3, 4]) # or any other tensor + i, j, k = _enumerate_directions(x) + result = i + 2*j + 3*k + ``` + + `result[i, j, k] = i + 2j + 3k`, and also has the same shape as result + Works very similarly to numpy.ogrid (open indexing grid) + """ + backend = get_backend(x) + shape = backend.shape(x) + result = [] + for axis_id, axis_length in enumerate(shape): + shape = [1] * len(shape) + shape[axis_id] = axis_length + result.append(backend.reshape(backend.arange(0, axis_length), shape)) + return result + + +def asnumpy(tensor) -> 'numpy.ndarray': + """ + Convert a tensor of an imperative framework (i.e. numpy/jittor.) to `numpy.ndarray` + + Parameters: + tensor: tensor of any of known imperative framework + + Returns: + `numpy.ndarray`, converted to numpy + """ + return get_backend(tensor).to_numpy(tensor) diff --git a/python/jittor/einops/experimental/__init__.py b/python/jittor/einops/experimental/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/jittor/einops/experimental/indexing.py b/python/jittor/einops/experimental/indexing.py new file mode 100644 index 00000000..4ba9e9de --- /dev/null +++ b/python/jittor/einops/experimental/indexing.py @@ -0,0 +1,393 @@ +""" + +Indexing one array with the other(s). + +Concept for discussion. + +Notation targets hard cases, not simple ones, like indexing of 1d-array with another 1d-array +(notation supports that, but you can't simplify arr[ind], and there is no reason to) + +Examples + +1. query for every token in sequence a token in the image. Images and sequences are paired + einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, [h_indices_bt, w_indices_bt]) + + this is equivalent, so you can pass indexers idependently or together + einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, np.asarray([h_indices_bt, w_indices_bt])) + + after some thinking I decided that having first axis for indexing variable is not too restrictive, + but should simplify mapping of such cases. + For this reason [...] part should always go first in indexer. + + This makes the largest difference with einindex https://github.com/malmaud/einindex, + which has almost identical grammar, but puts special dimension last, while we put it first. + This trick allows naturally decomposing multiindex into individual dimensions or visa versa. + + +2. query for every token in the video the most suitable word in a (matching) sentence + einindex('b t h w <- seq b, [seq] t b h w', arr_tbc, [t_indices_bhw]) + + note, that only one indexer is used, but still it has to be enclosed in the list. + That's a price for being generic. Alternatively leading singleton dimension can be added. + + +3. (not supported now, future planning) + for every timeframe in a video, find the token with the highest norm (across h and w), and compose a new stack of them + indices_2bt = argmax(x_bthwc.norm(dim=-1), 'b t h w -> [h, w] b t') + selected_embeddings_btc = einindex('b t c <- b t h w c, [h, w] b t', x_bthwc, indices_2bt) + + while currently question is around 'how do we index', + it is important to pre-align that with a question 'what are natural ways to get indices'. + Most common are min/max. less common options: topk (works here), random sampling. + + + +Some important properties of this notation: +- support for multiple indexers, including using a single tensor to keep multiple indexers +- 'batch' indexing, when some axes of indexer and array should be matched +- universal (one-indexing-to-rule-them-all) +- extensible for (named) ellipses, including variadic number of indexers +- extensible for einops-style compositions and decompositions +- extensible for outer indexing when indexers are not aligned + +Current implementation based on python array api and uses loops, +because no appropriate indexing available in the standard. + +""" + +from typing import List, Union, TypeVar, Tuple + +from jittor.einops import EinopsError + +T = TypeVar('T') + + +class CompositionDecomposition: + def __init__( + self, + decomposed_shape: List[str], + composed_shape: List[List[str]], + ): + flat_shape = [] + for x in composed_shape: + flat_shape.extend(x) + + self.compose_transposition: Tuple[int] = tuple([decomposed_shape.index(x) for x in flat_shape]) + self.decompose_transposition: Tuple[int] = tuple([flat_shape.index(x) for x in decomposed_shape]) + self.composed_shape = composed_shape + self.decomposed_shape = decomposed_shape + + def decompose(self, x, known_axes_lengths: dict[str, int]): + xp = x.__array_namespace__() + shape = x.shape + + flat_shape = [] + + for i, axis_group in enumerate(self.composed_shape): + unknown_axis_name = None + known_sizes_prod = 1 + for axis_name in axis_group: + if axis_name in known_axes_lengths: + known_sizes_prod *= known_axes_lengths[axis_name] + else: + if unknown_axis_name is None: + unknown_axis_name = axis_name + else: + raise EinopsError("Can't infer the size") + + if unknown_axis_name is None: + assert shape[i] == known_sizes_prod + else: + known_axes_lengths[unknown_axis_name] = shape[i] // known_sizes_prod + + for axis in axis_group: + flat_shape.append(known_axes_lengths[axis]) + + x = xp.reshape(x, flat_shape) + return xp.permute_dims(x, self.decompose_transposition) + + def compose(self, x, known_axes_lengths: dict[str, int]): + xp = x.__array_namespace__() + + for axis_len, axis_name in zip(x.shape, self.decomposed_shape): + if axis_name in known_axes_lengths: + assert known_axes_lengths[axis_name] == axis_len + else: + known_axes_lengths[axis_name] = axis_len + + x = xp.permute_dims(x, self.compose_transposition) + new_shape = [] + for axis_group in self.composed_shape: + composed_axis_size = 1 + for axis_name in axis_group: + composed_axis_size *= known_axes_lengths[axis_name] + new_shape.append(composed_axis_size) + + return xp.reshape(x, tuple(new_shape)) + + +def arange_at_position(xp, n_axes, axis, axis_len, device=None): + x = xp.arange(axis_len, dtype=xp.int64, device=device) + shape = [1] * n_axes + shape[axis] = axis_len + x = xp.reshape(x, shape) + return x + + +class IndexingFormula: + + def __init__(self, pattern: str): + """ + :param pattern: example 'b t c <- b hsel wsel c, [hsel, wsel] b t' + """ + self.pattern = pattern + left, right = pattern.split('<-') + arg_split = right.index(',') + arr_pattern, ind_pattern = right[:arg_split], right[arg_split + 1:] + ind_pattern = ind_pattern.strip() + # print( + # arr_pattern, '\n', + # ind_pattern, + # ) + assert ind_pattern.startswith('['), 'composition axis should go first in indexer (second argument) [h w] i j k' + composition_start = ind_pattern.index('[') + composition_end = ind_pattern.index(']') + composition = ind_pattern[composition_start + 1: composition_end] + ind_other_axes = ind_pattern[composition_end + 1:] + + self.result_axes_names = left.split() + self.array_axes_names = arr_pattern.split() + self.indexing_axes_names = [x.strip() for x in composition.split(',')] + self.indexer_other_axes_names = ind_other_axes.split() + + for group_name, group in [ + ('result', self.result_axes_names), + ('array', self.array_axes_names), + ('indexer', self.indexing_axes_names + self.indexer_other_axes_names), + ]: + if len(set(group)) != len(group): + # need more verbosity, which axis, raise + raise EinopsError(f'{group_name} pattern ({group}) contains a duplicated axis') + + axis_groups = [ + self.result_axes_names, + self.array_axes_names, + self.indexing_axes_names, + self.indexer_other_axes_names, + ] + + all_axes = set() + for group in axis_groups: + all_axes.update(group) + + self.indexer_axes = [] + self.batch_axes = [] + self.result_and_index_axes = [] + self.result_and_array_axes = [] + + for axis in all_axes: + presence = tuple(axis in g for g in axis_groups) + # want match-case here. sweet dreams + if presence == (False, True, True, False): + self.indexer_axes.append(axis) + elif presence[2]: + raise EinopsError(f'Wrong usage of indexer variable {axis}') + elif presence == (True, True, False, True): + self.batch_axes.append(axis) + elif presence == (True, False, False, True): + self.result_and_index_axes.append(axis) + elif presence == (True, True, False, False): + self.result_and_array_axes.append(axis) + else: + # TODO better categorization of wrong usage patterns + raise EinopsError(f'{axis} is used incorrectly in {pattern}') + + assert set(self.indexer_axes) == set(self.indexing_axes_names) + # order of these variables matters, since we can't lose mapping here + self.indexer_axes = self.indexing_axes_names + + self.array_composition = CompositionDecomposition( + decomposed_shape=self.array_axes_names, + composed_shape=[self.batch_axes + self.indexer_axes, self.result_and_array_axes], + ) + + self.index_composition = CompositionDecomposition( + decomposed_shape=self.indexer_other_axes_names, + # single axis after composition + composed_shape=[self.batch_axes + self.result_and_index_axes], + ) + + self.result_composition = CompositionDecomposition( + decomposed_shape=self.result_axes_names, + composed_shape=[self.batch_axes + self.result_and_index_axes, self.result_and_array_axes], + ) + + def apply_to_array_api(self, arr: T, ind: Union[T, List[T]]): + known_axes_sizes: dict[str, int] = {} + xp = arr.__array_namespace__() + + if not isinstance(ind, list): + ind = [ind[i, ...] for i in range(ind.shape[0])] + + for indexer in ind: + assert len(indexer.shape) == len(self.indexer_other_axes_names) + + # step 1. transpose, reshapes of arr; learn its dimensions + arr_2d = self.array_composition.compose(arr, known_axes_sizes) + + # step 2. compute shifts and create an actual indexing array + shift = 1 + full_index = xp.zeros([1] * len(ind[0].shape), dtype=xp.int64, device=arr.device) + + # original order: [*batch-like axes, *indexing_axes,] + # now we need to traverse them in the opposite direction + + for axis_name, indexer in list(zip(self.indexing_axes_names, ind))[::-1]: + full_index = full_index + shift * (indexer % known_axes_sizes[axis_name]) + shift *= known_axes_sizes[axis_name] + + for axis_name in self.batch_axes[::-1]: + axis_id = self.indexer_other_axes_names.index(axis_name) + full_index = full_index + arange_at_position( + xp, len(self.indexer_other_axes_names), axis=axis_id, axis_len=known_axes_sizes[axis_name], + device=arr.device, + ) * shift + shift *= known_axes_sizes[axis_name] + + assert shift == arr_2d.shape[0] + + # step 3. Flatten index + full_index = self.index_composition.compose(full_index, known_axes_sizes) + + # step 4. indexing + # python array api lacks any integer indexing, so... I use loops. + # did you know that there is conceptual programming ... just like art? + # result_2d = arr_2d[full_index] + result_2d = xp.stack([arr_2d[full_index[i], :] for i in range(full_index.shape[0])]) + + # step 5. doing resulting + result = self.result_composition.decompose(result_2d, known_axes_sizes) + return result + + +def einindex(pattern: str, arr: T, /, ind: Union[T, List[T]]): + """ + Demonstrates how einindex should work. + Supports data-api compliant arrays. + """ + formula = IndexingFormula(pattern) + return formula.apply_to_array_api(arr, ind) + + +def test_composition_and_decomposition(): + import numpy.array_api as np + x = np.arange(2 * 3 * 5 * 7) + x = np.reshape(x, (2, 3, 5, 7)) + comp = CompositionDecomposition( + decomposed_shape=['a', 'b', 'c', 'd'], + composed_shape=[['a', 'b'], ['c', 'd']], + ) + assert comp.compose(x, known_axes_lengths={}).shape == (2 * 3, 5 * 7) + + y = CompositionDecomposition( + decomposed_shape=['a', 'b', 'c', 'd'], + composed_shape=[['a', 'b'], [], ['c', 'd']], + ).compose(x, {}) + assert y.shape == (2 * 3, 1, 5 * 7) + assert np.all(np.reshape(x, (-1,)) == np.reshape(y, (-1,))) + + comp = CompositionDecomposition( + decomposed_shape=['a', 'b', 'e', 'c', 'd'], + composed_shape=[['e', 'c'], ['b'], ['a', 'd']], + ) + x = np.arange(2 * 3 * 5 * 7 * 3) + x = np.reshape(x, (2, 3, 5, 7, 3)) + + axes = {} + y = comp.compose(x, axes) + x2 = comp.decompose(y, axes) + assert np.all(x == x2) + + +def test_simple_indexing(): + import numpy.array_api as np + + # simple 2d test + arr = np.reshape(np.arange(5 * 7), (5, 7)) + ind = np.arange(7) % 5 + x = einindex('j <- i j, [i] j', arr, [ind]) + for j, i in enumerate(ind): + assert arr[i, j] == x[j] + + y = einindex('j <- j i, [i] j', np.permute_dims(arr, (1, 0)), [ind]) + for j, i in enumerate(ind): + assert arr[i, j] == y[j] + + +def test_multidimensional_indexing(): + import numpy.array_api as np + + embedding_bhwc = ( + + arange_at_position(np, 4, 0, 2) * 1000 + + arange_at_position(np, 4, 1, 3) * 100 + + arange_at_position(np, 4, 2, 5) * 10 + + arange_at_position(np, 4, 3, 7) * 1 + ) + + hindices_bt = np.reshape(np.arange(6), (2, 3)) % 3 + windices_bt = np.reshape(np.arange(6), (2, 3)) % 5 + + # imagine that you have pairs of image <> sentence + # your goal is to get most suitable token from image for every token in sentence + # thus for every token in sentence you compute best k and v + + result = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, [hindices_bt, windices_bt]) + # example of using a single array for indexing multiple axes + hw_indices_bt = np.stack([hindices_bt, windices_bt]) + result2 = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, hw_indices_bt) + assert np.all(result == result2) + + # check vs manual element computation + result_manual = result * 0 + for b in range(2): + for t in range(3): + for c in range(7): + h = hindices_bt[b, t] + w = windices_bt[b, t] + result_manual[c, t, b] = embedding_bhwc[b, h, w, c] + + assert np.all(result == result_manual) + + +def test_reverse_indexing(): + import numpy.array_api as np + + C, T, B = 2, 3, 5 + # G = GPU, batch-like varaible + G = 4 + H = 7 + W = 9 + + arr_gtbc = ( + + arange_at_position(np, 4, 0, G) * 1000 + + arange_at_position(np, 4, 1, T) * 100 + + arange_at_position(np, 4, 2, B) * 10 + + arange_at_position(np, 4, 3, C) * 1 + ) + + t_indices_gbhw = np.reshape(np.arange(G * B * H * W), (G, B, H, W)) % T + + result = einindex('g b c h w <- g t b c, [t] g b h w', arr_gtbc, [t_indices_gbhw]) + + result_manual = result * 0 + for g in range(G): + for b in range(B): + for c in range(C): + for h in range(H): + for w in range(W): + t = t_indices_gbhw[g, b, h, w] + result_manual[g, b, c, h, w] = arr_gtbc[g, t, b, c] + + assert np.all(result == result_manual) + + diff --git a/python/jittor/einops/layers/__init__.py b/python/jittor/einops/layers/__init__.py new file mode 100644 index 00000000..7b7f43d9 --- /dev/null +++ b/python/jittor/einops/layers/__init__.py @@ -0,0 +1,79 @@ +__author__ = 'Alex Rogozhnikov' + +import functools + +from jittor.einops.einops import _apply_recipe + +from jittor.einops.einops import TransformRecipe, _prepare_transformation_recipe +from jittor.einops import EinopsError + + +class RearrangeMixin: + """ + Rearrange layer behaves identically to einops.rearrange operation. + + :param pattern: str, rearrangement pattern + :param axes_lengths: any additional specification of dimensions + + See einops.rearrange for source_examples. + """ + + def __init__(self, pattern, **axes_lengths): + super().__init__() + self.pattern = pattern + self.axes_lengths = axes_lengths + self._recipe = self.recipe() # checking parameters + + def __repr__(self): + params = repr(self.pattern) + for axis, length in self.axes_lengths.items(): + params += ', {}={}'.format(axis, length) + return '{}({})'.format(self.__class__.__name__, params) + + @functools.lru_cache(maxsize=1024) + def recipe(self) -> TransformRecipe: + try: + hashable_lengths = tuple(sorted(self.axes_lengths.items())) + return _prepare_transformation_recipe(self.pattern, operation='rearrange', axes_lengths=hashable_lengths) + except EinopsError as e: + raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e)) + + def _apply_recipe(self, x): + return _apply_recipe(self._recipe, x, reduction_type='rearrange') + + +class ReduceMixin: + """ + Reduce layer behaves identically to einops.reduce operation. + + :param pattern: str, rearrangement pattern + :param reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive + :param axes_lengths: any additional specification of dimensions + + See einops.reduce for source_examples. + """ + + def __init__(self, pattern, reduction, **axes_lengths): + super().__init__() + self.pattern = pattern + self.reduction = reduction + self.axes_lengths = axes_lengths + self._recipe = self.recipe() # checking parameters + + def __repr__(self): + params = '{!r}, {!r}'.format(self.pattern, self.reduction) + for axis, length in self.axes_lengths.items(): + params += ', {}={}'.format(axis, length) + return '{}({})'.format(self.__class__.__name__, params) + + @functools.lru_cache(maxsize=1024) + def recipe(self) -> TransformRecipe: + try: + hashable_lengths = tuple(sorted(self.axes_lengths.items())) + return _prepare_transformation_recipe( + self.pattern, operation=self.reduction, axes_lengths=hashable_lengths) + except EinopsError as e: + raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e)) + + def _apply_recipe(self, x): + return _apply_recipe(self._recipe, x, reduction_type=self.reduction) diff --git a/python/jittor/einops/layers/_einmix.py b/python/jittor/einops/layers/_einmix.py new file mode 100644 index 00000000..29fb0a5a --- /dev/null +++ b/python/jittor/einops/layers/_einmix.py @@ -0,0 +1,175 @@ +from typing import Optional, Dict + +from jittor.einops import EinopsError +from jittor.einops.parsing import ParsedExpression +import warnings +import string +from jittor.einops.einops import _product + + +def _report_axes(axes: set, report_message: str): + if len(axes) > 0: + raise EinopsError(report_message.format(axes)) + + +class _EinmixMixin: + def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths): + """ + EinMix - Einstein summation with automated tensor management and axis packing/unpacking. + + EinMix is an advanced tool, helpful tutorial: + https://github.com/arogozhnikov/einops/blob/master/docs/3-einmix-layer.ipynb + + Imagine taking einsum with two arguments, one of each input, and one - tensor with weights + >>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight) + + This layer manages weights for you, syntax highlights separate role of weight matrix + >>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out') + But otherwise it is the same einsum under the hood. + + Simple linear layer with bias term (you have one like that in your framework) + >>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20) + There is restriction to mix the last axis. Let's mix along height + >>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32) + Channel-wise multiplication (like one used in normalizations) + >>> EinMix('t b c -> t b c', weight_shape='c', c=128) + Separate dense layer within each head, no connection between different heads + >>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...) + + ... ah yes, you need to specify all dimensions of weight shape/bias shape in parameters. + + Use cases: + - when channel dimension is not last, use EinMix, not transposition + - patch/segment embeddings + - when need only within-group connections to reduce number of weights and computations + - perfect as a part of sequential models + - next-gen MLPs (follow tutorial to learn more) + + Uniform He initialization is applied to weight tensor and encounters for number of elements mixed. + + Parameters + :param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output + :param weight_shape: axes of weight. Tensor od this shape is created, stored, and optimized in a layer + :param bias_shape: axes of bias added to output. + :param axes_lengths: dimensions of weight tensor + """ + super().__init__() + self.pattern = pattern + self.weight_shape = weight_shape + self.bias_shape = bias_shape + self.axes_lengths = axes_lengths + + left_pattern, right_pattern = pattern.split('->') + left = ParsedExpression(left_pattern) + right = ParsedExpression(right_pattern) + weight = ParsedExpression(weight_shape) + _report_axes( + set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}), + 'Unrecognized identifiers on the right side of EinMix {}' + ) + + if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis: + raise EinopsError('Ellipsis is not supported in EinMix (right now)') + if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]): + raise EinopsError('Anonymous axes (numbers) are not allowed in EinMix') + if '(' in weight_shape or ')' in weight_shape: + raise EinopsError(f'Parenthesis is not allowed in weight shape: {weight_shape}') + + pre_reshape_pattern = None + pre_reshape_lengths = None + post_reshape_pattern = None + if any(len(group) != 1 for group in left.composition): + names = [] + for group in left.composition: + names += group + composition = ' '.join(names) + pre_reshape_pattern = f'{left_pattern}->{composition}' + pre_reshape_lengths = {name: length for name, length in self.axes_lengths.items() if name in names} + + if any(len(group) != 1 for group in right.composition): + names = [] + for group in right.composition: + names += group + composition = ' '.join(names) + post_reshape_pattern = f'{composition}->{right_pattern}' + + self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {}) + + for axis in weight.identifiers: + if axis not in axes_lengths: + raise EinopsError('Dimension {} of weight should be specified'.format(axis)) + _report_axes( + set.difference(set(axes_lengths), {*left.identifiers, *weight.identifiers}), + 'Axes {} are not used in pattern', + ) + _report_axes( + set.difference(weight.identifiers, {*left.identifiers, *right.identifiers}), + 'Weight axes {} are redundant' + ) + if len(weight.identifiers) == 0: + warnings.warn('EinMix: weight has no dimensions (means multiplication by a number)') + + _weight_shape = [axes_lengths[axis] for axis, in weight.composition] + # single output element is a combination of fan_in input elements + _fan_in = _product([axes_lengths[axis] for axis, in weight.composition if axis not in right.identifiers]) + if bias_shape is not None: + if not isinstance(bias_shape, str): + raise EinopsError('bias shape should be string specifying which axes bias depends on') + bias = ParsedExpression(bias_shape) + _report_axes( + set.difference(bias.identifiers, right.identifiers), + 'Bias axes {} not present in output' + ) + _report_axes( + set.difference(bias.identifiers, set(axes_lengths)), + 'Sizes not provided for bias axes {}', + ) + + _bias_shape = [] + for axes in right.composition: + for axis in axes: + if axis in bias.identifiers: + _bias_shape.append(axes_lengths[axis]) + else: + _bias_shape.append(1) + else: + _bias_shape = None + _bias_input_size = None + + weight_bound = (3 / _fan_in) ** 0.5 + bias_bound = (1 / _fan_in) ** 0.5 + self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound) + + # rewrite einsum expression with single-letter latin identifiers so that + # expression will be understood by any framework + mapping2letters = {*left.identifiers, *right.identifiers, *weight.identifiers} + mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapping2letters)} + + def write_flat(axes: list): + return ''.join(mapping2letters[axis] for axis in axes) + + self.einsum_pattern: str = '{},{}->{}'.format( + write_flat(left.flat_axes_order()), + write_flat(weight.flat_axes_order()), + write_flat(right.flat_axes_order()), + ) + + def _create_rearrange_layers(self, + pre_reshape_pattern: Optional[str], + pre_reshape_lengths: Optional[Dict], + post_reshape_pattern: Optional[str], + post_reshape_lengths: Optional[Dict]): + raise NotImplementedError('Should be defined in framework implementations') + + def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): + """ Shape and implementations """ + raise NotImplementedError('Should be defined in framework implementations') + + def __repr__(self): + params = repr(self.pattern) + params += f", '{self.weight_shape}'" + if self.bias_shape is not None: + params += f", '{self.bias_shape}'" + for axis, length in self.axes_lengths.items(): + params += ', {}={}'.format(axis, length) + return '{}({})'.format(self.__class__.__name__, params) diff --git a/python/jittor/einops/layers/jittor.py b/python/jittor/einops/layers/jittor.py new file mode 100644 index 00000000..2bcb18e1 --- /dev/null +++ b/python/jittor/einops/layers/jittor.py @@ -0,0 +1,64 @@ +from typing import Optional, Dict + +import jittor as jt +from jittor import nn +import numpy as np + +from jittor.einops.layers import RearrangeMixin, ReduceMixin +from jittor.einops.layers._einmix import _EinmixMixin +from jittor.einops._jittor_specific import apply_for_scriptable_jittor + +__author__ = 'Ruiyang Liu' + + +class Rearrange(RearrangeMixin, jt.nn.Module): + def execute(self, input): + return apply_for_scriptable_jittor(self._recipe, input, reduction_type='rearrange') + + def _apply_recipe(self, x): + # overriding parent method to prevent it's scripting + pass + + +class Reduce(ReduceMixin, jt.nn.Module): + def execute(self, input): + return apply_for_scriptable_jittor(self._recipe, input, reduction_type=self.reduction) + + def _apply_recipe(self, x): + # overriding parent method to prevent it's scripting + pass + + +class EinMix(_EinmixMixin, jt.nn.Module): + def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound): + self.weight = jt.zeros(weight_shape) + nn.init.uniform_(self.weight, low = -weight_bound, high = weight_bound) + if bias_shape is not None: + self.bias = jt.zeros(bias_shape) + nn.init.uniform_(self.bias, low = -bias_bound, high = bias_bound) + else: + self.bias = None + + def _create_rearrange_layers(self, + pre_reshape_pattern: Optional[str], + pre_reshape_lengths: Optional[Dict], + post_reshape_pattern: Optional[str], + post_reshape_lengths: Optional[Dict], + ): + self.pre_rearrange = None + if pre_reshape_pattern is not None: + self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths) + + self.post_rearrange = None + if post_reshape_pattern is not None: + self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths) + + def execute(self, input): + if self.pre_rearrange is not None: + input = self.pre_rearrange(input) + result = jt.linalg.einsum(self.einsum_pattern, input, self.weight) + if self.bias is not None: + result += self.bias + if self.post_rearrange is not None: + result = self.post_rearrange(result) + return result diff --git a/python/jittor/einops/parsing.py b/python/jittor/einops/parsing.py new file mode 100644 index 00000000..77513b2f --- /dev/null +++ b/python/jittor/einops/parsing.py @@ -0,0 +1,147 @@ +from jittor.einops import EinopsError +import keyword +import warnings +from typing import List, Optional, Set, Tuple + +_ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated + + +class AnonymousAxis(object): + """Important thing: all instances of this class are not equal to each other """ + + def __init__(self, value: str): + self.value = int(value) + if self.value <= 1: + if self.value == 1: + raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue') + else: + raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value)) + + def __repr__(self): + return "{}-axis".format(str(self.value)) + + +class ParsedExpression: + """ + non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)') + and keeps some information important for downstream + """ + def __init__(self, expression, *, allow_underscore: bool = False): + self.has_ellipsis: bool = False + self.has_ellipsis_parenthesized: Optional[bool] = None + self.identifiers: Set[str] = set() + # that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition + self.has_non_unitary_anonymous_axes: bool = False + # composition keeps structure of composite axes, see how different corner cases are handled in tests + self.composition = [] + if '.' in expression: + if '...' not in expression: + raise EinopsError('Expression may contain dots only inside ellipsis (...)') + if str.count(expression, '...') != 1 or str.count(expression, '.') != 3: + raise EinopsError( + 'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ') + expression = expression.replace('...', _ellipsis) + self.has_ellipsis = True + + bracket_group = None + + def add_axis_name(x): + if x is not None: + if x in self.identifiers: + if not (allow_underscore and x == "_"): + raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x)) + if x == _ellipsis: + self.identifiers.add(_ellipsis) + if bracket_group is None: + self.composition.append(_ellipsis) + self.has_ellipsis_parenthesized = False + else: + bracket_group.append(_ellipsis) + self.has_ellipsis_parenthesized = True + else: + is_number = str.isdecimal(x) + if is_number and int(x) == 1: + # handling the case of anonymous axis of length 1 + if bracket_group is None: + self.composition.append([]) + else: + pass # no need to think about 1s inside parenthesis + return + is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore) + if not (is_number or is_axis_name): + raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason)) + if is_number: + x = AnonymousAxis(x) + self.identifiers.add(x) + if is_number: + self.has_non_unitary_anonymous_axes = True + if bracket_group is None: + self.composition.append([x]) + else: + bracket_group.append(x) + + current_identifier = None + for char in expression: + if char in '() ': + add_axis_name(current_identifier) + current_identifier = None + if char == '(': + if bracket_group is not None: + raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)") + bracket_group = [] + elif char == ')': + if bracket_group is None: + raise EinopsError('Brackets are not balanced') + self.composition.append(bracket_group) + bracket_group = None + elif str.isalnum(char) or char in ['_', _ellipsis]: + if current_identifier is None: + current_identifier = char + else: + current_identifier += char + else: + raise EinopsError("Unknown character '{}'".format(char)) + + if bracket_group is not None: + raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression)) + add_axis_name(current_identifier) + + def flat_axes_order(self) -> List: + result = [] + for composed_axis in self.composition: + assert isinstance(composed_axis, list), 'does not work with ellipsis' + for axis in composed_axis: + result.append(axis) + return result + + def has_composed_axes(self) -> bool: + # this will ignore 1 inside brackets + for axes in self.composition: + if isinstance(axes, list) and len(axes) > 1: + return True + return False + + @staticmethod + def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]: + if not str.isidentifier(name): + return False, 'not a valid python identifier' + elif name[0] == '_' or name[-1] == '_': + if name == '_' and allow_underscore: + return True, '' + return False, 'axis name should should not start or end with underscore' + else: + if keyword.iskeyword(name): + warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning) + if name in ['axis']: + warnings.warn("It is discouraged to use 'axis' as an axis name " + "and will raise an error in future", FutureWarning) + return True, '' + + @staticmethod + def check_axis_name(name: str) -> bool: + """ + Valid axes names are python identifiers except keywords, + and additionally should not start or end with underscore + """ + is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name) + return is_valid