mirror of https://github.com/Jittor/Jittor
616 lines
27 KiB
Python
616 lines
27 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2022 Jittor. All Rights Reserved.
|
|
# Maintainers: DongYang Li <lidongyang2001@gmail.com>.
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
from collections import namedtuple
|
|
import tempfile
|
|
import pickle
|
|
import itertools
|
|
from jittor.einops.einops import (rearrange, reduce, _enumerate_directions, _reductions)
|
|
from jittor.einops import EinopsError
|
|
import jittor as jt
|
|
import numpy
|
|
import unittest
|
|
|
|
# tests/__init__.py
|
|
import os
|
|
from jittor.einops import _backends
|
|
import warnings
|
|
|
|
flag_to_bool = {
|
|
'': False,
|
|
'0': False,
|
|
'1': True,
|
|
}
|
|
|
|
|
|
def collect_test_backends(symbolic=False, layers=False):
|
|
"""
|
|
:param symbolic: symbolic or imperative frameworks?
|
|
:param layers: layers or operations?
|
|
:return: list of backends satisfying set conditions
|
|
"""
|
|
if not symbolic:
|
|
if not layers:
|
|
backend_types = [
|
|
_backends.NumpyBackend,
|
|
_backends.JittorBackend,
|
|
]
|
|
else:
|
|
backend_types = [
|
|
_backends.JittorBackend,
|
|
]
|
|
else:
|
|
backend_types = []
|
|
result = []
|
|
for backend_type in backend_types:
|
|
try:
|
|
result.append(backend_type())
|
|
except ImportError:
|
|
# problem with backend installation fails a specific test function,
|
|
# but will be skipped in all other test cases
|
|
warnings.warn('backend could not be initialized for tests: {}'.format(backend_type))
|
|
return result
|
|
|
|
|
|
# test/test_ops.py
|
|
|
|
imp_op_backends = collect_test_backends(symbolic=False, layers=False)
|
|
|
|
# test/test_layer.py
|
|
|
|
|
|
class TestSlice(unittest.TestCase):
|
|
|
|
def test_anonymous_axes(self):
|
|
x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6])
|
|
for pattern, axis_dimensions in test_cases_repeat_anonymous:
|
|
check_reversion(x, pattern, **axis_dimensions)
|
|
|
|
def test_repeat_imperatives(self):
|
|
x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5])
|
|
for backend in imp_op_backends:
|
|
print('Repeat tests for ', backend.framework_name)
|
|
|
|
for pattern, axis_dimensions in repeat_test_cases:
|
|
expected = reduce(x, pattern, reduction='repeat', **axis_dimensions)
|
|
converted = backend.from_numpy(x)
|
|
repeated = reduce(converted, pattern, reduction='repeat', **axis_dimensions)
|
|
result = backend.to_numpy(repeated)
|
|
assert numpy.array_equal(result, expected)
|
|
|
|
def test_repeat_numpy(self):
|
|
# check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well
|
|
x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5])
|
|
x1 = reduce(x, 'a b c -> copy a b c ', reduction='repeat', copy=1)
|
|
assert numpy.array_equal(x[None], x1)
|
|
for pattern, axis_dimensions in repeat_test_cases:
|
|
check_reversion(x, pattern, **axis_dimensions)
|
|
|
|
def test_tiling_imperatives(self):
|
|
for backend in imp_op_backends:
|
|
print('Tiling tests for ', backend.framework_name)
|
|
input = numpy.arange(2 * 3 * 5, dtype='int64').reshape([2, 1, 3, 1, 5])
|
|
test_cases = [
|
|
(1, 1, 1, 1, 1),
|
|
(1, 2, 1, 3, 1),
|
|
(3, 1, 1, 4, 1),
|
|
]
|
|
for repeats in test_cases:
|
|
expected = numpy.tile(input, repeats)
|
|
converted = backend.from_numpy(input)
|
|
repeated = backend.tile(converted, repeats)
|
|
result = backend.to_numpy(repeated)
|
|
assert numpy.array_equal(result, expected)
|
|
|
|
def test_gradients_imperatives(self):
|
|
# lazy - just checking reductions
|
|
for reduction in _reductions:
|
|
x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype('float32')
|
|
results = {}
|
|
for backend in imp_op_backends:
|
|
y0 = backend.from_numpy(x)
|
|
if not 'jittor' in backend.framework_name and not hasattr(y0, 'grad'):
|
|
continue
|
|
y1 = reduce(y0, 'a b c -> c a', reduction=reduction)
|
|
y2 = reduce(y1, 'c a -> a c', reduction=reduction)
|
|
y3 = reduce(y2, 'a (c1 c2) -> a', reduction=reduction, c1=2)
|
|
y4 = reduce(y3, '... -> ', reduction=reduction)
|
|
if 'jittor' in backend.framework_name:
|
|
grad = backend.jittor.grad(y4, y0)
|
|
else:
|
|
y4.backward()
|
|
grad = y0.grad
|
|
results[backend.framework_name] = backend.to_numpy(grad)
|
|
|
|
print('comparing gradients for', results.keys())
|
|
for name1, grad1 in results.items():
|
|
for name2, grad2 in results.items():
|
|
assert numpy.allclose(grad1, grad2), [name1, name2, 'provided different gradients']
|
|
|
|
def test_concatenations_and_stacking(self):
|
|
for backend in imp_op_backends:
|
|
print('testing shapes for ', backend.framework_name)
|
|
for n_arrays in [1, 2, 5]:
|
|
shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
|
|
for shape in shapes:
|
|
if (backend.framework_name == 'jittor')\
|
|
and len(shape) == 0:
|
|
# jittor stores scalar in 1d array
|
|
continue
|
|
arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)]
|
|
arrays2 = [backend.from_numpy(array) for array in arrays1]
|
|
result0 = numpy.asarray(arrays1)
|
|
result1 = rearrange(arrays1, '...->...')
|
|
result2 = rearrange(arrays2, '...->...')
|
|
assert numpy.array_equal(result0, result1)
|
|
assert numpy.array_equal(result1, backend.to_numpy(result2))
|
|
|
|
result1 = rearrange(arrays1, 'b ... -> ... b')
|
|
result2 = rearrange(arrays2, 'b ... -> ... b')
|
|
assert numpy.array_equal(result1, backend.to_numpy(result2))
|
|
|
|
def test_enumerating_directions(self):
|
|
for backend in imp_op_backends:
|
|
print('testing directions for', backend.framework_name)
|
|
for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]:
|
|
if (backend.framework_name == 'jittor')\
|
|
and len(shape) == 0:
|
|
# jittor stores scalar in 1d array
|
|
continue
|
|
x = numpy.arange(numpy.prod(shape)).reshape(shape)
|
|
axes1 = _enumerate_directions(x)
|
|
axes2 = _enumerate_directions(backend.from_numpy(x))
|
|
assert len(axes1) == len(axes2) == len(shape)
|
|
for ax1, ax2 in zip(axes1, axes2):
|
|
ax2 = backend.to_numpy(ax2)
|
|
assert ax1.shape == ax2.shape
|
|
assert numpy.allclose(ax1, ax2)
|
|
|
|
def test_reduction_with_callable_imperatives(self):
|
|
x_numpy = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]).astype('float32')
|
|
x_numpy /= x_numpy.max()
|
|
|
|
def logsumexp_jittor(x, tuple_of_axes):
|
|
import jittor as jt
|
|
return jt.nn.logsumexp(x, tuple_of_axes)
|
|
|
|
def logsumexp_numpy(x, tuple_of_axes):
|
|
# very naive logsumexp to compare to
|
|
minused = x.max(tuple_of_axes)
|
|
y = x - x.max(tuple_of_axes, keepdims=True)
|
|
y = numpy.exp(y)
|
|
y = numpy.sum(y, axis=tuple_of_axes)
|
|
return numpy.log(y) + minused
|
|
|
|
from jittor.einops._backends import JittorBackend, NumpyBackend
|
|
backend2callback = {
|
|
JittorBackend.framework_name: logsumexp_jittor,
|
|
NumpyBackend.framework_name: logsumexp_numpy,
|
|
}
|
|
|
|
for backend in imp_op_backends:
|
|
if backend.framework_name not in backend2callback:
|
|
continue
|
|
|
|
backend_callback = backend2callback[backend.framework_name]
|
|
|
|
x_backend = backend.from_numpy(x_numpy)
|
|
for pattern1, pattern2 in equivalent_reduction_patterns:
|
|
print('Test reduction with callable for ', backend.framework_name, pattern1, pattern2)
|
|
output_numpy = reduce(x_numpy, pattern1, reduction=logsumexp_numpy)
|
|
output_backend = reduce(x_backend, pattern1, reduction=backend_callback)
|
|
assert numpy.allclose(
|
|
output_numpy,
|
|
backend.to_numpy(output_backend),
|
|
)
|
|
|
|
def test_reduction_stress_imperatives(self):
|
|
for backend in imp_op_backends:
|
|
print('Stress-testing reduction for ', backend.framework_name)
|
|
for reduction in _reductions + ('rearrange',):
|
|
dtype = 'int64'
|
|
coincide = numpy.array_equal
|
|
if reduction in ['mean', 'prod']:
|
|
dtype = 'float64'
|
|
coincide = numpy.allclose
|
|
for n_axes in range(11):
|
|
shape = numpy.random.randint(2, 4, size=n_axes)
|
|
permutation = numpy.random.permutation(n_axes)
|
|
skipped = 0 if reduction == 'rearrange' else numpy.random.randint(n_axes + 1)
|
|
left = ' '.join('x' + str(i) for i in range(n_axes))
|
|
right = ' '.join('x' + str(i) for i in permutation[skipped:])
|
|
pattern = left + '->' + right
|
|
x = numpy.arange(1, 1 + numpy.prod(shape), dtype=dtype).reshape(shape)
|
|
if reduction == 'prod':
|
|
x /= x.mean() # to avoid overflows
|
|
result1 = reduce(x, pattern, reduction=reduction)
|
|
result2 = x.transpose(permutation)
|
|
if skipped > 0:
|
|
result2 = getattr(result2, reduction)(axis=tuple(range(skipped)))
|
|
assert coincide(result1, result2)
|
|
check_op_against_numpy(backend, x, pattern, reduction=reduction, axes_lengths={}, is_symbolic=False)
|
|
|
|
def test_reduction_imperatives(self):
|
|
for backend in imp_op_backends:
|
|
print('Reduction tests for ', backend.framework_name)
|
|
for reduction in _reductions:
|
|
# slight redundancy for simpler order - numpy version is evaluated multiple times
|
|
input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype='int64').reshape([2, 3, 4, 5, 6])
|
|
if reduction in ['mean', 'prod']:
|
|
input = input / input.astype('float64').mean()
|
|
test_cases = [
|
|
['a b c d e -> ', {},
|
|
getattr(input, reduction)()],
|
|
['a ... -> ', {},
|
|
getattr(input, reduction)()],
|
|
['(a1 a2) ... (e1 e2) -> ', dict(a1=1, e2=2),
|
|
getattr(input, reduction)()],
|
|
['a b c d e -> (e c) a', {},
|
|
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
|
|
['a ... c d e -> (e c) a', {},
|
|
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
|
|
['a b c d e ... -> (e c) a', {},
|
|
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
|
|
['a b c d e -> (e c a)', {},
|
|
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])],
|
|
['(a a2) ... -> (a2 a) ...', dict(a2=1),
|
|
input],
|
|
]
|
|
for pattern, axes_lengths, expected_result in test_cases:
|
|
result = reduce(backend.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths)
|
|
result = backend.to_numpy(result)
|
|
assert numpy.allclose(result, expected_result)
|
|
|
|
def test_rearrange_permutations_numpy(self):
|
|
# tests random permutation of axes against two independent numpy ways
|
|
for n_axes in range(1, 10):
|
|
input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
|
|
permutation = numpy.random.permutation(n_axes)
|
|
left_expression = ' '.join('i' + str(axis) for axis in range(n_axes))
|
|
right_expression = ' '.join('i' + str(axis) for axis in permutation)
|
|
expression = left_expression + ' -> ' + right_expression
|
|
result = rearrange(input, expression)
|
|
|
|
for pick in numpy.random.randint(0, 2, [10, n_axes]):
|
|
assert input[tuple(pick)] == result[tuple(pick[permutation])]
|
|
|
|
for n_axes in range(1, 10):
|
|
input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
|
|
permutation = numpy.random.permutation(n_axes)
|
|
left_expression = ' '.join('i' + str(axis) for axis in range(n_axes)[::-1])
|
|
right_expression = ' '.join('i' + str(axis) for axis in permutation[::-1])
|
|
expression = left_expression + ' -> ' + right_expression
|
|
result = rearrange(input, expression)
|
|
assert result.shape == input.shape
|
|
expected_result = numpy.zeros_like(input)
|
|
for original_axis, result_axis in enumerate(permutation):
|
|
expected_result |= ((input >> original_axis) & 1) << result_axis
|
|
|
|
assert numpy.array_equal(result, expected_result)
|
|
|
|
def test_rearrange_consistency_numpy(self):
|
|
shape = [1, 2, 3, 5, 7, 11]
|
|
x = numpy.arange(numpy.prod(shape)).reshape(shape)
|
|
for pattern in [
|
|
'a b c d e f -> a b c d e f',
|
|
'b a c d e f -> a b d e f c',
|
|
'a b c d e f -> f e d c b a',
|
|
'a b c d e f -> (f e) d (c b a)',
|
|
'a b c d e f -> (f e d c b a)',
|
|
]:
|
|
result = rearrange(x, pattern)
|
|
assert len(numpy.setdiff1d(x, result)) == 0
|
|
assert result.dtype == x.dtype
|
|
|
|
result = rearrange(x, 'a b c d e f -> a (b) (c d e) f')
|
|
assert numpy.array_equal(x.flatten(), result.flatten())
|
|
|
|
result = rearrange(x, 'a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11')
|
|
assert numpy.array_equal(x, result)
|
|
|
|
result1 = rearrange(x, 'a b c d e f -> f e d c b a')
|
|
result2 = rearrange(x, 'f e d c b a -> a b c d e f')
|
|
assert numpy.array_equal(result1, result2)
|
|
|
|
result = rearrange(rearrange(x, 'a b c d e f -> (f d) c (e b) a'), '(f d) c (e b) a -> a b c d e f', b=2, d=5)
|
|
assert numpy.array_equal(x, result)
|
|
|
|
sizes = dict(zip('abcdef', shape))
|
|
temp = rearrange(x, 'a b c d e f -> (f d) c (e b) a', **sizes)
|
|
result = rearrange(temp, '(f d) c (e b) a -> a b c d e f', **sizes)
|
|
assert numpy.array_equal(x, result)
|
|
|
|
x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4])
|
|
result = rearrange(x2, 'a b c -> b c a')
|
|
assert x2[1, 2, 3] == result[2, 3, 1]
|
|
assert x2[0, 1, 2] == result[1, 2, 0]
|
|
|
|
def test_ellipsis_ops_imperative(self):
|
|
""" Checking various patterns against numpy """
|
|
x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
|
|
for is_symbolic in [True, False]:
|
|
for backend in collect_test_backends(symbolic=is_symbolic, layers=False):
|
|
for pattern in identity_patterns + list(itertools.chain(*equivalent_rearrange_patterns)):
|
|
check_op_against_numpy(backend, x, pattern, axes_lengths={},
|
|
reduction='rearrange', is_symbolic=is_symbolic)
|
|
|
|
for reduction in ['min', 'max', 'sum']:
|
|
for pattern in itertools.chain(*equivalent_reduction_patterns):
|
|
check_op_against_numpy(backend, x, pattern, axes_lengths={},
|
|
reduction=reduction, is_symbolic=is_symbolic)
|
|
|
|
def test_ellipsis_ops_numpy(self):
|
|
x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
|
|
for pattern in identity_patterns:
|
|
assert numpy.array_equal(x, rearrange(x, pattern)), pattern
|
|
|
|
for pattern1, pattern2 in equivalent_rearrange_patterns:
|
|
assert numpy.array_equal(rearrange(x, pattern1), rearrange(x, pattern2))
|
|
|
|
for reduction in ['min', 'max', 'sum']:
|
|
for pattern1, pattern2 in equivalent_reduction_patterns:
|
|
assert numpy.array_equal(reduce(x, pattern1, reduction=reduction),
|
|
reduce(x, pattern2, reduction=reduction))
|
|
|
|
# now just check coincidence with numpy
|
|
all_rearrange_patterns = [*identity_patterns]
|
|
for pattern_pairs in equivalent_rearrange_patterns:
|
|
all_rearrange_patterns.extend(pattern_pairs)
|
|
|
|
def test_collapsed_ellipsis_errors_out(self):
|
|
x = numpy.zeros([1, 1, 1, 1, 1])
|
|
rearrange(x, 'a b c d ... -> a b c ... d')
|
|
error = 0
|
|
try:
|
|
rearrange(x, 'a b c d (...) -> a b c ... d')
|
|
except Exception as e:
|
|
error = 1
|
|
assert error == 1
|
|
|
|
rearrange(x, '... -> (...)')
|
|
error = 0
|
|
try:
|
|
rearrange(x, '(...) -> (...)')
|
|
except Exception as e:
|
|
error = 1
|
|
assert error == 1
|
|
|
|
def test_rearrange_imperative(self):
|
|
for backend in collect_test_backends(symbolic=False, layers=True):
|
|
print('Test layer for ', backend.framework_name)
|
|
|
|
for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns:
|
|
x = numpy.arange(numpy.prod(input_shape), dtype='float32').reshape(input_shape)
|
|
result_numpy = rearrange(x, pattern, **axes_lengths)
|
|
layer = backend.layers().Rearrange(pattern, **axes_lengths)
|
|
for shape in wrong_shapes:
|
|
try:
|
|
layer(backend.from_numpy(numpy.zeros(shape, dtype='float32')))
|
|
except:
|
|
pass
|
|
else:
|
|
raise AssertionError('Failure expected')
|
|
|
|
# simple pickling / unpickling
|
|
layer2 = pickle.loads(pickle.dumps(layer))
|
|
result1 = backend.to_numpy(layer(backend.from_numpy(x)))
|
|
result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
|
|
assert numpy.allclose(result_numpy, result1)
|
|
assert numpy.allclose(result1, result2)
|
|
|
|
just_sum = backend.layers().Reduce('...->', reduction='sum')
|
|
|
|
|
|
variable = backend.from_numpy(x)
|
|
result = just_sum(layer(variable))
|
|
|
|
if 'jittor' in backend.framework_name:
|
|
grad = backend.jittor.grad(result, variable)
|
|
else:
|
|
result.backward()
|
|
grad = variable.grad
|
|
|
|
assert numpy.allclose(backend.to_numpy(grad), 1)
|
|
|
|
def test_reduce_imperative(self):
|
|
for backend in collect_test_backends(symbolic=False, layers=True):
|
|
print('Test layer for ', backend.framework_name)
|
|
for reduction in _reductions:
|
|
for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns:
|
|
print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes)
|
|
x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype='float32').reshape(input_shape)
|
|
x /= x.mean()
|
|
result_numpy = reduce(x, pattern, reduction, **axes_lengths)
|
|
layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
|
|
for shape in wrong_shapes:
|
|
try:
|
|
layer(backend.from_numpy(numpy.zeros(shape, dtype='float32')))
|
|
except:
|
|
pass
|
|
else:
|
|
raise AssertionError('Failure expected')
|
|
|
|
# simple pickling / unpickling
|
|
layer2 = pickle.loads(pickle.dumps(layer))
|
|
result1 = backend.to_numpy(layer(backend.from_numpy(x)))
|
|
result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
|
|
assert numpy.allclose(result_numpy, result1)
|
|
assert numpy.allclose(result1, result2)
|
|
|
|
just_sum = backend.layers().Reduce('...->', reduction='sum')
|
|
|
|
|
|
variable = backend.from_numpy(x)
|
|
result = just_sum(layer(variable))
|
|
|
|
if 'jittor' in backend.framework_name:
|
|
grad = backend.jittor.grad(result, variable)
|
|
grad = backend.to_numpy(grad)
|
|
else:
|
|
result.backward()
|
|
grad = backend.to_numpy(variable.grad)
|
|
if reduction == 'sum':
|
|
assert numpy.allclose(grad, 1)
|
|
if reduction == 'mean':
|
|
assert numpy.allclose(grad, grad.min())
|
|
if reduction in ['max', 'min']:
|
|
assert numpy.all(numpy.in1d(grad, [0, 1]))
|
|
assert numpy.sum(grad) > 0.5
|
|
|
|
def test_jittor_layer(self):
|
|
has_jittor = any(backend.framework_name == 'jittor' for backend in collect_test_backends(symbolic=False, layers=True))
|
|
if has_jittor:
|
|
# checked that jittor present
|
|
import jittor
|
|
|
|
rtol = 1e-05
|
|
atol = 1e-08
|
|
def allclose(input, other): return jittor.all(jittor.abs(input-other) <= atol+rtol*jittor.abs(other))
|
|
model1 = create_jittor_model(use_reduce=True)
|
|
model2 = create_jittor_model(use_reduce=False)
|
|
input = jittor.randn([10, 3, 32, 32])
|
|
# random models have different predictions
|
|
assert not allclose(model1(input), model2(input))
|
|
model2.load_state_dict(pickle.loads(pickle.dumps(model1.state_dict())))
|
|
assert allclose(model1(input), model2(input))
|
|
|
|
|
|
testcase = namedtuple('testcase', ['pattern', 'axes_lengths', 'input_shape', 'wrong_shapes'])
|
|
|
|
rearrangement_patterns = [
|
|
testcase('b c h w -> b (c h w)', dict(c=20), (10, 20, 30, 40),
|
|
[(), (10,), (10, 10, 10), (10, 21, 30, 40), [1, 20, 1, 1, 1]]),
|
|
testcase('b c (h1 h2) (w1 w2) -> b (c h2 w2) h1 w1', dict(h2=2, w2=2), (10, 20, 30, 40),
|
|
[(), (1, 1, 1, 1), (1, 10, 3), ()]),
|
|
testcase('b ... c -> c b ...', dict(b=10), (10, 20, 30),
|
|
[(), (10,), (5, 10)]),
|
|
]
|
|
|
|
reduction_patterns = rearrangement_patterns + [
|
|
testcase('b c h w -> b ()', dict(b=10), (10, 20, 30, 40),
|
|
[(10,), (10, 20, 30)]),
|
|
testcase('b c (h1 h2) (w1 w2) -> b c h1 w1', dict(h1=15, h2=2, w2=2), (10, 20, 30, 40),
|
|
[(10, 20, 31, 40)]),
|
|
testcase('b ... c -> b', dict(b=10), (10, 20, 30, 40),
|
|
[(10,), (11, 10)]),
|
|
]
|
|
|
|
equivalent_reduction_patterns = [
|
|
('a b c d e -> ', ' ... -> '),
|
|
('a b c d e -> (e a)', 'a ... e -> (e a)'),
|
|
('a b c d e -> d (a e)', ' a b c d e ... -> d (a e) '),
|
|
('a b c d e -> (a b)', ' ... c d e -> (...) '),
|
|
]
|
|
|
|
equivalent_rearrange_patterns = [
|
|
('a b c d e -> (a b) c d e', 'a b ... -> (a b) ... '),
|
|
('a b c d e -> a b (c d) e', '... c d e -> ... (c d) e'),
|
|
('a b c d e -> a b c d e', '... -> ... '),
|
|
('a b c d e -> (a b c d e)', '... -> (...)'),
|
|
('a b c d e -> b (c d e) a', 'a b ... -> b (...) a'),
|
|
('a b c d e -> b (a c d) e', 'a b ... e -> b (a ...) e'),
|
|
]
|
|
|
|
identity_patterns = [
|
|
'...->...',
|
|
'a b c d e-> a b c d e',
|
|
'a b c d e ...-> ... a b c d e',
|
|
'a b c d e ...-> a ... b c d e',
|
|
'... a b c d e -> ... a b c d e',
|
|
'a ... e-> a ... e',
|
|
'a ... -> a ... ',
|
|
'a ... c d e -> a (...) c d e',
|
|
]
|
|
|
|
test_cases_repeat_anonymous = [
|
|
# all assume that input has shape [1, 2, 4, 6]
|
|
('a b c d -> c a d b', dict()),
|
|
('a b c d -> (c 2 d a b)', dict(a=1, c=4, d=6)),
|
|
('1 b c d -> (d copy 1) 3 b c ', dict(copy=3)),
|
|
('1 ... -> 3 ... ', dict()),
|
|
('() ... d -> 1 (copy1 d copy2) ... ', dict(copy1=2, copy2=3)),
|
|
('1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)', dict()),
|
|
|
|
]
|
|
|
|
repeat_test_cases = [
|
|
# all assume that input has shape [2, 3, 5]
|
|
('a b c -> c a b', dict()),
|
|
('a b c -> (c copy a b)', dict(copy=2, a=2, b=3, c=5)),
|
|
('a b c -> (a copy) b c ', dict(copy=1)),
|
|
('a b c -> (c a) (copy1 b copy2)', dict(a=2, copy1=1, copy2=2)),
|
|
('a ... -> a ... copy', dict(copy=4)),
|
|
('... c -> ... (copy1 c copy2)', dict(copy1=1, copy2=2)),
|
|
('... -> ... ', dict()),
|
|
(' ... -> copy1 ... copy2 ', dict(copy1=2, copy2=3)),
|
|
('a b c -> copy1 a copy2 b c () ', dict(copy1=2, copy2=1)),
|
|
]
|
|
|
|
|
|
def check_reversion(x, repeat_pattern, **sizes):
|
|
"""Checks repeat pattern by running reduction """
|
|
left, right = repeat_pattern.split('->')
|
|
reduce_pattern = right + '->' + left
|
|
repeated = reduce(x, repeat_pattern, reduction='repeat', **sizes)
|
|
reduced_min = reduce(repeated, reduce_pattern, reduction='min', **sizes)
|
|
reduced_max = reduce(repeated, reduce_pattern, reduction='max', **sizes)
|
|
assert numpy.array_equal(x, reduced_min)
|
|
assert numpy.array_equal(x, reduced_max)
|
|
|
|
|
|
def check_op_against_numpy(backend, numpy_input, pattern, axes_lengths, reduction='rearrange', is_symbolic=False):
|
|
"""
|
|
Helper to test result of operation (rearrange or transpose) against numpy
|
|
if reduction == 'rearrange', rearrange op is tested, otherwise reduce
|
|
"""
|
|
if len(numpy_input.shape) == 0:
|
|
return
|
|
|
|
def operation(x):
|
|
if reduction == 'rearrange':
|
|
return rearrange(x, pattern, **axes_lengths)
|
|
else:
|
|
return reduce(x, pattern, reduction, **axes_lengths)
|
|
|
|
numpy_result = operation(numpy_input)
|
|
check_equal = numpy.array_equal
|
|
p_none_dimension = 0.5
|
|
if 'jittor' in backend.framework_name:
|
|
check_equal = numpy.allclose
|
|
p_none_dimension = 0
|
|
if is_symbolic:
|
|
symbol_shape = [d if numpy.random.random() >= p_none_dimension else None for d in numpy_input.shape]
|
|
symbol = backend.create_symbol(shape=symbol_shape)
|
|
result_symbol = operation(symbol)
|
|
backend_result = backend.eval_symbol(result_symbol, [(symbol, numpy_input)])
|
|
else:
|
|
backend_result = operation(backend.from_numpy(numpy_input))
|
|
backend_result = backend.to_numpy(backend_result)
|
|
|
|
check_equal(numpy_result, backend_result)
|
|
|
|
|
|
def create_jittor_model(use_reduce=False):
|
|
from jittor.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
|
|
from jittor.einops.layers.jittor import Rearrange, Reduce, EinMix
|
|
return Sequential(
|
|
Conv2d(3, 6, kernel_size=(5, 5)),
|
|
Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2),
|
|
Conv2d(6, 16, kernel_size=(5, 5)),
|
|
Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2),
|
|
Rearrange('b c h w -> b (c h w)'),
|
|
Linear(16 * 5 * 5, 120),
|
|
ReLU(),
|
|
Linear(120, 84),
|
|
ReLU(),
|
|
EinMix('b c1 -> (b c2)', weight_shape='c1 c2', bias_shape='c2', c1=84, c2=84),
|
|
EinMix('(b c2) -> b c3', weight_shape='c2 c3', bias_shape='c3', c2=84, c3=84),
|
|
Linear(84, 10),
|
|
)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |