mirror of https://github.com/Jittor/Jittor
Support einops for Jittor
This commit is contained in:
parent
8c9bfb639d
commit
13f9eaafc0
|
@ -0,0 +1,8 @@
|
||||||
|
class EinopsError(RuntimeError):
|
||||||
|
""" Runtime error thrown by einops """
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['rearrange', 'reduce', 'repeat', 'parse_shape', 'asnumpy', 'EinopsError']
|
||||||
|
|
||||||
|
from jittor.einops.einops import rearrange, reduce, repeat, parse_shape, asnumpy
|
|
@ -0,0 +1,255 @@
|
||||||
|
"""
|
||||||
|
Backends in `einops` are organized to meet the following requirements
|
||||||
|
- backends are not imported unless those are actually needed, because
|
||||||
|
- backends may not be installed
|
||||||
|
- importing all available backends will drive to significant memory footprint
|
||||||
|
- backends may by present but installed with errors (but never used),
|
||||||
|
importing may drive to crashes
|
||||||
|
- backend should be either symbolic or imperative (tensorflow is for both, but that causes problems)
|
||||||
|
- this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined
|
||||||
|
- if backend can't (temporarily) provide symbols for shape dimensions, UnknownSize objects are used
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
__author__ = 'Alex Rogozhnikov'
|
||||||
|
|
||||||
|
_backends = {}
|
||||||
|
_debug_importing = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_backend(tensor) -> 'AbstractBackend':
|
||||||
|
"""
|
||||||
|
Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
|
||||||
|
If needed, imports package and creates backend
|
||||||
|
"""
|
||||||
|
for framework_name, backend in _backends.items():
|
||||||
|
if backend.is_appropriate_type(tensor):
|
||||||
|
return backend
|
||||||
|
|
||||||
|
# Find backend subclasses recursively
|
||||||
|
backend_subclasses = []
|
||||||
|
backends = AbstractBackend.__subclasses__()
|
||||||
|
while backends:
|
||||||
|
backend = backends.pop()
|
||||||
|
backends += backend.__subclasses__()
|
||||||
|
backend_subclasses.append(backend)
|
||||||
|
|
||||||
|
for BackendSubclass in backend_subclasses:
|
||||||
|
if _debug_importing:
|
||||||
|
print('Testing for subclass of ', BackendSubclass)
|
||||||
|
if BackendSubclass.framework_name not in _backends:
|
||||||
|
# check that module was already imported. Otherwise it can't be imported
|
||||||
|
if BackendSubclass.framework_name in sys.modules:
|
||||||
|
if _debug_importing:
|
||||||
|
print('Imported backend for ', BackendSubclass.framework_name)
|
||||||
|
backend = BackendSubclass()
|
||||||
|
_backends[backend.framework_name] = backend
|
||||||
|
if backend.is_appropriate_type(tensor):
|
||||||
|
return backend
|
||||||
|
|
||||||
|
raise RuntimeError('Tensor type unknown to einops {}'.format(type(tensor)))
|
||||||
|
|
||||||
|
|
||||||
|
class AbstractBackend:
|
||||||
|
""" Base backend class, major part of methods are only for debugging purposes. """
|
||||||
|
framework_name = None
|
||||||
|
|
||||||
|
def is_appropriate_type(self, tensor):
|
||||||
|
""" helper method should recognize tensors it can handle """
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def from_numpy(self, x):
|
||||||
|
raise NotImplementedError("framework doesn't support imperative execution")
|
||||||
|
|
||||||
|
def to_numpy(self, x):
|
||||||
|
raise NotImplementedError("framework doesn't support imperative execution")
|
||||||
|
|
||||||
|
def create_symbol(self, shape):
|
||||||
|
raise NotImplementedError("framework doesn't support symbolic computations")
|
||||||
|
|
||||||
|
def eval_symbol(self, symbol, input_dict):
|
||||||
|
raise NotImplementedError("framework doesn't support symbolic computations")
|
||||||
|
|
||||||
|
def arange(self, start, stop):
|
||||||
|
# supplementary method used only in testing, so should implement CPU version
|
||||||
|
raise NotImplementedError("framework doesn't implement arange")
|
||||||
|
|
||||||
|
def shape(self, x):
|
||||||
|
"""shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)"""
|
||||||
|
return x.shape
|
||||||
|
|
||||||
|
def reshape(self, x, shape):
|
||||||
|
return x.reshape(shape)
|
||||||
|
|
||||||
|
def transpose(self, x, axes):
|
||||||
|
return x.transpose(axes)
|
||||||
|
|
||||||
|
def reduce(self, x, operation, axes):
|
||||||
|
return getattr(x, operation)(axis=axes)
|
||||||
|
|
||||||
|
def stack_on_zeroth_dimension(self, tensors: list):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def add_axis(self, x, new_position):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def add_axes(self, x, n_axes, pos2len):
|
||||||
|
repeats = [1] * n_axes
|
||||||
|
for axis_position, axis_length in pos2len.items():
|
||||||
|
x = self.add_axis(x, axis_position)
|
||||||
|
repeats[axis_position] = axis_length
|
||||||
|
return self.tile(x, tuple(repeats))
|
||||||
|
|
||||||
|
def tile(self, x, repeats):
|
||||||
|
"""repeats is a number of """
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def is_float_type(self, x):
|
||||||
|
# some backends (torch) can't compute average for non-floating types.
|
||||||
|
# Decided to drop average for all backends if type is not floating
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def layers(self):
|
||||||
|
raise NotImplementedError("backend does not provide layers")
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<einops backend for {}>".format(self.framework_name)
|
||||||
|
|
||||||
|
|
||||||
|
class UnknownSize:
|
||||||
|
""" pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements """
|
||||||
|
|
||||||
|
def __floordiv__(self, other):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return True # we don't know actual size
|
||||||
|
|
||||||
|
def __mul__(self, other):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __rmul__(self, other):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __hash__(self):
|
||||||
|
return None.__hash__()
|
||||||
|
|
||||||
|
|
||||||
|
class NumpyBackend(AbstractBackend):
|
||||||
|
framework_name = 'numpy'
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
import numpy
|
||||||
|
self.np = numpy
|
||||||
|
|
||||||
|
def is_appropriate_type(self, tensor):
|
||||||
|
return isinstance(tensor, self.np.ndarray)
|
||||||
|
|
||||||
|
def from_numpy(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def to_numpy(self, x):
|
||||||
|
return x
|
||||||
|
|
||||||
|
def arange(self, start, stop):
|
||||||
|
return self.np.arange(start, stop)
|
||||||
|
|
||||||
|
def stack_on_zeroth_dimension(self, tensors: list):
|
||||||
|
return self.np.stack(tensors)
|
||||||
|
|
||||||
|
def tile(self, x, repeats):
|
||||||
|
return self.np.tile(x, repeats)
|
||||||
|
|
||||||
|
def is_float_type(self, x):
|
||||||
|
return x.dtype in ('float16', 'float32', 'float64', 'float128')
|
||||||
|
|
||||||
|
def add_axis(self, x, new_position):
|
||||||
|
return self.np.expand_dims(x, new_position)
|
||||||
|
|
||||||
|
class HashableTuple:
|
||||||
|
"""Overcomes non-hashability of symbolic elements"""
|
||||||
|
|
||||||
|
def __init__(self, elements: tuple):
|
||||||
|
self.elements = elements
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
for x in self.elements:
|
||||||
|
yield x
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.elements)
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.elements[item]
|
||||||
|
|
||||||
|
class JittorBackend(AbstractBackend):
|
||||||
|
framework_name = 'jittor'
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
import jittor
|
||||||
|
self.jittor = jittor
|
||||||
|
|
||||||
|
def is_appropriate_type(self, tensor):
|
||||||
|
return isinstance(tensor, self.jittor.jittor_core.Var)
|
||||||
|
|
||||||
|
def from_numpy(self, x):
|
||||||
|
variable = self.jittor.array(x)
|
||||||
|
if self.is_float_type(variable):
|
||||||
|
# attach grad only to floating types
|
||||||
|
variable.requires_grad = True
|
||||||
|
return variable
|
||||||
|
|
||||||
|
def to_numpy(self, x):
|
||||||
|
return x.detach().numpy()
|
||||||
|
|
||||||
|
def arange(self, start, stop):
|
||||||
|
return self.jittor.arange(start, stop, dtype='int64')
|
||||||
|
|
||||||
|
def shape(self, x):
|
||||||
|
return HashableTuple(tuple(x.shape))
|
||||||
|
|
||||||
|
def reshape(self, x, shape):
|
||||||
|
return self.jittor.reshape(x, shape)
|
||||||
|
|
||||||
|
# def reduce(self, x, operation, axes):
|
||||||
|
# return getattr(x, operation)(dim=axes)
|
||||||
|
|
||||||
|
def reduce(self, x, operation, reduced_axes):
|
||||||
|
for axis in sorted(reduced_axes, reverse=True):
|
||||||
|
if operation == 'min':
|
||||||
|
x = x.min(dim=axis)
|
||||||
|
elif operation == 'max':
|
||||||
|
x = x.max(dim=axis)
|
||||||
|
elif operation in ['sum', 'mean', 'prod']:
|
||||||
|
x = getattr(x, operation)(dim=axis)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Unknown reduction ', operation)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def transpose(self, x, axes):
|
||||||
|
return x.permute(axes)
|
||||||
|
|
||||||
|
def stack_on_zeroth_dimension(self, tensors: list):
|
||||||
|
return self.jittor.stack(tensors)
|
||||||
|
|
||||||
|
def add_axes(self, x, n_axes, pos2len):
|
||||||
|
repeats = [-1] * n_axes
|
||||||
|
for axis_position, axis_length in pos2len.items():
|
||||||
|
x = self.add_axis(x, axis_position)
|
||||||
|
repeats[axis_position] = axis_length
|
||||||
|
return x.expand(repeats)
|
||||||
|
|
||||||
|
def tile(self, x, repeats):
|
||||||
|
return x.repeat(repeats)
|
||||||
|
|
||||||
|
def add_axis(self, x, new_position):
|
||||||
|
return self.jittor.unsqueeze(x, new_position)
|
||||||
|
|
||||||
|
def is_float_type(self, x):
|
||||||
|
return x.dtype in ["float16", "float32", "float64"]
|
||||||
|
|
||||||
|
def layers(self):
|
||||||
|
from jittor.einops.layers import jittor
|
||||||
|
return jittor
|
|
@ -0,0 +1,84 @@
|
||||||
|
"""
|
||||||
|
Specialization of einops for jittor.
|
||||||
|
|
||||||
|
Unfortunately, jittor's jit scripting mechanism isn't strong enough,
|
||||||
|
and to have scripting supported at least for layers,
|
||||||
|
a number of changes is required, and this layer helps.
|
||||||
|
|
||||||
|
Importantly, whole lib is designed so that you can't use it
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
import jittor as jt
|
||||||
|
from jittor.einops.einops import TransformRecipe, _reconstruct_from_shape_uncached
|
||||||
|
|
||||||
|
|
||||||
|
class JittorJitBackend:
|
||||||
|
"""
|
||||||
|
Completely static backend that mimics part of normal backend functionality
|
||||||
|
but restricted to jittor stuff only
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reduce(x: jt.jittor_core.Var, operation: str, reduced_axes: List[int]):
|
||||||
|
if operation == 'min':
|
||||||
|
return x.min(dims=reduced_axes)
|
||||||
|
elif operation == 'max':
|
||||||
|
return x.max(dims=reduced_axes)
|
||||||
|
elif operation == 'sum':
|
||||||
|
return x.sum(dims=reduced_axes)
|
||||||
|
elif operation == 'mean':
|
||||||
|
return x.mean(dims=reduced_axes)
|
||||||
|
elif operation == 'prod':
|
||||||
|
for i in list(sorted(reduced_axes))[::-1]:
|
||||||
|
x = x.prod(dim=i)
|
||||||
|
return x
|
||||||
|
else:
|
||||||
|
raise NotImplementedError('Unknown reduction ', operation)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def transpose(x, axes: List[int]):
|
||||||
|
return x.permute(axes)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def stack_on_zeroth_dimension(tensors: List[jt.jittor_core.Var]):
|
||||||
|
return jt.stack(tensors)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def tile(x, repeats: List[int]):
|
||||||
|
return x.repeat(repeats)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def add_axes(x, n_axes: int, pos2len: Dict[int, int]):
|
||||||
|
repeats = [1] * n_axes
|
||||||
|
for axis_position, axis_length in pos2len.items():
|
||||||
|
x = jt.unsqueeze(x, axis_position)
|
||||||
|
repeats[axis_position] = axis_length
|
||||||
|
return JittorJitBackend.tile(x, repeats)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def is_float_type(x):
|
||||||
|
return x.dtype in ["float16", "float32", "float64"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def shape(x):
|
||||||
|
return x.shape
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def reshape(x, shape: List[int]):
|
||||||
|
return x.reshape(shape)
|
||||||
|
|
||||||
|
|
||||||
|
# mirrors einops.einops._apply_recipe
|
||||||
|
def apply_for_scriptable_jittor(recipe: TransformRecipe, tensor: jt.jittor_core.Var, reduction_type: str) -> jt.jittor_core.Var:
|
||||||
|
backend = JittorJitBackend
|
||||||
|
init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \
|
||||||
|
_reconstruct_from_shape_uncached(recipe, backend.shape(tensor))
|
||||||
|
tensor = backend.reshape(tensor, init_shapes)
|
||||||
|
if len(reduced_axes) > 0:
|
||||||
|
tensor = backend.reduce(tensor, operation=reduction_type, reduced_axes=reduced_axes)
|
||||||
|
tensor = backend.transpose(tensor, axes_reordering)
|
||||||
|
if len(added_axes) > 0:
|
||||||
|
tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes)
|
||||||
|
return backend.reshape(tensor, final_shapes)
|
|
@ -0,0 +1,625 @@
|
||||||
|
import functools
|
||||||
|
import itertools
|
||||||
|
import typing
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Tuple, List, Dict, Union, Callable, Optional, TypeVar
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from jittor.einops import EinopsError
|
||||||
|
from jittor.einops._backends import get_backend
|
||||||
|
from jittor.einops.parsing import ParsedExpression, _ellipsis, AnonymousAxis
|
||||||
|
|
||||||
|
Tensor = TypeVar('Tensor')
|
||||||
|
ReductionCallable = Callable[[Tensor, List[int]], Tensor]
|
||||||
|
Reduction = Union[str, ReductionCallable]
|
||||||
|
|
||||||
|
_reductions = ('min', 'max', 'sum', 'mean', 'prod')
|
||||||
|
_ellipsis_not_in_parenthesis: List[int] = [-999]
|
||||||
|
_unknown_axis_length = -999999
|
||||||
|
|
||||||
|
|
||||||
|
def is_ellipsis_not_in_parenthesis(group: List[int]) -> bool:
|
||||||
|
if len(group) != 1:
|
||||||
|
return False
|
||||||
|
return group[0] == -999
|
||||||
|
|
||||||
|
|
||||||
|
def _product(sequence: List[int]) -> int:
|
||||||
|
""" minimalistic product that works both with numbers and symbols. Supports empty lists """
|
||||||
|
result = 1
|
||||||
|
for element in sequence:
|
||||||
|
result *= element
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _reduce_axes(tensor, reduction_type: Reduction, reduced_axes: List[int], backend):
|
||||||
|
reduced_axes = tuple(reduced_axes)
|
||||||
|
if callable(reduction_type):
|
||||||
|
# custom callable
|
||||||
|
return reduction_type(tensor, reduced_axes)
|
||||||
|
else:
|
||||||
|
# one of built-in operations
|
||||||
|
if len(reduced_axes) == 0:
|
||||||
|
return tensor
|
||||||
|
assert reduction_type in _reductions
|
||||||
|
if reduction_type == 'mean':
|
||||||
|
if not backend.is_float_type(tensor):
|
||||||
|
raise NotImplementedError('reduce_mean is not available for non-floating tensors')
|
||||||
|
return backend.reduce(tensor, reduction_type, reduced_axes)
|
||||||
|
|
||||||
|
|
||||||
|
def _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes):
|
||||||
|
# 'collapses' neighboring axes if those participate in the result pattern in the same order
|
||||||
|
# TODO add support for added_axes
|
||||||
|
assert len(axes_reordering) + len(reduced_axes) == len(init_shapes)
|
||||||
|
# joining consecutive axes that will be reduced
|
||||||
|
# possibly we can skip this if all backends can optimize this (not sure)
|
||||||
|
reduced_axes = tuple(sorted(reduced_axes))
|
||||||
|
for i in range(len(reduced_axes) - 1)[::-1]:
|
||||||
|
if reduced_axes[i] + 1 == reduced_axes[i + 1]:
|
||||||
|
removed_axis = reduced_axes[i + 1]
|
||||||
|
removed_length = init_shapes[removed_axis]
|
||||||
|
init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
|
||||||
|
init_shapes[removed_axis - 1] *= removed_length
|
||||||
|
reduced_axes = reduced_axes[:i + 1] + tuple(axis - 1 for axis in reduced_axes[i + 2:])
|
||||||
|
|
||||||
|
# removing axes that are moved together during reshape
|
||||||
|
def build_mapping():
|
||||||
|
init_to_final = {}
|
||||||
|
for axis in range(len(init_shapes)):
|
||||||
|
if axis in reduced_axes:
|
||||||
|
init_to_final[axis] = None
|
||||||
|
else:
|
||||||
|
after_reduction = sum(x is not None for x in init_to_final.values())
|
||||||
|
init_to_final[axis] = list(axes_reordering).index(after_reduction)
|
||||||
|
return init_to_final
|
||||||
|
|
||||||
|
init_axis_to_final_axis = build_mapping()
|
||||||
|
|
||||||
|
for init_axis in range(len(init_shapes) - 1)[::-1]:
|
||||||
|
if init_axis_to_final_axis[init_axis] is None:
|
||||||
|
continue
|
||||||
|
if init_axis_to_final_axis[init_axis + 1] is None:
|
||||||
|
continue
|
||||||
|
if init_axis_to_final_axis[init_axis] + 1 == init_axis_to_final_axis[init_axis + 1]:
|
||||||
|
removed_axis = init_axis + 1
|
||||||
|
removed_length = init_shapes[removed_axis]
|
||||||
|
removed_axis_after_reduction = sum(x not in reduced_axes for x in range(removed_axis))
|
||||||
|
|
||||||
|
reduced_axes = tuple(axis if axis < removed_axis else axis - 1 for axis in reduced_axes)
|
||||||
|
init_shapes = init_shapes[:removed_axis] + init_shapes[removed_axis + 1:]
|
||||||
|
init_shapes[removed_axis - 1] *= removed_length
|
||||||
|
old_reordering = axes_reordering
|
||||||
|
axes_reordering = []
|
||||||
|
for axis in old_reordering:
|
||||||
|
if axis == removed_axis_after_reduction:
|
||||||
|
pass
|
||||||
|
elif axis < removed_axis_after_reduction:
|
||||||
|
axes_reordering.append(axis)
|
||||||
|
else:
|
||||||
|
axes_reordering.append(axis - 1)
|
||||||
|
init_axis_to_final_axis = build_mapping()
|
||||||
|
|
||||||
|
return init_shapes, reduced_axes, axes_reordering, final_shapes
|
||||||
|
|
||||||
|
|
||||||
|
CookedRecipe = Tuple[List[int], List[int], List[int], Dict[int, int], List[int]]
|
||||||
|
|
||||||
|
|
||||||
|
class TransformRecipe:
|
||||||
|
"""
|
||||||
|
Recipe describes actual computation pathway.
|
||||||
|
Recipe can be applied to a tensor or variable.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# structure is non-mutable. In future, this can be non-mutable dataclass (python 3.7+)
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
# list of expressions (or just sizes) for elementary axes as they appear in left expression.
|
||||||
|
# this is what (after computing unknown parts) will be a shape after first transposition.
|
||||||
|
# If ellipsis is present, it forms one dimension here (in the right position).
|
||||||
|
elementary_axes_lengths: List[int],
|
||||||
|
# each dimension in input can help to reconstruct length of one elementary axis
|
||||||
|
# or verify one of dimensions. Each element points to element of elementary_axes_lengths
|
||||||
|
input_composite_axes: List[Tuple[List[int], List[int]]],
|
||||||
|
# indices of axes to be squashed
|
||||||
|
reduced_elementary_axes: List[int],
|
||||||
|
# in which order should axes be reshuffled after reduction
|
||||||
|
axes_permutation: List[int],
|
||||||
|
# at which positions which of elementary axes should appear
|
||||||
|
added_axes: Dict[int, int],
|
||||||
|
# ids of axes as they appear in result, again pointers to elementary_axes_lengths,
|
||||||
|
# only used to infer result dimensions
|
||||||
|
output_composite_axes: List[List[int]],
|
||||||
|
# positions of ellipsis in lhs and rhs of expression
|
||||||
|
ellipsis_position_in_lhs: Optional[int] = None,
|
||||||
|
):
|
||||||
|
self.elementary_axes_lengths: List[int] = elementary_axes_lengths
|
||||||
|
self.input_composite_axes: List[Tuple[List[int], List[int]]] = input_composite_axes
|
||||||
|
self.output_composite_axes: List[List[int]] = output_composite_axes
|
||||||
|
self.axes_permutation: List[int] = axes_permutation
|
||||||
|
self.added_axes: Dict[int, int] = added_axes
|
||||||
|
# This is redundant information, but more convenient to use
|
||||||
|
self.reduced_elementary_axes: List[int] = reduced_elementary_axes
|
||||||
|
# setting to a large number to avoid handling Nones in reconstruct_from_shape
|
||||||
|
self.ellipsis_position_in_lhs: int = ellipsis_position_in_lhs if ellipsis_position_in_lhs is not None else 10000
|
||||||
|
|
||||||
|
|
||||||
|
def _reconstruct_from_shape_uncached(self: TransformRecipe, shape: List[int]) -> CookedRecipe:
|
||||||
|
"""
|
||||||
|
Reconstruct all actual parameters using shape.
|
||||||
|
Shape is a tuple that may contain integers, shape symbols (tf, keras, theano) and UnknownSize (keras, mxnet)
|
||||||
|
known axes can be integers or symbols, but not Nones.
|
||||||
|
"""
|
||||||
|
axes_lengths: List[int] = list(self.elementary_axes_lengths)
|
||||||
|
if self.ellipsis_position_in_lhs != 10000:
|
||||||
|
if len(shape) < len(self.input_composite_axes) - 1:
|
||||||
|
raise EinopsError('Expected at least {} dimensions, got {}'.format(
|
||||||
|
len(self.input_composite_axes) - 1, len(shape)))
|
||||||
|
else:
|
||||||
|
if len(shape) != len(self.input_composite_axes):
|
||||||
|
raise EinopsError('Expected {} dimensions, got {}'.format(len(self.input_composite_axes), len(shape)))
|
||||||
|
|
||||||
|
ellipsis_shape: List[int] = []
|
||||||
|
for input_axis, (known_axes, unknown_axes) in enumerate(self.input_composite_axes):
|
||||||
|
before_ellipsis = input_axis
|
||||||
|
after_ellipsis = input_axis + len(shape) - len(self.input_composite_axes)
|
||||||
|
if input_axis == self.ellipsis_position_in_lhs:
|
||||||
|
assert len(known_axes) == 0 and len(unknown_axes) == 1
|
||||||
|
unknown_axis, = unknown_axes
|
||||||
|
ellipsis_shape = shape[before_ellipsis:after_ellipsis + 1]
|
||||||
|
for d in ellipsis_shape:
|
||||||
|
if d is None:
|
||||||
|
raise EinopsError("Couldn't infer shape for one or more axes represented by ellipsis")
|
||||||
|
total_dim_size: int = _product(ellipsis_shape)
|
||||||
|
axes_lengths[unknown_axis] = total_dim_size
|
||||||
|
else:
|
||||||
|
if input_axis < self.ellipsis_position_in_lhs:
|
||||||
|
length = shape[before_ellipsis]
|
||||||
|
else:
|
||||||
|
length = shape[after_ellipsis]
|
||||||
|
known_product = 1
|
||||||
|
for axis in known_axes:
|
||||||
|
known_product *= axes_lengths[axis]
|
||||||
|
|
||||||
|
if len(unknown_axes) == 0:
|
||||||
|
if isinstance(length, int) and isinstance(known_product, int) and length != known_product:
|
||||||
|
raise EinopsError('Shape mismatch, {} != {}'.format(length, known_product))
|
||||||
|
# this is enforced when recipe is created
|
||||||
|
# elif len(unknown_axes) > 1:
|
||||||
|
# raise EinopsError(
|
||||||
|
# "Lengths of two or more axes in parenthesis not provided (dim={}), can't infer dimensions".
|
||||||
|
# format(known_product)
|
||||||
|
# )
|
||||||
|
else:
|
||||||
|
if isinstance(length, int) and isinstance(known_product, int) and length % known_product != 0:
|
||||||
|
raise EinopsError("Shape mismatch, can't divide axis of length {} in chunks of {}".format(
|
||||||
|
length, known_product))
|
||||||
|
|
||||||
|
unknown_axis: int = unknown_axes[0]
|
||||||
|
inferred_length: int = length // known_product
|
||||||
|
axes_lengths[unknown_axis] = inferred_length
|
||||||
|
|
||||||
|
# at this point all axes_lengths are computed (either have values or variables, but not Nones)
|
||||||
|
|
||||||
|
# TODO more readable expression
|
||||||
|
init_shapes = axes_lengths[:len(axes_lengths) - len(self.added_axes)]
|
||||||
|
final_shapes: List[int] = []
|
||||||
|
for output_axis, grouping in enumerate(self.output_composite_axes):
|
||||||
|
if is_ellipsis_not_in_parenthesis(grouping):
|
||||||
|
final_shapes.extend(ellipsis_shape)
|
||||||
|
else:
|
||||||
|
lengths = [axes_lengths[elementary_axis] for elementary_axis in grouping]
|
||||||
|
final_shapes.append(_product(lengths))
|
||||||
|
reduced_axes = self.reduced_elementary_axes
|
||||||
|
axes_reordering = self.axes_permutation
|
||||||
|
added_axes: Dict[int, int] = {
|
||||||
|
pos: axes_lengths[pos_in_elementary] for pos, pos_in_elementary in self.added_axes.items()}
|
||||||
|
# if optimize:
|
||||||
|
# assert len(self.added_axes) == 0
|
||||||
|
# return _optimize_transformation(init_shapes, reduced_axes, axes_reordering, final_shapes)
|
||||||
|
return init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes
|
||||||
|
|
||||||
|
|
||||||
|
_reconstruct_from_shape = functools.lru_cache(1024)(_reconstruct_from_shape_uncached)
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_recipe(recipe: TransformRecipe, tensor: Tensor, reduction_type: Reduction) -> Tensor:
|
||||||
|
# this method works for all backends but not compilable with
|
||||||
|
backend = get_backend(tensor)
|
||||||
|
init_shapes, reduced_axes, axes_reordering, added_axes, final_shapes = \
|
||||||
|
_reconstruct_from_shape(recipe, backend.shape(tensor))
|
||||||
|
tensor = backend.reshape(tensor, init_shapes)
|
||||||
|
tensor = _reduce_axes(tensor, reduction_type=reduction_type, reduced_axes=reduced_axes, backend=backend)
|
||||||
|
tensor = backend.transpose(tensor, axes_reordering)
|
||||||
|
if len(added_axes) > 0:
|
||||||
|
tensor = backend.add_axes(tensor, n_axes=len(axes_reordering) + len(added_axes), pos2len=added_axes)
|
||||||
|
return backend.reshape(tensor, final_shapes)
|
||||||
|
|
||||||
|
|
||||||
|
@functools.lru_cache(256)
|
||||||
|
def _prepare_transformation_recipe(pattern: str,
|
||||||
|
operation: Reduction,
|
||||||
|
axes_lengths: Tuple[Tuple, ...]) -> TransformRecipe:
|
||||||
|
""" Perform initial parsing of pattern and provided supplementary info
|
||||||
|
axes_lengths is a tuple of tuples (axis_name, axis_length)
|
||||||
|
"""
|
||||||
|
left, rght = pattern.split('->')
|
||||||
|
left = ParsedExpression(left)
|
||||||
|
rght = ParsedExpression(rght)
|
||||||
|
|
||||||
|
# checking that axes are in agreement - new axes appear only in repeat, while disappear only in reduction
|
||||||
|
if not left.has_ellipsis and rght.has_ellipsis:
|
||||||
|
raise EinopsError('Ellipsis found in right side, but not left side of a pattern {}'.format(pattern))
|
||||||
|
if left.has_ellipsis and left.has_ellipsis_parenthesized:
|
||||||
|
raise EinopsError('Ellipsis is parenthesis in the left side is not allowed: {}'.format(pattern))
|
||||||
|
if operation == 'rearrange':
|
||||||
|
difference = set.symmetric_difference(left.identifiers, rght.identifiers)
|
||||||
|
if left.has_non_unitary_anonymous_axes or rght.has_non_unitary_anonymous_axes:
|
||||||
|
raise EinopsError('Non-unitary anonymous axes are not supported in rearrange (exception is length 1)')
|
||||||
|
if len(difference) > 0:
|
||||||
|
raise EinopsError('Identifiers only on one side of expression (should be on both): {}'.format(difference))
|
||||||
|
elif operation == 'repeat':
|
||||||
|
difference = set.difference(left.identifiers, rght.identifiers)
|
||||||
|
if len(difference) > 0:
|
||||||
|
raise EinopsError('Unexpected identifiers on the left side of repeat: {}'.format(difference))
|
||||||
|
axes_without_size = set.difference({ax for ax in rght.identifiers if not isinstance(ax, AnonymousAxis)},
|
||||||
|
{*left.identifiers, *(ax for ax, _ in axes_lengths)})
|
||||||
|
if len(axes_without_size) > 0:
|
||||||
|
raise EinopsError('Specify sizes for new axes in repeat: {}'.format(axes_without_size))
|
||||||
|
elif operation in _reductions or callable(operation):
|
||||||
|
difference = set.difference(rght.identifiers, left.identifiers)
|
||||||
|
if len(difference) > 0:
|
||||||
|
raise EinopsError('Unexpected identifiers on the right side of reduce {}: {}'.format(operation, difference))
|
||||||
|
else:
|
||||||
|
raise EinopsError('Unknown reduction {}. Expect one of {}.'.format(operation, _reductions))
|
||||||
|
|
||||||
|
# parsing all dimensions to find out lengths
|
||||||
|
axis_name2known_length = OrderedDict()
|
||||||
|
for composite_axis in left.composition:
|
||||||
|
for axis_name in composite_axis:
|
||||||
|
if isinstance(axis_name, AnonymousAxis):
|
||||||
|
axis_name2known_length[axis_name] = axis_name.value
|
||||||
|
else:
|
||||||
|
axis_name2known_length[axis_name] = _unknown_axis_length
|
||||||
|
|
||||||
|
# axis_ids_after_first_reshape = range(len(axis_name2known_length)) at this point
|
||||||
|
|
||||||
|
repeat_axes_names = []
|
||||||
|
for axis_name in rght.identifiers:
|
||||||
|
if axis_name not in axis_name2known_length:
|
||||||
|
if isinstance(axis_name, AnonymousAxis):
|
||||||
|
axis_name2known_length[axis_name] = axis_name.value
|
||||||
|
else:
|
||||||
|
axis_name2known_length[axis_name] = _unknown_axis_length
|
||||||
|
repeat_axes_names.append(axis_name)
|
||||||
|
|
||||||
|
axis_name2position = {name: position for position, name in enumerate(axis_name2known_length)}
|
||||||
|
reduced_axes: List[int] = [position for axis, position in axis_name2position.items() if
|
||||||
|
axis not in rght.identifiers]
|
||||||
|
reduced_axes: List[int] = list(sorted(reduced_axes))
|
||||||
|
|
||||||
|
for elementary_axis, axis_length in axes_lengths:
|
||||||
|
if not ParsedExpression.check_axis_name(elementary_axis):
|
||||||
|
raise EinopsError('Invalid name for an axis', elementary_axis)
|
||||||
|
if elementary_axis not in axis_name2known_length:
|
||||||
|
raise EinopsError('Axis {} is not used in transform'.format(elementary_axis))
|
||||||
|
axis_name2known_length[elementary_axis] = axis_length
|
||||||
|
|
||||||
|
input_axes_known_unknown = []
|
||||||
|
# some of shapes will be inferred later - all information is prepared for faster inference
|
||||||
|
for composite_axis in left.composition:
|
||||||
|
known = {axis for axis in composite_axis if axis_name2known_length[axis] != _unknown_axis_length}
|
||||||
|
unknown = {axis for axis in composite_axis if axis_name2known_length[axis] == _unknown_axis_length}
|
||||||
|
if len(unknown) > 1:
|
||||||
|
raise EinopsError('Could not infer sizes for {}'.format(unknown))
|
||||||
|
assert len(unknown) + len(known) == len(composite_axis)
|
||||||
|
input_axes_known_unknown.append(
|
||||||
|
([axis_name2position[axis] for axis in known],
|
||||||
|
[axis_name2position[axis] for axis in unknown])
|
||||||
|
)
|
||||||
|
|
||||||
|
axis_position_after_reduction = {}
|
||||||
|
for axis_name in itertools.chain(*left.composition):
|
||||||
|
if axis_name in rght.identifiers:
|
||||||
|
axis_position_after_reduction[axis_name] = len(axis_position_after_reduction)
|
||||||
|
|
||||||
|
result_axes_grouping: List[List[int]] = []
|
||||||
|
for composite_axis in rght.composition:
|
||||||
|
if composite_axis == _ellipsis:
|
||||||
|
result_axes_grouping.append(_ellipsis_not_in_parenthesis)
|
||||||
|
else:
|
||||||
|
result_axes_grouping.append([axis_name2position[axis] for axis in composite_axis])
|
||||||
|
|
||||||
|
ordered_axis_right = list(itertools.chain(*rght.composition))
|
||||||
|
axes_permutation = [
|
||||||
|
axis_position_after_reduction[axis] for axis in ordered_axis_right if axis in left.identifiers]
|
||||||
|
added_axes = {i: axis_name2position[axis_name] for i, axis_name in enumerate(ordered_axis_right)
|
||||||
|
if axis_name not in left.identifiers}
|
||||||
|
|
||||||
|
ellipsis_left = None if _ellipsis not in left.composition else left.composition.index(_ellipsis)
|
||||||
|
|
||||||
|
return TransformRecipe(
|
||||||
|
elementary_axes_lengths=list(axis_name2known_length.values()),
|
||||||
|
input_composite_axes=input_axes_known_unknown,
|
||||||
|
reduced_elementary_axes=reduced_axes,
|
||||||
|
axes_permutation=axes_permutation,
|
||||||
|
added_axes=added_axes,
|
||||||
|
output_composite_axes=result_axes_grouping,
|
||||||
|
ellipsis_position_in_lhs=ellipsis_left,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def reduce(tensor: Tensor, pattern: str, reduction: Reduction, **axes_lengths: int) -> Tensor:
|
||||||
|
"""
|
||||||
|
einops.reduce provides combination of reordering and reduction using reader-friendly notation.
|
||||||
|
|
||||||
|
Examples for reduce operation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
>>> x = np.random.randn(100, 32, 64)
|
||||||
|
|
||||||
|
# perform max-reduction on the first axis
|
||||||
|
>>> y = reduce(x, 't b c -> b c', 'max')
|
||||||
|
|
||||||
|
# same as previous, but with clearer axes meaning
|
||||||
|
>>> y = reduce(x, 'time batch channel -> batch channel', 'max')
|
||||||
|
|
||||||
|
>>> x = np.random.randn(10, 20, 30, 40)
|
||||||
|
|
||||||
|
# 2d max-pooling with kernel size = 2 * 2 for image processing
|
||||||
|
>>> y1 = reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h2=2, w2=2)
|
||||||
|
|
||||||
|
# if one wants to go back to the original height and width, depth-to-space trick can be applied
|
||||||
|
>>> y2 = rearrange(y1, 'b (c h2 w2) h1 w1 -> b c (h1 h2) (w1 w2)', h2=2, w2=2)
|
||||||
|
>>> assert parse_shape(x, 'b _ h w') == parse_shape(y2, 'b _ h w')
|
||||||
|
|
||||||
|
# Adaptive 2d max-pooling to 3 * 4 grid
|
||||||
|
>>> reduce(x, 'b c (h1 h2) (w1 w2) -> b c h1 w1', 'max', h1=3, w1=4).shape
|
||||||
|
(10, 20, 3, 4)
|
||||||
|
|
||||||
|
# Global average pooling
|
||||||
|
>>> reduce(x, 'b c h w -> b c', 'mean').shape
|
||||||
|
(10, 20)
|
||||||
|
|
||||||
|
# Subtracting mean over batch for each channel
|
||||||
|
>>> y = x - reduce(x, 'b c h w -> () c () ()', 'mean')
|
||||||
|
|
||||||
|
# Subtracting per-image mean for each channel
|
||||||
|
>>> y = x - reduce(x, 'b c h w -> b c () ()', 'mean')
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
tensor: tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
|
||||||
|
list of tensors is also accepted, those should be of the same type and shape
|
||||||
|
pattern: string, reduction pattern
|
||||||
|
reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
|
||||||
|
alternatively, a callable f(tensor, reduced_axes) -> tensor can be provided.
|
||||||
|
axes_lengths: any additional specifications for dimensions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor of the same type as input
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
hashable_axes_lengths = tuple(sorted(axes_lengths.items()))
|
||||||
|
recipe = _prepare_transformation_recipe(pattern, reduction, axes_lengths=hashable_axes_lengths)
|
||||||
|
return _apply_recipe(recipe, tensor, reduction_type=reduction)
|
||||||
|
except EinopsError as e:
|
||||||
|
message = ' Error while processing {}-reduction pattern "{}".'.format(reduction, pattern)
|
||||||
|
if not isinstance(tensor, list):
|
||||||
|
message += '\n Input tensor shape: {}. '.format(get_backend(tensor).shape(tensor))
|
||||||
|
else:
|
||||||
|
message += '\n Input is list. '
|
||||||
|
message += 'Additional info: {}.'.format(axes_lengths)
|
||||||
|
raise EinopsError(message + '\n {}'.format(e))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@typing.overload
|
||||||
|
def rearrange(tensor: Tensor, pattern: str, **axes_length: int) -> Tensor: ...
|
||||||
|
@typing.overload
|
||||||
|
def rearrange(tensor: List[Tensor], pattern: str, **axes_lengths: int) -> Tensor: ...
|
||||||
|
|
||||||
|
|
||||||
|
def rearrange(tensor, pattern: str, **axes_lengths):
|
||||||
|
"""
|
||||||
|
einops.rearrange is a reader-friendly smart element reordering for multidimensional tensors.
|
||||||
|
This operation includes functionality of transpose (axes permutation), reshape (view), squeeze, unsqueeze,
|
||||||
|
stack, concatenate and other operations.
|
||||||
|
|
||||||
|
Examples for rearrange operation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# suppose we have a set of 32 images in "h w c" format (height-width-channel)
|
||||||
|
>>> images = [np.random.randn(30, 40, 3) for _ in range(32)]
|
||||||
|
|
||||||
|
# stack along first (batch) axis, output is a single array
|
||||||
|
>>> rearrange(images, 'b h w c -> b h w c').shape
|
||||||
|
(32, 30, 40, 3)
|
||||||
|
|
||||||
|
# concatenate images along height (vertical axis), 960 = 32 * 30
|
||||||
|
>>> rearrange(images, 'b h w c -> (b h) w c').shape
|
||||||
|
(960, 40, 3)
|
||||||
|
|
||||||
|
# concatenated images along horizontal axis, 1280 = 32 * 40
|
||||||
|
>>> rearrange(images, 'b h w c -> h (b w) c').shape
|
||||||
|
(30, 1280, 3)
|
||||||
|
|
||||||
|
# reordered axes to "b c h w" format for deep learning
|
||||||
|
>>> rearrange(images, 'b h w c -> b c h w').shape
|
||||||
|
(32, 3, 30, 40)
|
||||||
|
|
||||||
|
# flattened each image into a vector, 3600 = 30 * 40 * 3
|
||||||
|
>>> rearrange(images, 'b h w c -> b (c h w)').shape
|
||||||
|
(32, 3600)
|
||||||
|
|
||||||
|
# split each image into 4 smaller (top-left, top-right, bottom-left, bottom-right), 128 = 32 * 2 * 2
|
||||||
|
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
|
||||||
|
(128, 15, 20, 3)
|
||||||
|
|
||||||
|
# space-to-depth operation
|
||||||
|
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
|
||||||
|
(32, 15, 20, 12)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
When composing axes, C-order enumeration used (consecutive elements have different last axis)
|
||||||
|
Find more examples in einops tutorial.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
|
||||||
|
list of tensors is also accepted, those should be of the same type and shape
|
||||||
|
pattern: string, rearrangement pattern
|
||||||
|
axes_lengths: any additional specifications for dimensions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tensor of the same type as input. If possible, a view to the original tensor is returned.
|
||||||
|
|
||||||
|
"""
|
||||||
|
if isinstance(tensor, list):
|
||||||
|
if len(tensor) == 0:
|
||||||
|
raise TypeError("Rearrange can't be applied to an empty list")
|
||||||
|
tensor = get_backend(tensor[0]).stack_on_zeroth_dimension(tensor)
|
||||||
|
return reduce(tensor, pattern, reduction='rearrange', **axes_lengths)
|
||||||
|
|
||||||
|
|
||||||
|
def repeat(tensor: Tensor, pattern: str, **axes_lengths) -> Tensor:
|
||||||
|
"""
|
||||||
|
einops.repeat allows reordering elements and repeating them in arbitrary combinations.
|
||||||
|
This operation includes functionality of repeat, tile, broadcast functions.
|
||||||
|
|
||||||
|
Examples for repeat operation:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# a grayscale image (of shape height x width)
|
||||||
|
>>> image = np.random.randn(30, 40)
|
||||||
|
|
||||||
|
# change it to RGB format by repeating in each channel
|
||||||
|
>>> repeat(image, 'h w -> h w c', c=3).shape
|
||||||
|
(30, 40, 3)
|
||||||
|
|
||||||
|
# repeat image 2 times along height (vertical axis)
|
||||||
|
>>> repeat(image, 'h w -> (repeat h) w', repeat=2).shape
|
||||||
|
(60, 40)
|
||||||
|
|
||||||
|
# repeat image 2 time along height and 3 times along width
|
||||||
|
>>> repeat(image, 'h w -> h (repeat w)', repeat=3).shape
|
||||||
|
(30, 120)
|
||||||
|
|
||||||
|
# convert each pixel to a small square 2x2. Upsample image by 2x
|
||||||
|
>>> repeat(image, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
|
||||||
|
(60, 80)
|
||||||
|
|
||||||
|
# pixelate image first by downsampling by 2x, then upsampling
|
||||||
|
>>> downsampled = reduce(image, '(h h2) (w w2) -> h w', 'mean', h2=2, w2=2)
|
||||||
|
>>> repeat(downsampled, 'h w -> (h h2) (w w2)', h2=2, w2=2).shape
|
||||||
|
(30, 40)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
When composing axes, C-order enumeration used (consecutive elements have different last axis)
|
||||||
|
Find more examples in einops tutorial.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
tensor: tensor of any supported library (e.g. numpy.ndarray, jittor.array).
|
||||||
|
list of tensors is also accepted, those should be of the same type and shape
|
||||||
|
pattern: string, rearrangement pattern
|
||||||
|
axes_lengths: any additional specifications for dimensions
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor of the same type as input. If possible, a view to the original tensor is returned.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return reduce(tensor, pattern, reduction='repeat', **axes_lengths)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_shape(x, pattern: str):
|
||||||
|
"""
|
||||||
|
Parse a tensor shape to dictionary mapping axes names to their lengths.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Use underscore to skip the dimension in parsing.
|
||||||
|
>>> x = np.zeros([2, 3, 5, 7])
|
||||||
|
>>> parse_shape(x, 'batch _ h w')
|
||||||
|
{'batch': 2, 'h': 5, 'w': 7}
|
||||||
|
|
||||||
|
# `parse_shape` output can be used to specify axes_lengths for other operations:
|
||||||
|
>>> y = np.zeros([700])
|
||||||
|
>>> rearrange(y, '(b c h w) -> b c h w', **parse_shape(x, 'b _ h w')).shape
|
||||||
|
(2, 10, 5, 7)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
For symbolic frameworks may return symbols, not integers.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
x: tensor of any of supported frameworks
|
||||||
|
pattern: str, space separated names for axes, underscore means skip axis
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict, maps axes names to their lengths
|
||||||
|
"""
|
||||||
|
exp = ParsedExpression(pattern, allow_underscore=True)
|
||||||
|
shape = get_backend(x).shape(x)
|
||||||
|
if exp.has_composed_axes():
|
||||||
|
raise RuntimeError("Can't parse shape with composite axes: {pattern} {shape}".format(
|
||||||
|
pattern=pattern, shape=shape))
|
||||||
|
if len(shape) != len(exp.composition):
|
||||||
|
if exp.has_ellipsis:
|
||||||
|
if len(shape) < len(exp.composition) - 1:
|
||||||
|
raise RuntimeError("Can't parse shape with this number of dimensions: {pattern} {shape}".format(
|
||||||
|
pattern=pattern, shape=shape))
|
||||||
|
else:
|
||||||
|
raise RuntimeError("Can't parse shape with different number of dimensions: {pattern} {shape}".format(
|
||||||
|
pattern=pattern, shape=shape))
|
||||||
|
if exp.has_ellipsis:
|
||||||
|
ellipsis_idx = exp.composition.index(_ellipsis)
|
||||||
|
composition = (exp.composition[:ellipsis_idx] +
|
||||||
|
['_'] * (len(shape) - len(exp.composition) + 1) +
|
||||||
|
exp.composition[ellipsis_idx+1:])
|
||||||
|
else:
|
||||||
|
composition = exp.composition
|
||||||
|
result = {}
|
||||||
|
for (axis_name, ), axis_length in zip(composition, shape):
|
||||||
|
if axis_name != '_':
|
||||||
|
result[axis_name] = axis_length
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# this one is probably not needed in the public API
|
||||||
|
def _enumerate_directions(x):
|
||||||
|
"""
|
||||||
|
For an n-dimensional tensor, returns tensors to enumerate each axis.
|
||||||
|
```python
|
||||||
|
x = np.zeros([2, 3, 4]) # or any other tensor
|
||||||
|
i, j, k = _enumerate_directions(x)
|
||||||
|
result = i + 2*j + 3*k
|
||||||
|
```
|
||||||
|
|
||||||
|
`result[i, j, k] = i + 2j + 3k`, and also has the same shape as result
|
||||||
|
Works very similarly to numpy.ogrid (open indexing grid)
|
||||||
|
"""
|
||||||
|
backend = get_backend(x)
|
||||||
|
shape = backend.shape(x)
|
||||||
|
result = []
|
||||||
|
for axis_id, axis_length in enumerate(shape):
|
||||||
|
shape = [1] * len(shape)
|
||||||
|
shape[axis_id] = axis_length
|
||||||
|
result.append(backend.reshape(backend.arange(0, axis_length), shape))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def asnumpy(tensor) -> 'numpy.ndarray':
|
||||||
|
"""
|
||||||
|
Convert a tensor of an imperative framework (i.e. numpy/jittor.) to `numpy.ndarray`
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
tensor: tensor of any of known imperative framework
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
`numpy.ndarray`, converted to numpy
|
||||||
|
"""
|
||||||
|
return get_backend(tensor).to_numpy(tensor)
|
|
@ -0,0 +1,393 @@
|
||||||
|
"""
|
||||||
|
|
||||||
|
Indexing one array with the other(s).
|
||||||
|
|
||||||
|
Concept for discussion.
|
||||||
|
|
||||||
|
Notation targets hard cases, not simple ones, like indexing of 1d-array with another 1d-array
|
||||||
|
(notation supports that, but you can't simplify arr[ind], and there is no reason to)
|
||||||
|
|
||||||
|
Examples
|
||||||
|
|
||||||
|
1. query for every token in sequence a token in the image. Images and sequences are paired
|
||||||
|
einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, [h_indices_bt, w_indices_bt])
|
||||||
|
|
||||||
|
this is equivalent, so you can pass indexers idependently or together
|
||||||
|
einindex('b t c <- b h w c, [h, w] b t', arr_bhwc, np.asarray([h_indices_bt, w_indices_bt]))
|
||||||
|
|
||||||
|
after some thinking I decided that having first axis for indexing variable is not too restrictive,
|
||||||
|
but should simplify mapping of such cases.
|
||||||
|
For this reason [...] part should always go first in indexer.
|
||||||
|
|
||||||
|
This makes the largest difference with einindex https://github.com/malmaud/einindex,
|
||||||
|
which has almost identical grammar, but puts special dimension last, while we put it first.
|
||||||
|
This trick allows naturally decomposing multiindex into individual dimensions or visa versa.
|
||||||
|
|
||||||
|
|
||||||
|
2. query for every token in the video the most suitable word in a (matching) sentence
|
||||||
|
einindex('b t h w <- seq b, [seq] t b h w', arr_tbc, [t_indices_bhw])
|
||||||
|
|
||||||
|
note, that only one indexer is used, but still it has to be enclosed in the list.
|
||||||
|
That's a price for being generic. Alternatively leading singleton dimension can be added.
|
||||||
|
|
||||||
|
|
||||||
|
3. (not supported now, future planning)
|
||||||
|
for every timeframe in a video, find the token with the highest norm (across h and w), and compose a new stack of them
|
||||||
|
indices_2bt = argmax(x_bthwc.norm(dim=-1), 'b t h w -> [h, w] b t')
|
||||||
|
selected_embeddings_btc = einindex('b t c <- b t h w c, [h, w] b t', x_bthwc, indices_2bt)
|
||||||
|
|
||||||
|
while currently question is around 'how do we index',
|
||||||
|
it is important to pre-align that with a question 'what are natural ways to get indices'.
|
||||||
|
Most common are min/max. less common options: topk (works here), random sampling.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Some important properties of this notation:
|
||||||
|
- support for multiple indexers, including using a single tensor to keep multiple indexers
|
||||||
|
- 'batch' indexing, when some axes of indexer and array should be matched
|
||||||
|
- universal (one-indexing-to-rule-them-all)
|
||||||
|
- extensible for (named) ellipses, including variadic number of indexers
|
||||||
|
- extensible for einops-style compositions and decompositions
|
||||||
|
- extensible for outer indexing when indexers are not aligned
|
||||||
|
|
||||||
|
Current implementation based on python array api and uses loops,
|
||||||
|
because no appropriate indexing available in the standard.
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Union, TypeVar, Tuple
|
||||||
|
|
||||||
|
from jittor.einops import EinopsError
|
||||||
|
|
||||||
|
T = TypeVar('T')
|
||||||
|
|
||||||
|
|
||||||
|
class CompositionDecomposition:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
decomposed_shape: List[str],
|
||||||
|
composed_shape: List[List[str]],
|
||||||
|
):
|
||||||
|
flat_shape = []
|
||||||
|
for x in composed_shape:
|
||||||
|
flat_shape.extend(x)
|
||||||
|
|
||||||
|
self.compose_transposition: Tuple[int] = tuple([decomposed_shape.index(x) for x in flat_shape])
|
||||||
|
self.decompose_transposition: Tuple[int] = tuple([flat_shape.index(x) for x in decomposed_shape])
|
||||||
|
self.composed_shape = composed_shape
|
||||||
|
self.decomposed_shape = decomposed_shape
|
||||||
|
|
||||||
|
def decompose(self, x, known_axes_lengths: dict[str, int]):
|
||||||
|
xp = x.__array_namespace__()
|
||||||
|
shape = x.shape
|
||||||
|
|
||||||
|
flat_shape = []
|
||||||
|
|
||||||
|
for i, axis_group in enumerate(self.composed_shape):
|
||||||
|
unknown_axis_name = None
|
||||||
|
known_sizes_prod = 1
|
||||||
|
for axis_name in axis_group:
|
||||||
|
if axis_name in known_axes_lengths:
|
||||||
|
known_sizes_prod *= known_axes_lengths[axis_name]
|
||||||
|
else:
|
||||||
|
if unknown_axis_name is None:
|
||||||
|
unknown_axis_name = axis_name
|
||||||
|
else:
|
||||||
|
raise EinopsError("Can't infer the size")
|
||||||
|
|
||||||
|
if unknown_axis_name is None:
|
||||||
|
assert shape[i] == known_sizes_prod
|
||||||
|
else:
|
||||||
|
known_axes_lengths[unknown_axis_name] = shape[i] // known_sizes_prod
|
||||||
|
|
||||||
|
for axis in axis_group:
|
||||||
|
flat_shape.append(known_axes_lengths[axis])
|
||||||
|
|
||||||
|
x = xp.reshape(x, flat_shape)
|
||||||
|
return xp.permute_dims(x, self.decompose_transposition)
|
||||||
|
|
||||||
|
def compose(self, x, known_axes_lengths: dict[str, int]):
|
||||||
|
xp = x.__array_namespace__()
|
||||||
|
|
||||||
|
for axis_len, axis_name in zip(x.shape, self.decomposed_shape):
|
||||||
|
if axis_name in known_axes_lengths:
|
||||||
|
assert known_axes_lengths[axis_name] == axis_len
|
||||||
|
else:
|
||||||
|
known_axes_lengths[axis_name] = axis_len
|
||||||
|
|
||||||
|
x = xp.permute_dims(x, self.compose_transposition)
|
||||||
|
new_shape = []
|
||||||
|
for axis_group in self.composed_shape:
|
||||||
|
composed_axis_size = 1
|
||||||
|
for axis_name in axis_group:
|
||||||
|
composed_axis_size *= known_axes_lengths[axis_name]
|
||||||
|
new_shape.append(composed_axis_size)
|
||||||
|
|
||||||
|
return xp.reshape(x, tuple(new_shape))
|
||||||
|
|
||||||
|
|
||||||
|
def arange_at_position(xp, n_axes, axis, axis_len, device=None):
|
||||||
|
x = xp.arange(axis_len, dtype=xp.int64, device=device)
|
||||||
|
shape = [1] * n_axes
|
||||||
|
shape[axis] = axis_len
|
||||||
|
x = xp.reshape(x, shape)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class IndexingFormula:
|
||||||
|
|
||||||
|
def __init__(self, pattern: str):
|
||||||
|
"""
|
||||||
|
:param pattern: example 'b t c <- b hsel wsel c, [hsel, wsel] b t'
|
||||||
|
"""
|
||||||
|
self.pattern = pattern
|
||||||
|
left, right = pattern.split('<-')
|
||||||
|
arg_split = right.index(',')
|
||||||
|
arr_pattern, ind_pattern = right[:arg_split], right[arg_split + 1:]
|
||||||
|
ind_pattern = ind_pattern.strip()
|
||||||
|
# print(
|
||||||
|
# arr_pattern, '\n',
|
||||||
|
# ind_pattern,
|
||||||
|
# )
|
||||||
|
assert ind_pattern.startswith('['), 'composition axis should go first in indexer (second argument) [h w] i j k'
|
||||||
|
composition_start = ind_pattern.index('[')
|
||||||
|
composition_end = ind_pattern.index(']')
|
||||||
|
composition = ind_pattern[composition_start + 1: composition_end]
|
||||||
|
ind_other_axes = ind_pattern[composition_end + 1:]
|
||||||
|
|
||||||
|
self.result_axes_names = left.split()
|
||||||
|
self.array_axes_names = arr_pattern.split()
|
||||||
|
self.indexing_axes_names = [x.strip() for x in composition.split(',')]
|
||||||
|
self.indexer_other_axes_names = ind_other_axes.split()
|
||||||
|
|
||||||
|
for group_name, group in [
|
||||||
|
('result', self.result_axes_names),
|
||||||
|
('array', self.array_axes_names),
|
||||||
|
('indexer', self.indexing_axes_names + self.indexer_other_axes_names),
|
||||||
|
]:
|
||||||
|
if len(set(group)) != len(group):
|
||||||
|
# need more verbosity, which axis, raise
|
||||||
|
raise EinopsError(f'{group_name} pattern ({group}) contains a duplicated axis')
|
||||||
|
|
||||||
|
axis_groups = [
|
||||||
|
self.result_axes_names,
|
||||||
|
self.array_axes_names,
|
||||||
|
self.indexing_axes_names,
|
||||||
|
self.indexer_other_axes_names,
|
||||||
|
]
|
||||||
|
|
||||||
|
all_axes = set()
|
||||||
|
for group in axis_groups:
|
||||||
|
all_axes.update(group)
|
||||||
|
|
||||||
|
self.indexer_axes = []
|
||||||
|
self.batch_axes = []
|
||||||
|
self.result_and_index_axes = []
|
||||||
|
self.result_and_array_axes = []
|
||||||
|
|
||||||
|
for axis in all_axes:
|
||||||
|
presence = tuple(axis in g for g in axis_groups)
|
||||||
|
# want match-case here. sweet dreams
|
||||||
|
if presence == (False, True, True, False):
|
||||||
|
self.indexer_axes.append(axis)
|
||||||
|
elif presence[2]:
|
||||||
|
raise EinopsError(f'Wrong usage of indexer variable {axis}')
|
||||||
|
elif presence == (True, True, False, True):
|
||||||
|
self.batch_axes.append(axis)
|
||||||
|
elif presence == (True, False, False, True):
|
||||||
|
self.result_and_index_axes.append(axis)
|
||||||
|
elif presence == (True, True, False, False):
|
||||||
|
self.result_and_array_axes.append(axis)
|
||||||
|
else:
|
||||||
|
# TODO better categorization of wrong usage patterns
|
||||||
|
raise EinopsError(f'{axis} is used incorrectly in {pattern}')
|
||||||
|
|
||||||
|
assert set(self.indexer_axes) == set(self.indexing_axes_names)
|
||||||
|
# order of these variables matters, since we can't lose mapping here
|
||||||
|
self.indexer_axes = self.indexing_axes_names
|
||||||
|
|
||||||
|
self.array_composition = CompositionDecomposition(
|
||||||
|
decomposed_shape=self.array_axes_names,
|
||||||
|
composed_shape=[self.batch_axes + self.indexer_axes, self.result_and_array_axes],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.index_composition = CompositionDecomposition(
|
||||||
|
decomposed_shape=self.indexer_other_axes_names,
|
||||||
|
# single axis after composition
|
||||||
|
composed_shape=[self.batch_axes + self.result_and_index_axes],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.result_composition = CompositionDecomposition(
|
||||||
|
decomposed_shape=self.result_axes_names,
|
||||||
|
composed_shape=[self.batch_axes + self.result_and_index_axes, self.result_and_array_axes],
|
||||||
|
)
|
||||||
|
|
||||||
|
def apply_to_array_api(self, arr: T, ind: Union[T, List[T]]):
|
||||||
|
known_axes_sizes: dict[str, int] = {}
|
||||||
|
xp = arr.__array_namespace__()
|
||||||
|
|
||||||
|
if not isinstance(ind, list):
|
||||||
|
ind = [ind[i, ...] for i in range(ind.shape[0])]
|
||||||
|
|
||||||
|
for indexer in ind:
|
||||||
|
assert len(indexer.shape) == len(self.indexer_other_axes_names)
|
||||||
|
|
||||||
|
# step 1. transpose, reshapes of arr; learn its dimensions
|
||||||
|
arr_2d = self.array_composition.compose(arr, known_axes_sizes)
|
||||||
|
|
||||||
|
# step 2. compute shifts and create an actual indexing array
|
||||||
|
shift = 1
|
||||||
|
full_index = xp.zeros([1] * len(ind[0].shape), dtype=xp.int64, device=arr.device)
|
||||||
|
|
||||||
|
# original order: [*batch-like axes, *indexing_axes,]
|
||||||
|
# now we need to traverse them in the opposite direction
|
||||||
|
|
||||||
|
for axis_name, indexer in list(zip(self.indexing_axes_names, ind))[::-1]:
|
||||||
|
full_index = full_index + shift * (indexer % known_axes_sizes[axis_name])
|
||||||
|
shift *= known_axes_sizes[axis_name]
|
||||||
|
|
||||||
|
for axis_name in self.batch_axes[::-1]:
|
||||||
|
axis_id = self.indexer_other_axes_names.index(axis_name)
|
||||||
|
full_index = full_index + arange_at_position(
|
||||||
|
xp, len(self.indexer_other_axes_names), axis=axis_id, axis_len=known_axes_sizes[axis_name],
|
||||||
|
device=arr.device,
|
||||||
|
) * shift
|
||||||
|
shift *= known_axes_sizes[axis_name]
|
||||||
|
|
||||||
|
assert shift == arr_2d.shape[0]
|
||||||
|
|
||||||
|
# step 3. Flatten index
|
||||||
|
full_index = self.index_composition.compose(full_index, known_axes_sizes)
|
||||||
|
|
||||||
|
# step 4. indexing
|
||||||
|
# python array api lacks any integer indexing, so... I use loops.
|
||||||
|
# did you know that there is conceptual programming ... just like art?
|
||||||
|
# result_2d = arr_2d[full_index]
|
||||||
|
result_2d = xp.stack([arr_2d[full_index[i], :] for i in range(full_index.shape[0])])
|
||||||
|
|
||||||
|
# step 5. doing resulting
|
||||||
|
result = self.result_composition.decompose(result_2d, known_axes_sizes)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def einindex(pattern: str, arr: T, /, ind: Union[T, List[T]]):
|
||||||
|
"""
|
||||||
|
Demonstrates how einindex should work.
|
||||||
|
Supports data-api compliant arrays.
|
||||||
|
"""
|
||||||
|
formula = IndexingFormula(pattern)
|
||||||
|
return formula.apply_to_array_api(arr, ind)
|
||||||
|
|
||||||
|
|
||||||
|
def test_composition_and_decomposition():
|
||||||
|
import numpy.array_api as np
|
||||||
|
x = np.arange(2 * 3 * 5 * 7)
|
||||||
|
x = np.reshape(x, (2, 3, 5, 7))
|
||||||
|
comp = CompositionDecomposition(
|
||||||
|
decomposed_shape=['a', 'b', 'c', 'd'],
|
||||||
|
composed_shape=[['a', 'b'], ['c', 'd']],
|
||||||
|
)
|
||||||
|
assert comp.compose(x, known_axes_lengths={}).shape == (2 * 3, 5 * 7)
|
||||||
|
|
||||||
|
y = CompositionDecomposition(
|
||||||
|
decomposed_shape=['a', 'b', 'c', 'd'],
|
||||||
|
composed_shape=[['a', 'b'], [], ['c', 'd']],
|
||||||
|
).compose(x, {})
|
||||||
|
assert y.shape == (2 * 3, 1, 5 * 7)
|
||||||
|
assert np.all(np.reshape(x, (-1,)) == np.reshape(y, (-1,)))
|
||||||
|
|
||||||
|
comp = CompositionDecomposition(
|
||||||
|
decomposed_shape=['a', 'b', 'e', 'c', 'd'],
|
||||||
|
composed_shape=[['e', 'c'], ['b'], ['a', 'd']],
|
||||||
|
)
|
||||||
|
x = np.arange(2 * 3 * 5 * 7 * 3)
|
||||||
|
x = np.reshape(x, (2, 3, 5, 7, 3))
|
||||||
|
|
||||||
|
axes = {}
|
||||||
|
y = comp.compose(x, axes)
|
||||||
|
x2 = comp.decompose(y, axes)
|
||||||
|
assert np.all(x == x2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_simple_indexing():
|
||||||
|
import numpy.array_api as np
|
||||||
|
|
||||||
|
# simple 2d test
|
||||||
|
arr = np.reshape(np.arange(5 * 7), (5, 7))
|
||||||
|
ind = np.arange(7) % 5
|
||||||
|
x = einindex('j <- i j, [i] j', arr, [ind])
|
||||||
|
for j, i in enumerate(ind):
|
||||||
|
assert arr[i, j] == x[j]
|
||||||
|
|
||||||
|
y = einindex('j <- j i, [i] j', np.permute_dims(arr, (1, 0)), [ind])
|
||||||
|
for j, i in enumerate(ind):
|
||||||
|
assert arr[i, j] == y[j]
|
||||||
|
|
||||||
|
|
||||||
|
def test_multidimensional_indexing():
|
||||||
|
import numpy.array_api as np
|
||||||
|
|
||||||
|
embedding_bhwc = (
|
||||||
|
+ arange_at_position(np, 4, 0, 2) * 1000
|
||||||
|
+ arange_at_position(np, 4, 1, 3) * 100
|
||||||
|
+ arange_at_position(np, 4, 2, 5) * 10
|
||||||
|
+ arange_at_position(np, 4, 3, 7) * 1
|
||||||
|
)
|
||||||
|
|
||||||
|
hindices_bt = np.reshape(np.arange(6), (2, 3)) % 3
|
||||||
|
windices_bt = np.reshape(np.arange(6), (2, 3)) % 5
|
||||||
|
|
||||||
|
# imagine that you have pairs of image <> sentence
|
||||||
|
# your goal is to get most suitable token from image for every token in sentence
|
||||||
|
# thus for every token in sentence you compute best k and v
|
||||||
|
|
||||||
|
result = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, [hindices_bt, windices_bt])
|
||||||
|
# example of using a single array for indexing multiple axes
|
||||||
|
hw_indices_bt = np.stack([hindices_bt, windices_bt])
|
||||||
|
result2 = einindex('c t b <- b h w c, [h, w] b t', embedding_bhwc, hw_indices_bt)
|
||||||
|
assert np.all(result == result2)
|
||||||
|
|
||||||
|
# check vs manual element computation
|
||||||
|
result_manual = result * 0
|
||||||
|
for b in range(2):
|
||||||
|
for t in range(3):
|
||||||
|
for c in range(7):
|
||||||
|
h = hindices_bt[b, t]
|
||||||
|
w = windices_bt[b, t]
|
||||||
|
result_manual[c, t, b] = embedding_bhwc[b, h, w, c]
|
||||||
|
|
||||||
|
assert np.all(result == result_manual)
|
||||||
|
|
||||||
|
|
||||||
|
def test_reverse_indexing():
|
||||||
|
import numpy.array_api as np
|
||||||
|
|
||||||
|
C, T, B = 2, 3, 5
|
||||||
|
# G = GPU, batch-like varaible
|
||||||
|
G = 4
|
||||||
|
H = 7
|
||||||
|
W = 9
|
||||||
|
|
||||||
|
arr_gtbc = (
|
||||||
|
+ arange_at_position(np, 4, 0, G) * 1000
|
||||||
|
+ arange_at_position(np, 4, 1, T) * 100
|
||||||
|
+ arange_at_position(np, 4, 2, B) * 10
|
||||||
|
+ arange_at_position(np, 4, 3, C) * 1
|
||||||
|
)
|
||||||
|
|
||||||
|
t_indices_gbhw = np.reshape(np.arange(G * B * H * W), (G, B, H, W)) % T
|
||||||
|
|
||||||
|
result = einindex('g b c h w <- g t b c, [t] g b h w', arr_gtbc, [t_indices_gbhw])
|
||||||
|
|
||||||
|
result_manual = result * 0
|
||||||
|
for g in range(G):
|
||||||
|
for b in range(B):
|
||||||
|
for c in range(C):
|
||||||
|
for h in range(H):
|
||||||
|
for w in range(W):
|
||||||
|
t = t_indices_gbhw[g, b, h, w]
|
||||||
|
result_manual[g, b, c, h, w] = arr_gtbc[g, t, b, c]
|
||||||
|
|
||||||
|
assert np.all(result == result_manual)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,79 @@
|
||||||
|
__author__ = 'Alex Rogozhnikov'
|
||||||
|
|
||||||
|
import functools
|
||||||
|
|
||||||
|
from jittor.einops.einops import _apply_recipe
|
||||||
|
|
||||||
|
from jittor.einops.einops import TransformRecipe, _prepare_transformation_recipe
|
||||||
|
from jittor.einops import EinopsError
|
||||||
|
|
||||||
|
|
||||||
|
class RearrangeMixin:
|
||||||
|
"""
|
||||||
|
Rearrange layer behaves identically to einops.rearrange operation.
|
||||||
|
|
||||||
|
:param pattern: str, rearrangement pattern
|
||||||
|
:param axes_lengths: any additional specification of dimensions
|
||||||
|
|
||||||
|
See einops.rearrange for source_examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pattern, **axes_lengths):
|
||||||
|
super().__init__()
|
||||||
|
self.pattern = pattern
|
||||||
|
self.axes_lengths = axes_lengths
|
||||||
|
self._recipe = self.recipe() # checking parameters
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
params = repr(self.pattern)
|
||||||
|
for axis, length in self.axes_lengths.items():
|
||||||
|
params += ', {}={}'.format(axis, length)
|
||||||
|
return '{}({})'.format(self.__class__.__name__, params)
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=1024)
|
||||||
|
def recipe(self) -> TransformRecipe:
|
||||||
|
try:
|
||||||
|
hashable_lengths = tuple(sorted(self.axes_lengths.items()))
|
||||||
|
return _prepare_transformation_recipe(self.pattern, operation='rearrange', axes_lengths=hashable_lengths)
|
||||||
|
except EinopsError as e:
|
||||||
|
raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e))
|
||||||
|
|
||||||
|
def _apply_recipe(self, x):
|
||||||
|
return _apply_recipe(self._recipe, x, reduction_type='rearrange')
|
||||||
|
|
||||||
|
|
||||||
|
class ReduceMixin:
|
||||||
|
"""
|
||||||
|
Reduce layer behaves identically to einops.reduce operation.
|
||||||
|
|
||||||
|
:param pattern: str, rearrangement pattern
|
||||||
|
:param reduction: one of available reductions ('min', 'max', 'sum', 'mean', 'prod'), case-sensitive
|
||||||
|
:param axes_lengths: any additional specification of dimensions
|
||||||
|
|
||||||
|
See einops.reduce for source_examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, pattern, reduction, **axes_lengths):
|
||||||
|
super().__init__()
|
||||||
|
self.pattern = pattern
|
||||||
|
self.reduction = reduction
|
||||||
|
self.axes_lengths = axes_lengths
|
||||||
|
self._recipe = self.recipe() # checking parameters
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
params = '{!r}, {!r}'.format(self.pattern, self.reduction)
|
||||||
|
for axis, length in self.axes_lengths.items():
|
||||||
|
params += ', {}={}'.format(axis, length)
|
||||||
|
return '{}({})'.format(self.__class__.__name__, params)
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=1024)
|
||||||
|
def recipe(self) -> TransformRecipe:
|
||||||
|
try:
|
||||||
|
hashable_lengths = tuple(sorted(self.axes_lengths.items()))
|
||||||
|
return _prepare_transformation_recipe(
|
||||||
|
self.pattern, operation=self.reduction, axes_lengths=hashable_lengths)
|
||||||
|
except EinopsError as e:
|
||||||
|
raise EinopsError(' Error while preparing {!r}\n {}'.format(self, e))
|
||||||
|
|
||||||
|
def _apply_recipe(self, x):
|
||||||
|
return _apply_recipe(self._recipe, x, reduction_type=self.reduction)
|
|
@ -0,0 +1,175 @@
|
||||||
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
from jittor.einops import EinopsError
|
||||||
|
from jittor.einops.parsing import ParsedExpression
|
||||||
|
import warnings
|
||||||
|
import string
|
||||||
|
from jittor.einops.einops import _product
|
||||||
|
|
||||||
|
|
||||||
|
def _report_axes(axes: set, report_message: str):
|
||||||
|
if len(axes) > 0:
|
||||||
|
raise EinopsError(report_message.format(axes))
|
||||||
|
|
||||||
|
|
||||||
|
class _EinmixMixin:
|
||||||
|
def __init__(self, pattern, weight_shape, bias_shape=None, **axes_lengths):
|
||||||
|
"""
|
||||||
|
EinMix - Einstein summation with automated tensor management and axis packing/unpacking.
|
||||||
|
|
||||||
|
EinMix is an advanced tool, helpful tutorial:
|
||||||
|
https://github.com/arogozhnikov/einops/blob/master/docs/3-einmix-layer.ipynb
|
||||||
|
|
||||||
|
Imagine taking einsum with two arguments, one of each input, and one - tensor with weights
|
||||||
|
>>> einsum('time batch channel_in, channel_in channel_out -> time batch channel_out', input, weight)
|
||||||
|
|
||||||
|
This layer manages weights for you, syntax highlights separate role of weight matrix
|
||||||
|
>>> EinMix('time batch channel_in -> time batch channel_out', weight_shape='channel_in channel_out')
|
||||||
|
But otherwise it is the same einsum under the hood.
|
||||||
|
|
||||||
|
Simple linear layer with bias term (you have one like that in your framework)
|
||||||
|
>>> EinMix('t b cin -> t b cout', weight_shape='cin cout', bias_shape='cout', cin=10, cout=20)
|
||||||
|
There is restriction to mix the last axis. Let's mix along height
|
||||||
|
>>> EinMix('h w c-> hout w c', weight_shape='h hout', bias_shape='hout', h=32, hout=32)
|
||||||
|
Channel-wise multiplication (like one used in normalizations)
|
||||||
|
>>> EinMix('t b c -> t b c', weight_shape='c', c=128)
|
||||||
|
Separate dense layer within each head, no connection between different heads
|
||||||
|
>>> EinMix('t b (head cin) -> t b (head cout)', weight_shape='head cin cout', ...)
|
||||||
|
|
||||||
|
... ah yes, you need to specify all dimensions of weight shape/bias shape in parameters.
|
||||||
|
|
||||||
|
Use cases:
|
||||||
|
- when channel dimension is not last, use EinMix, not transposition
|
||||||
|
- patch/segment embeddings
|
||||||
|
- when need only within-group connections to reduce number of weights and computations
|
||||||
|
- perfect as a part of sequential models
|
||||||
|
- next-gen MLPs (follow tutorial to learn more)
|
||||||
|
|
||||||
|
Uniform He initialization is applied to weight tensor and encounters for number of elements mixed.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
:param pattern: transformation pattern, left side - dimensions of input, right side - dimensions of output
|
||||||
|
:param weight_shape: axes of weight. Tensor od this shape is created, stored, and optimized in a layer
|
||||||
|
:param bias_shape: axes of bias added to output.
|
||||||
|
:param axes_lengths: dimensions of weight tensor
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
self.pattern = pattern
|
||||||
|
self.weight_shape = weight_shape
|
||||||
|
self.bias_shape = bias_shape
|
||||||
|
self.axes_lengths = axes_lengths
|
||||||
|
|
||||||
|
left_pattern, right_pattern = pattern.split('->')
|
||||||
|
left = ParsedExpression(left_pattern)
|
||||||
|
right = ParsedExpression(right_pattern)
|
||||||
|
weight = ParsedExpression(weight_shape)
|
||||||
|
_report_axes(
|
||||||
|
set.difference(right.identifiers, {*left.identifiers, *weight.identifiers}),
|
||||||
|
'Unrecognized identifiers on the right side of EinMix {}'
|
||||||
|
)
|
||||||
|
|
||||||
|
if left.has_ellipsis or right.has_ellipsis or weight.has_ellipsis:
|
||||||
|
raise EinopsError('Ellipsis is not supported in EinMix (right now)')
|
||||||
|
if any(x.has_non_unitary_anonymous_axes for x in [left, right, weight]):
|
||||||
|
raise EinopsError('Anonymous axes (numbers) are not allowed in EinMix')
|
||||||
|
if '(' in weight_shape or ')' in weight_shape:
|
||||||
|
raise EinopsError(f'Parenthesis is not allowed in weight shape: {weight_shape}')
|
||||||
|
|
||||||
|
pre_reshape_pattern = None
|
||||||
|
pre_reshape_lengths = None
|
||||||
|
post_reshape_pattern = None
|
||||||
|
if any(len(group) != 1 for group in left.composition):
|
||||||
|
names = []
|
||||||
|
for group in left.composition:
|
||||||
|
names += group
|
||||||
|
composition = ' '.join(names)
|
||||||
|
pre_reshape_pattern = f'{left_pattern}->{composition}'
|
||||||
|
pre_reshape_lengths = {name: length for name, length in self.axes_lengths.items() if name in names}
|
||||||
|
|
||||||
|
if any(len(group) != 1 for group in right.composition):
|
||||||
|
names = []
|
||||||
|
for group in right.composition:
|
||||||
|
names += group
|
||||||
|
composition = ' '.join(names)
|
||||||
|
post_reshape_pattern = f'{composition}->{right_pattern}'
|
||||||
|
|
||||||
|
self._create_rearrange_layers(pre_reshape_pattern, pre_reshape_lengths, post_reshape_pattern, {})
|
||||||
|
|
||||||
|
for axis in weight.identifiers:
|
||||||
|
if axis not in axes_lengths:
|
||||||
|
raise EinopsError('Dimension {} of weight should be specified'.format(axis))
|
||||||
|
_report_axes(
|
||||||
|
set.difference(set(axes_lengths), {*left.identifiers, *weight.identifiers}),
|
||||||
|
'Axes {} are not used in pattern',
|
||||||
|
)
|
||||||
|
_report_axes(
|
||||||
|
set.difference(weight.identifiers, {*left.identifiers, *right.identifiers}),
|
||||||
|
'Weight axes {} are redundant'
|
||||||
|
)
|
||||||
|
if len(weight.identifiers) == 0:
|
||||||
|
warnings.warn('EinMix: weight has no dimensions (means multiplication by a number)')
|
||||||
|
|
||||||
|
_weight_shape = [axes_lengths[axis] for axis, in weight.composition]
|
||||||
|
# single output element is a combination of fan_in input elements
|
||||||
|
_fan_in = _product([axes_lengths[axis] for axis, in weight.composition if axis not in right.identifiers])
|
||||||
|
if bias_shape is not None:
|
||||||
|
if not isinstance(bias_shape, str):
|
||||||
|
raise EinopsError('bias shape should be string specifying which axes bias depends on')
|
||||||
|
bias = ParsedExpression(bias_shape)
|
||||||
|
_report_axes(
|
||||||
|
set.difference(bias.identifiers, right.identifiers),
|
||||||
|
'Bias axes {} not present in output'
|
||||||
|
)
|
||||||
|
_report_axes(
|
||||||
|
set.difference(bias.identifiers, set(axes_lengths)),
|
||||||
|
'Sizes not provided for bias axes {}',
|
||||||
|
)
|
||||||
|
|
||||||
|
_bias_shape = []
|
||||||
|
for axes in right.composition:
|
||||||
|
for axis in axes:
|
||||||
|
if axis in bias.identifiers:
|
||||||
|
_bias_shape.append(axes_lengths[axis])
|
||||||
|
else:
|
||||||
|
_bias_shape.append(1)
|
||||||
|
else:
|
||||||
|
_bias_shape = None
|
||||||
|
_bias_input_size = None
|
||||||
|
|
||||||
|
weight_bound = (3 / _fan_in) ** 0.5
|
||||||
|
bias_bound = (1 / _fan_in) ** 0.5
|
||||||
|
self._create_parameters(_weight_shape, weight_bound, _bias_shape, bias_bound)
|
||||||
|
|
||||||
|
# rewrite einsum expression with single-letter latin identifiers so that
|
||||||
|
# expression will be understood by any framework
|
||||||
|
mapping2letters = {*left.identifiers, *right.identifiers, *weight.identifiers}
|
||||||
|
mapping2letters = {k: letter for letter, k in zip(string.ascii_lowercase, mapping2letters)}
|
||||||
|
|
||||||
|
def write_flat(axes: list):
|
||||||
|
return ''.join(mapping2letters[axis] for axis in axes)
|
||||||
|
|
||||||
|
self.einsum_pattern: str = '{},{}->{}'.format(
|
||||||
|
write_flat(left.flat_axes_order()),
|
||||||
|
write_flat(weight.flat_axes_order()),
|
||||||
|
write_flat(right.flat_axes_order()),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_rearrange_layers(self,
|
||||||
|
pre_reshape_pattern: Optional[str],
|
||||||
|
pre_reshape_lengths: Optional[Dict],
|
||||||
|
post_reshape_pattern: Optional[str],
|
||||||
|
post_reshape_lengths: Optional[Dict]):
|
||||||
|
raise NotImplementedError('Should be defined in framework implementations')
|
||||||
|
|
||||||
|
def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
|
||||||
|
""" Shape and implementations """
|
||||||
|
raise NotImplementedError('Should be defined in framework implementations')
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
params = repr(self.pattern)
|
||||||
|
params += f", '{self.weight_shape}'"
|
||||||
|
if self.bias_shape is not None:
|
||||||
|
params += f", '{self.bias_shape}'"
|
||||||
|
for axis, length in self.axes_lengths.items():
|
||||||
|
params += ', {}={}'.format(axis, length)
|
||||||
|
return '{}({})'.format(self.__class__.__name__, params)
|
|
@ -0,0 +1,64 @@
|
||||||
|
from typing import Optional, Dict
|
||||||
|
|
||||||
|
import jittor as jt
|
||||||
|
from jittor import nn
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from jittor.einops.layers import RearrangeMixin, ReduceMixin
|
||||||
|
from jittor.einops.layers._einmix import _EinmixMixin
|
||||||
|
from jittor.einops._jittor_specific import apply_for_scriptable_jittor
|
||||||
|
|
||||||
|
__author__ = 'Ruiyang Liu'
|
||||||
|
|
||||||
|
|
||||||
|
class Rearrange(RearrangeMixin, jt.nn.Module):
|
||||||
|
def execute(self, input):
|
||||||
|
return apply_for_scriptable_jittor(self._recipe, input, reduction_type='rearrange')
|
||||||
|
|
||||||
|
def _apply_recipe(self, x):
|
||||||
|
# overriding parent method to prevent it's scripting
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Reduce(ReduceMixin, jt.nn.Module):
|
||||||
|
def execute(self, input):
|
||||||
|
return apply_for_scriptable_jittor(self._recipe, input, reduction_type=self.reduction)
|
||||||
|
|
||||||
|
def _apply_recipe(self, x):
|
||||||
|
# overriding parent method to prevent it's scripting
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class EinMix(_EinmixMixin, jt.nn.Module):
|
||||||
|
def _create_parameters(self, weight_shape, weight_bound, bias_shape, bias_bound):
|
||||||
|
self.weight = jt.zeros(weight_shape)
|
||||||
|
nn.init.uniform_(self.weight, low = -weight_bound, high = weight_bound)
|
||||||
|
if bias_shape is not None:
|
||||||
|
self.bias = jt.zeros(bias_shape)
|
||||||
|
nn.init.uniform_(self.bias, low = -bias_bound, high = bias_bound)
|
||||||
|
else:
|
||||||
|
self.bias = None
|
||||||
|
|
||||||
|
def _create_rearrange_layers(self,
|
||||||
|
pre_reshape_pattern: Optional[str],
|
||||||
|
pre_reshape_lengths: Optional[Dict],
|
||||||
|
post_reshape_pattern: Optional[str],
|
||||||
|
post_reshape_lengths: Optional[Dict],
|
||||||
|
):
|
||||||
|
self.pre_rearrange = None
|
||||||
|
if pre_reshape_pattern is not None:
|
||||||
|
self.pre_rearrange = Rearrange(pre_reshape_pattern, **pre_reshape_lengths)
|
||||||
|
|
||||||
|
self.post_rearrange = None
|
||||||
|
if post_reshape_pattern is not None:
|
||||||
|
self.post_rearrange = Rearrange(post_reshape_pattern, **post_reshape_lengths)
|
||||||
|
|
||||||
|
def execute(self, input):
|
||||||
|
if self.pre_rearrange is not None:
|
||||||
|
input = self.pre_rearrange(input)
|
||||||
|
result = jt.linalg.einsum(self.einsum_pattern, input, self.weight)
|
||||||
|
if self.bias is not None:
|
||||||
|
result += self.bias
|
||||||
|
if self.post_rearrange is not None:
|
||||||
|
result = self.post_rearrange(result)
|
||||||
|
return result
|
|
@ -0,0 +1,147 @@
|
||||||
|
from jittor.einops import EinopsError
|
||||||
|
import keyword
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Set, Tuple
|
||||||
|
|
||||||
|
_ellipsis: str = '…' # NB, this is a single unicode symbol. String is used as it is not a list, but can be iterated
|
||||||
|
|
||||||
|
|
||||||
|
class AnonymousAxis(object):
|
||||||
|
"""Important thing: all instances of this class are not equal to each other """
|
||||||
|
|
||||||
|
def __init__(self, value: str):
|
||||||
|
self.value = int(value)
|
||||||
|
if self.value <= 1:
|
||||||
|
if self.value == 1:
|
||||||
|
raise EinopsError('No need to create anonymous axis of length 1. Report this as an issue')
|
||||||
|
else:
|
||||||
|
raise EinopsError('Anonymous axis should have positive length, not {}'.format(self.value))
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "{}-axis".format(str(self.value))
|
||||||
|
|
||||||
|
|
||||||
|
class ParsedExpression:
|
||||||
|
"""
|
||||||
|
non-mutable structure that contains information about one side of expression (e.g. 'b c (h w)')
|
||||||
|
and keeps some information important for downstream
|
||||||
|
"""
|
||||||
|
def __init__(self, expression, *, allow_underscore: bool = False):
|
||||||
|
self.has_ellipsis: bool = False
|
||||||
|
self.has_ellipsis_parenthesized: Optional[bool] = None
|
||||||
|
self.identifiers: Set[str] = set()
|
||||||
|
# that's axes like 2, 3, 4 or 5. Axes with size 1 are exceptional and replaced with empty composition
|
||||||
|
self.has_non_unitary_anonymous_axes: bool = False
|
||||||
|
# composition keeps structure of composite axes, see how different corner cases are handled in tests
|
||||||
|
self.composition = []
|
||||||
|
if '.' in expression:
|
||||||
|
if '...' not in expression:
|
||||||
|
raise EinopsError('Expression may contain dots only inside ellipsis (...)')
|
||||||
|
if str.count(expression, '...') != 1 or str.count(expression, '.') != 3:
|
||||||
|
raise EinopsError(
|
||||||
|
'Expression may contain dots only inside ellipsis (...); only one ellipsis for tensor ')
|
||||||
|
expression = expression.replace('...', _ellipsis)
|
||||||
|
self.has_ellipsis = True
|
||||||
|
|
||||||
|
bracket_group = None
|
||||||
|
|
||||||
|
def add_axis_name(x):
|
||||||
|
if x is not None:
|
||||||
|
if x in self.identifiers:
|
||||||
|
if not (allow_underscore and x == "_"):
|
||||||
|
raise EinopsError('Indexing expression contains duplicate dimension "{}"'.format(x))
|
||||||
|
if x == _ellipsis:
|
||||||
|
self.identifiers.add(_ellipsis)
|
||||||
|
if bracket_group is None:
|
||||||
|
self.composition.append(_ellipsis)
|
||||||
|
self.has_ellipsis_parenthesized = False
|
||||||
|
else:
|
||||||
|
bracket_group.append(_ellipsis)
|
||||||
|
self.has_ellipsis_parenthesized = True
|
||||||
|
else:
|
||||||
|
is_number = str.isdecimal(x)
|
||||||
|
if is_number and int(x) == 1:
|
||||||
|
# handling the case of anonymous axis of length 1
|
||||||
|
if bracket_group is None:
|
||||||
|
self.composition.append([])
|
||||||
|
else:
|
||||||
|
pass # no need to think about 1s inside parenthesis
|
||||||
|
return
|
||||||
|
is_axis_name, reason = self.check_axis_name_return_reason(x, allow_underscore=allow_underscore)
|
||||||
|
if not (is_number or is_axis_name):
|
||||||
|
raise EinopsError('Invalid axis identifier: {}\n{}'.format(x, reason))
|
||||||
|
if is_number:
|
||||||
|
x = AnonymousAxis(x)
|
||||||
|
self.identifiers.add(x)
|
||||||
|
if is_number:
|
||||||
|
self.has_non_unitary_anonymous_axes = True
|
||||||
|
if bracket_group is None:
|
||||||
|
self.composition.append([x])
|
||||||
|
else:
|
||||||
|
bracket_group.append(x)
|
||||||
|
|
||||||
|
current_identifier = None
|
||||||
|
for char in expression:
|
||||||
|
if char in '() ':
|
||||||
|
add_axis_name(current_identifier)
|
||||||
|
current_identifier = None
|
||||||
|
if char == '(':
|
||||||
|
if bracket_group is not None:
|
||||||
|
raise EinopsError("Axis composition is one-level (brackets inside brackets not allowed)")
|
||||||
|
bracket_group = []
|
||||||
|
elif char == ')':
|
||||||
|
if bracket_group is None:
|
||||||
|
raise EinopsError('Brackets are not balanced')
|
||||||
|
self.composition.append(bracket_group)
|
||||||
|
bracket_group = None
|
||||||
|
elif str.isalnum(char) or char in ['_', _ellipsis]:
|
||||||
|
if current_identifier is None:
|
||||||
|
current_identifier = char
|
||||||
|
else:
|
||||||
|
current_identifier += char
|
||||||
|
else:
|
||||||
|
raise EinopsError("Unknown character '{}'".format(char))
|
||||||
|
|
||||||
|
if bracket_group is not None:
|
||||||
|
raise EinopsError('Imbalanced parentheses in expression: "{}"'.format(expression))
|
||||||
|
add_axis_name(current_identifier)
|
||||||
|
|
||||||
|
def flat_axes_order(self) -> List:
|
||||||
|
result = []
|
||||||
|
for composed_axis in self.composition:
|
||||||
|
assert isinstance(composed_axis, list), 'does not work with ellipsis'
|
||||||
|
for axis in composed_axis:
|
||||||
|
result.append(axis)
|
||||||
|
return result
|
||||||
|
|
||||||
|
def has_composed_axes(self) -> bool:
|
||||||
|
# this will ignore 1 inside brackets
|
||||||
|
for axes in self.composition:
|
||||||
|
if isinstance(axes, list) and len(axes) > 1:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_axis_name_return_reason(name: str, allow_underscore: bool = False) -> Tuple[bool, str]:
|
||||||
|
if not str.isidentifier(name):
|
||||||
|
return False, 'not a valid python identifier'
|
||||||
|
elif name[0] == '_' or name[-1] == '_':
|
||||||
|
if name == '_' and allow_underscore:
|
||||||
|
return True, ''
|
||||||
|
return False, 'axis name should should not start or end with underscore'
|
||||||
|
else:
|
||||||
|
if keyword.iskeyword(name):
|
||||||
|
warnings.warn("It is discouraged to use axes names that are keywords: {}".format(name), RuntimeWarning)
|
||||||
|
if name in ['axis']:
|
||||||
|
warnings.warn("It is discouraged to use 'axis' as an axis name "
|
||||||
|
"and will raise an error in future", FutureWarning)
|
||||||
|
return True, ''
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_axis_name(name: str) -> bool:
|
||||||
|
"""
|
||||||
|
Valid axes names are python identifiers except keywords,
|
||||||
|
and additionally should not start or end with underscore
|
||||||
|
"""
|
||||||
|
is_valid, _reason = ParsedExpression.check_axis_name_return_reason(name)
|
||||||
|
return is_valid
|
Loading…
Reference in New Issue