JittorMirror/python/jittor/test/test_einops.py

616 lines
27 KiB
Python

# ***************************************************************
# Copyright (c) 2022 Jittor. All Rights Reserved.
# Maintainers: DongYang Li <lidongyang2001@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
from collections import namedtuple
import tempfile
import pickle
import itertools
from jittor.einops.einops import (rearrange, reduce, _enumerate_directions, _reductions)
from jittor.einops import EinopsError
import jittor as jt
import numpy
import unittest
# tests/__init__.py
import os
from jittor.einops import _backends
import warnings
flag_to_bool = {
'': False,
'0': False,
'1': True,
}
def collect_test_backends(symbolic=False, layers=False):
"""
:param symbolic: symbolic or imperative frameworks?
:param layers: layers or operations?
:return: list of backends satisfying set conditions
"""
if not symbolic:
if not layers:
backend_types = [
_backends.NumpyBackend,
_backends.JittorBackend,
]
else:
backend_types = [
_backends.JittorBackend,
]
else:
backend_types = []
result = []
for backend_type in backend_types:
try:
result.append(backend_type())
except ImportError:
# problem with backend installation fails a specific test function,
# but will be skipped in all other test cases
warnings.warn('backend could not be initialized for tests: {}'.format(backend_type))
return result
# test/test_ops.py
imp_op_backends = collect_test_backends(symbolic=False, layers=False)
# test/test_layer.py
class TestSlice(unittest.TestCase):
def test_anonymous_axes(self):
x = numpy.arange(1 * 2 * 4 * 6).reshape([1, 2, 4, 6])
for pattern, axis_dimensions in test_cases_repeat_anonymous:
check_reversion(x, pattern, **axis_dimensions)
def test_repeat_imperatives(self):
x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5])
for backend in imp_op_backends:
print('Repeat tests for ', backend.framework_name)
for pattern, axis_dimensions in repeat_test_cases:
expected = reduce(x, pattern, reduction='repeat', **axis_dimensions)
converted = backend.from_numpy(x)
repeated = reduce(converted, pattern, reduction='repeat', **axis_dimensions)
result = backend.to_numpy(repeated)
assert numpy.array_equal(result, expected)
def test_repeat_numpy(self):
# check repeat vs reduce. Repeat works ok if reverse reduction with min and max work well
x = numpy.arange(2 * 3 * 5).reshape([2, 3, 5])
x1 = reduce(x, 'a b c -> copy a b c ', reduction='repeat', copy=1)
assert numpy.array_equal(x[None], x1)
for pattern, axis_dimensions in repeat_test_cases:
check_reversion(x, pattern, **axis_dimensions)
def test_tiling_imperatives(self):
for backend in imp_op_backends:
print('Tiling tests for ', backend.framework_name)
input = numpy.arange(2 * 3 * 5, dtype='int64').reshape([2, 1, 3, 1, 5])
test_cases = [
(1, 1, 1, 1, 1),
(1, 2, 1, 3, 1),
(3, 1, 1, 4, 1),
]
for repeats in test_cases:
expected = numpy.tile(input, repeats)
converted = backend.from_numpy(input)
repeated = backend.tile(converted, repeats)
result = backend.to_numpy(repeated)
assert numpy.array_equal(result, expected)
def test_gradients_imperatives(self):
# lazy - just checking reductions
for reduction in _reductions:
x = numpy.arange(1, 1 + 2 * 3 * 4).reshape([2, 3, 4]).astype('float32')
results = {}
for backend in imp_op_backends:
y0 = backend.from_numpy(x)
if not 'jittor' in backend.framework_name and not hasattr(y0, 'grad'):
continue
y1 = reduce(y0, 'a b c -> c a', reduction=reduction)
y2 = reduce(y1, 'c a -> a c', reduction=reduction)
y3 = reduce(y2, 'a (c1 c2) -> a', reduction=reduction, c1=2)
y4 = reduce(y3, '... -> ', reduction=reduction)
if 'jittor' in backend.framework_name:
grad = backend.jittor.grad(y4, y0)
else:
y4.backward()
grad = y0.grad
results[backend.framework_name] = backend.to_numpy(grad)
print('comparing gradients for', results.keys())
for name1, grad1 in results.items():
for name2, grad2 in results.items():
assert numpy.allclose(grad1, grad2), [name1, name2, 'provided different gradients']
def test_concatenations_and_stacking(self):
for backend in imp_op_backends:
print('testing shapes for ', backend.framework_name)
for n_arrays in [1, 2, 5]:
shapes = [[], [1], [1, 1], [2, 3, 5, 7], [1] * 6]
for shape in shapes:
if (backend.framework_name == 'jittor')\
and len(shape) == 0:
# jittor stores scalar in 1d array
continue
arrays1 = [numpy.arange(i, i + numpy.prod(shape)).reshape(shape) for i in range(n_arrays)]
arrays2 = [backend.from_numpy(array) for array in arrays1]
result0 = numpy.asarray(arrays1)
result1 = rearrange(arrays1, '...->...')
result2 = rearrange(arrays2, '...->...')
assert numpy.array_equal(result0, result1)
assert numpy.array_equal(result1, backend.to_numpy(result2))
result1 = rearrange(arrays1, 'b ... -> ... b')
result2 = rearrange(arrays2, 'b ... -> ... b')
assert numpy.array_equal(result1, backend.to_numpy(result2))
def test_enumerating_directions(self):
for backend in imp_op_backends:
print('testing directions for', backend.framework_name)
for shape in [[], [1], [1, 1, 1], [2, 3, 5, 7]]:
if (backend.framework_name == 'jittor')\
and len(shape) == 0:
# jittor stores scalar in 1d array
continue
x = numpy.arange(numpy.prod(shape)).reshape(shape)
axes1 = _enumerate_directions(x)
axes2 = _enumerate_directions(backend.from_numpy(x))
assert len(axes1) == len(axes2) == len(shape)
for ax1, ax2 in zip(axes1, axes2):
ax2 = backend.to_numpy(ax2)
assert ax1.shape == ax2.shape
assert numpy.allclose(ax1, ax2)
def test_reduction_with_callable_imperatives(self):
x_numpy = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6]).astype('float32')
x_numpy /= x_numpy.max()
def logsumexp_jittor(x, tuple_of_axes):
import jittor as jt
return jt.nn.logsumexp(x, tuple_of_axes)
def logsumexp_numpy(x, tuple_of_axes):
# very naive logsumexp to compare to
minused = x.max(tuple_of_axes)
y = x - x.max(tuple_of_axes, keepdims=True)
y = numpy.exp(y)
y = numpy.sum(y, axis=tuple_of_axes)
return numpy.log(y) + minused
from jittor.einops._backends import JittorBackend, NumpyBackend
backend2callback = {
JittorBackend.framework_name: logsumexp_jittor,
NumpyBackend.framework_name: logsumexp_numpy,
}
for backend in imp_op_backends:
if backend.framework_name not in backend2callback:
continue
backend_callback = backend2callback[backend.framework_name]
x_backend = backend.from_numpy(x_numpy)
for pattern1, pattern2 in equivalent_reduction_patterns:
print('Test reduction with callable for ', backend.framework_name, pattern1, pattern2)
output_numpy = reduce(x_numpy, pattern1, reduction=logsumexp_numpy)
output_backend = reduce(x_backend, pattern1, reduction=backend_callback)
assert numpy.allclose(
output_numpy,
backend.to_numpy(output_backend),
)
def test_reduction_stress_imperatives(self):
for backend in imp_op_backends:
print('Stress-testing reduction for ', backend.framework_name)
for reduction in _reductions + ('rearrange',):
dtype = 'int64'
coincide = numpy.array_equal
if reduction in ['mean', 'prod']:
dtype = 'float64'
coincide = numpy.allclose
for n_axes in range(11):
shape = numpy.random.randint(2, 4, size=n_axes)
permutation = numpy.random.permutation(n_axes)
skipped = 0 if reduction == 'rearrange' else numpy.random.randint(n_axes + 1)
left = ' '.join('x' + str(i) for i in range(n_axes))
right = ' '.join('x' + str(i) for i in permutation[skipped:])
pattern = left + '->' + right
x = numpy.arange(1, 1 + numpy.prod(shape), dtype=dtype).reshape(shape)
if reduction == 'prod':
x /= x.mean() # to avoid overflows
result1 = reduce(x, pattern, reduction=reduction)
result2 = x.transpose(permutation)
if skipped > 0:
result2 = getattr(result2, reduction)(axis=tuple(range(skipped)))
assert coincide(result1, result2)
check_op_against_numpy(backend, x, pattern, reduction=reduction, axes_lengths={}, is_symbolic=False)
def test_reduction_imperatives(self):
for backend in imp_op_backends:
print('Reduction tests for ', backend.framework_name)
for reduction in _reductions:
# slight redundancy for simpler order - numpy version is evaluated multiple times
input = numpy.arange(2 * 3 * 4 * 5 * 6, dtype='int64').reshape([2, 3, 4, 5, 6])
if reduction in ['mean', 'prod']:
input = input / input.astype('float64').mean()
test_cases = [
['a b c d e -> ', {},
getattr(input, reduction)()],
['a ... -> ', {},
getattr(input, reduction)()],
['(a1 a2) ... (e1 e2) -> ', dict(a1=1, e2=2),
getattr(input, reduction)()],
['a b c d e -> (e c) a', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
['a ... c d e -> (e c) a', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
['a b c d e ... -> (e c) a', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1, 2])],
['a b c d e -> (e c a)', {},
getattr(input, reduction)(axis=(1, 3)).transpose(2, 1, 0).reshape([-1])],
['(a a2) ... -> (a2 a) ...', dict(a2=1),
input],
]
for pattern, axes_lengths, expected_result in test_cases:
result = reduce(backend.from_numpy(input.copy()), pattern, reduction=reduction, **axes_lengths)
result = backend.to_numpy(result)
assert numpy.allclose(result, expected_result)
def test_rearrange_permutations_numpy(self):
# tests random permutation of axes against two independent numpy ways
for n_axes in range(1, 10):
input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
permutation = numpy.random.permutation(n_axes)
left_expression = ' '.join('i' + str(axis) for axis in range(n_axes))
right_expression = ' '.join('i' + str(axis) for axis in permutation)
expression = left_expression + ' -> ' + right_expression
result = rearrange(input, expression)
for pick in numpy.random.randint(0, 2, [10, n_axes]):
assert input[tuple(pick)] == result[tuple(pick[permutation])]
for n_axes in range(1, 10):
input = numpy.arange(2 ** n_axes).reshape([2] * n_axes)
permutation = numpy.random.permutation(n_axes)
left_expression = ' '.join('i' + str(axis) for axis in range(n_axes)[::-1])
right_expression = ' '.join('i' + str(axis) for axis in permutation[::-1])
expression = left_expression + ' -> ' + right_expression
result = rearrange(input, expression)
assert result.shape == input.shape
expected_result = numpy.zeros_like(input)
for original_axis, result_axis in enumerate(permutation):
expected_result |= ((input >> original_axis) & 1) << result_axis
assert numpy.array_equal(result, expected_result)
def test_rearrange_consistency_numpy(self):
shape = [1, 2, 3, 5, 7, 11]
x = numpy.arange(numpy.prod(shape)).reshape(shape)
for pattern in [
'a b c d e f -> a b c d e f',
'b a c d e f -> a b d e f c',
'a b c d e f -> f e d c b a',
'a b c d e f -> (f e) d (c b a)',
'a b c d e f -> (f e d c b a)',
]:
result = rearrange(x, pattern)
assert len(numpy.setdiff1d(x, result)) == 0
assert result.dtype == x.dtype
result = rearrange(x, 'a b c d e f -> a (b) (c d e) f')
assert numpy.array_equal(x.flatten(), result.flatten())
result = rearrange(x, 'a aa aa1 a1a1 aaaa a11 -> a aa aa1 a1a1 aaaa a11')
assert numpy.array_equal(x, result)
result1 = rearrange(x, 'a b c d e f -> f e d c b a')
result2 = rearrange(x, 'f e d c b a -> a b c d e f')
assert numpy.array_equal(result1, result2)
result = rearrange(rearrange(x, 'a b c d e f -> (f d) c (e b) a'), '(f d) c (e b) a -> a b c d e f', b=2, d=5)
assert numpy.array_equal(x, result)
sizes = dict(zip('abcdef', shape))
temp = rearrange(x, 'a b c d e f -> (f d) c (e b) a', **sizes)
result = rearrange(temp, '(f d) c (e b) a -> a b c d e f', **sizes)
assert numpy.array_equal(x, result)
x2 = numpy.arange(2 * 3 * 4).reshape([2, 3, 4])
result = rearrange(x2, 'a b c -> b c a')
assert x2[1, 2, 3] == result[2, 3, 1]
assert x2[0, 1, 2] == result[1, 2, 0]
def test_ellipsis_ops_imperative(self):
""" Checking various patterns against numpy """
x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
for is_symbolic in [True, False]:
for backend in collect_test_backends(symbolic=is_symbolic, layers=False):
for pattern in identity_patterns + list(itertools.chain(*equivalent_rearrange_patterns)):
check_op_against_numpy(backend, x, pattern, axes_lengths={},
reduction='rearrange', is_symbolic=is_symbolic)
for reduction in ['min', 'max', 'sum']:
for pattern in itertools.chain(*equivalent_reduction_patterns):
check_op_against_numpy(backend, x, pattern, axes_lengths={},
reduction=reduction, is_symbolic=is_symbolic)
def test_ellipsis_ops_numpy(self):
x = numpy.arange(2 * 3 * 4 * 5 * 6).reshape([2, 3, 4, 5, 6])
for pattern in identity_patterns:
assert numpy.array_equal(x, rearrange(x, pattern)), pattern
for pattern1, pattern2 in equivalent_rearrange_patterns:
assert numpy.array_equal(rearrange(x, pattern1), rearrange(x, pattern2))
for reduction in ['min', 'max', 'sum']:
for pattern1, pattern2 in equivalent_reduction_patterns:
assert numpy.array_equal(reduce(x, pattern1, reduction=reduction),
reduce(x, pattern2, reduction=reduction))
# now just check coincidence with numpy
all_rearrange_patterns = [*identity_patterns]
for pattern_pairs in equivalent_rearrange_patterns:
all_rearrange_patterns.extend(pattern_pairs)
def test_collapsed_ellipsis_errors_out(self):
x = numpy.zeros([1, 1, 1, 1, 1])
rearrange(x, 'a b c d ... -> a b c ... d')
error = 0
try:
rearrange(x, 'a b c d (...) -> a b c ... d')
except Exception as e:
error = 1
assert error == 1
rearrange(x, '... -> (...)')
error = 0
try:
rearrange(x, '(...) -> (...)')
except Exception as e:
error = 1
assert error == 1
def test_rearrange_imperative(self):
for backend in collect_test_backends(symbolic=False, layers=True):
print('Test layer for ', backend.framework_name)
for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns:
x = numpy.arange(numpy.prod(input_shape), dtype='float32').reshape(input_shape)
result_numpy = rearrange(x, pattern, **axes_lengths)
layer = backend.layers().Rearrange(pattern, **axes_lengths)
for shape in wrong_shapes:
try:
layer(backend.from_numpy(numpy.zeros(shape, dtype='float32')))
except:
pass
else:
raise AssertionError('Failure expected')
# simple pickling / unpickling
layer2 = pickle.loads(pickle.dumps(layer))
result1 = backend.to_numpy(layer(backend.from_numpy(x)))
result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
assert numpy.allclose(result_numpy, result1)
assert numpy.allclose(result1, result2)
just_sum = backend.layers().Reduce('...->', reduction='sum')
variable = backend.from_numpy(x)
result = just_sum(layer(variable))
if 'jittor' in backend.framework_name:
grad = backend.jittor.grad(result, variable)
else:
result.backward()
grad = variable.grad
assert numpy.allclose(backend.to_numpy(grad), 1)
def test_reduce_imperative(self):
for backend in collect_test_backends(symbolic=False, layers=True):
print('Test layer for ', backend.framework_name)
for reduction in _reductions:
for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns:
print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes)
x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype='float32').reshape(input_shape)
x /= x.mean()
result_numpy = reduce(x, pattern, reduction, **axes_lengths)
layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
for shape in wrong_shapes:
try:
layer(backend.from_numpy(numpy.zeros(shape, dtype='float32')))
except:
pass
else:
raise AssertionError('Failure expected')
# simple pickling / unpickling
layer2 = pickle.loads(pickle.dumps(layer))
result1 = backend.to_numpy(layer(backend.from_numpy(x)))
result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
assert numpy.allclose(result_numpy, result1)
assert numpy.allclose(result1, result2)
just_sum = backend.layers().Reduce('...->', reduction='sum')
variable = backend.from_numpy(x)
result = just_sum(layer(variable))
if 'jittor' in backend.framework_name:
grad = backend.jittor.grad(result, variable)
grad = backend.to_numpy(grad)
else:
result.backward()
grad = backend.to_numpy(variable.grad)
if reduction == 'sum':
assert numpy.allclose(grad, 1)
if reduction == 'mean':
assert numpy.allclose(grad, grad.min())
if reduction in ['max', 'min']:
assert numpy.all(numpy.in1d(grad, [0, 1]))
assert numpy.sum(grad) > 0.5
def test_jittor_layer(self):
has_jittor = any(backend.framework_name == 'jittor' for backend in collect_test_backends(symbolic=False, layers=True))
if has_jittor:
# checked that jittor present
import jittor
rtol = 1e-05
atol = 1e-08
def allclose(input, other): return jittor.all(jittor.abs(input-other) <= atol+rtol*jittor.abs(other))
model1 = create_jittor_model(use_reduce=True)
model2 = create_jittor_model(use_reduce=False)
input = jittor.randn([10, 3, 32, 32])
# random models have different predictions
assert not allclose(model1(input), model2(input))
model2.load_state_dict(pickle.loads(pickle.dumps(model1.state_dict())))
assert allclose(model1(input), model2(input))
testcase = namedtuple('testcase', ['pattern', 'axes_lengths', 'input_shape', 'wrong_shapes'])
rearrangement_patterns = [
testcase('b c h w -> b (c h w)', dict(c=20), (10, 20, 30, 40),
[(), (10,), (10, 10, 10), (10, 21, 30, 40), [1, 20, 1, 1, 1]]),
testcase('b c (h1 h2) (w1 w2) -> b (c h2 w2) h1 w1', dict(h2=2, w2=2), (10, 20, 30, 40),
[(), (1, 1, 1, 1), (1, 10, 3), ()]),
testcase('b ... c -> c b ...', dict(b=10), (10, 20, 30),
[(), (10,), (5, 10)]),
]
reduction_patterns = rearrangement_patterns + [
testcase('b c h w -> b ()', dict(b=10), (10, 20, 30, 40),
[(10,), (10, 20, 30)]),
testcase('b c (h1 h2) (w1 w2) -> b c h1 w1', dict(h1=15, h2=2, w2=2), (10, 20, 30, 40),
[(10, 20, 31, 40)]),
testcase('b ... c -> b', dict(b=10), (10, 20, 30, 40),
[(10,), (11, 10)]),
]
equivalent_reduction_patterns = [
('a b c d e -> ', ' ... -> '),
('a b c d e -> (e a)', 'a ... e -> (e a)'),
('a b c d e -> d (a e)', ' a b c d e ... -> d (a e) '),
('a b c d e -> (a b)', ' ... c d e -> (...) '),
]
equivalent_rearrange_patterns = [
('a b c d e -> (a b) c d e', 'a b ... -> (a b) ... '),
('a b c d e -> a b (c d) e', '... c d e -> ... (c d) e'),
('a b c d e -> a b c d e', '... -> ... '),
('a b c d e -> (a b c d e)', '... -> (...)'),
('a b c d e -> b (c d e) a', 'a b ... -> b (...) a'),
('a b c d e -> b (a c d) e', 'a b ... e -> b (a ...) e'),
]
identity_patterns = [
'...->...',
'a b c d e-> a b c d e',
'a b c d e ...-> ... a b c d e',
'a b c d e ...-> a ... b c d e',
'... a b c d e -> ... a b c d e',
'a ... e-> a ... e',
'a ... -> a ... ',
'a ... c d e -> a (...) c d e',
]
test_cases_repeat_anonymous = [
# all assume that input has shape [1, 2, 4, 6]
('a b c d -> c a d b', dict()),
('a b c d -> (c 2 d a b)', dict(a=1, c=4, d=6)),
('1 b c d -> (d copy 1) 3 b c ', dict(copy=3)),
('1 ... -> 3 ... ', dict()),
('() ... d -> 1 (copy1 d copy2) ... ', dict(copy1=2, copy2=3)),
('1 b c d -> (1 1) (1 b) 2 c 3 d (1 1)', dict()),
]
repeat_test_cases = [
# all assume that input has shape [2, 3, 5]
('a b c -> c a b', dict()),
('a b c -> (c copy a b)', dict(copy=2, a=2, b=3, c=5)),
('a b c -> (a copy) b c ', dict(copy=1)),
('a b c -> (c a) (copy1 b copy2)', dict(a=2, copy1=1, copy2=2)),
('a ... -> a ... copy', dict(copy=4)),
('... c -> ... (copy1 c copy2)', dict(copy1=1, copy2=2)),
('... -> ... ', dict()),
(' ... -> copy1 ... copy2 ', dict(copy1=2, copy2=3)),
('a b c -> copy1 a copy2 b c () ', dict(copy1=2, copy2=1)),
]
def check_reversion(x, repeat_pattern, **sizes):
"""Checks repeat pattern by running reduction """
left, right = repeat_pattern.split('->')
reduce_pattern = right + '->' + left
repeated = reduce(x, repeat_pattern, reduction='repeat', **sizes)
reduced_min = reduce(repeated, reduce_pattern, reduction='min', **sizes)
reduced_max = reduce(repeated, reduce_pattern, reduction='max', **sizes)
assert numpy.array_equal(x, reduced_min)
assert numpy.array_equal(x, reduced_max)
def check_op_against_numpy(backend, numpy_input, pattern, axes_lengths, reduction='rearrange', is_symbolic=False):
"""
Helper to test result of operation (rearrange or transpose) against numpy
if reduction == 'rearrange', rearrange op is tested, otherwise reduce
"""
if len(numpy_input.shape) == 0:
return
def operation(x):
if reduction == 'rearrange':
return rearrange(x, pattern, **axes_lengths)
else:
return reduce(x, pattern, reduction, **axes_lengths)
numpy_result = operation(numpy_input)
check_equal = numpy.array_equal
p_none_dimension = 0.5
if 'jittor' in backend.framework_name:
check_equal = numpy.allclose
p_none_dimension = 0
if is_symbolic:
symbol_shape = [d if numpy.random.random() >= p_none_dimension else None for d in numpy_input.shape]
symbol = backend.create_symbol(shape=symbol_shape)
result_symbol = operation(symbol)
backend_result = backend.eval_symbol(result_symbol, [(symbol, numpy_input)])
else:
backend_result = operation(backend.from_numpy(numpy_input))
backend_result = backend.to_numpy(backend_result)
check_equal(numpy_result, backend_result)
def create_jittor_model(use_reduce=False):
from jittor.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
from jittor.einops.layers.jittor import Rearrange, Reduce, EinMix
return Sequential(
Conv2d(3, 6, kernel_size=(5, 5)),
Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2),
Conv2d(6, 16, kernel_size=(5, 5)),
Reduce('b c (h h2) (w w2) -> b c h w', 'max', h2=2, w2=2),
Rearrange('b c h w -> b (c h w)'),
Linear(16 * 5 * 5, 120),
ReLU(),
Linear(120, 84),
ReLU(),
EinMix('b c1 -> (b c2)', weight_shape='c1 c2', bias_shape='c2', c1=84, c2=84),
EinMix('(b c2) -> b c3', weight_shape='c2 c3', bias_shape='c3', c2=84, c3=84),
Linear(84, 10),
)
if __name__ == '__main__':
unittest.main()