add several ffunctions

This commit is contained in:
514flowey 2024-08-20 15:08:19 +08:00
parent c124023085
commit 822955ac00
6 changed files with 856 additions and 17 deletions

View File

@ -2140,6 +2140,7 @@ from . import sparse
from . import optim
from . import dataset
from . import init
from . import gradfunctional
dtype = NanoString

View File

@ -1,7 +1,7 @@
from jittor_core import *
from jittor_core.ops import *
from .misc import *
from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse
from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse, gradfunctional as gradfunctional
from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
from .compiler import compile_custom_op as compile_custom_op, compile_custom_ops as compile_custom_ops
from .contrib import concat as concat

View File

@ -0,0 +1,2 @@
from .functional import jvp, vjp

View File

@ -0,0 +1,420 @@
# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/torch/autograd/functional.py
import jittor as jt
__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
# Utility functions
def _as_tuple_nocheck(x):
if isinstance(x, tuple):
return x
elif isinstance(x, list):
return tuple(x)
else:
return (x,)
def _as_tuple(inp, arg_name=None, fn_name=None):
# Ensures that inp is a tuple of Tensors
# Returns whether or not the original inp was a tuple and the tupled version of the input
if arg_name is None and fn_name is None:
return _as_tuple_nocheck(inp)
is_inp_tuple = True
if not isinstance(inp, tuple):
inp = (inp,)
is_inp_tuple = False
for i, el in enumerate(inp):
if not isinstance(el, (jt.Var, jt.nn.ComplexNumber)):
if is_inp_tuple:
raise TypeError(
f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
f" value at index {i} has type {type(el)}."
)
else:
raise TypeError(
f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
f" given {arg_name} has type {type(el)}."
)
return is_inp_tuple, inp
def _tuple_postprocess(res, to_unpack):
# Unpacks a potentially nested tuple of Tensors
# to_unpack should be a single boolean or a tuple of two booleans.
# It is used to:
# - invert _as_tuple when res should match the inp given to _as_tuple
# - optionally remove nesting of two tuples created by multiple calls to _as_tuple
if isinstance(to_unpack, tuple):
assert len(to_unpack) == 2
if not to_unpack[1]:
res = tuple(el[0] for el in res)
if not to_unpack[0]:
res = res[0]
else:
if not to_unpack:
res = res[0]
return res
def _grad_preprocess(inputs, create_graph, need_graph):
# Preprocess the inputs to make sure they require gradient
# inputs is a tuple of Tensors to preprocess
# create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
# need_graph specifies if we internally want gradients to flow back to the Tensors in res
# Note that we *always* create a new Tensor object to be able to see the difference between
# inputs given as arguments and the same Tensors automatically captured by the user function.
# Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
res = []
for inp in inputs:
if create_graph and inp.requires_grad:
# Create at least a new Tensor object in a differentiable way
# Use .reshae() to get a shallow copy
res.append(inp.reshape(inp.shape))
else:
if need_graph:
ninp = inp.detach().start_grad()
else:
ninp = inp.detach().stop_grad()
res.append(ninp)
return tuple(res)
def _grad_postprocess(inputs, create_graph):
# Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
# request it.
if isinstance(inputs[0], (jt.Var, jt.nn.ComplexNumber)):
if not create_graph:
return tuple(inp.detach() for inp in inputs)
else:
return inputs
else:
return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
def _validate_v(v, other, is_other_tuple):
# This assumes that other is the correct shape, and v should match
# Both are assumed to be tuples of Tensors
if len(other) != len(v):
if is_other_tuple:
raise RuntimeError(
f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}."
)
else:
raise RuntimeError("The given v should contain a single Tensor.")
for idx, (el_v, el_other) in enumerate(zip(v, other)):
if el_v.shape != el_other.shape:
prepend = ""
if is_other_tuple:
prepend = f"Entry {idx} in "
raise RuntimeError(
f"{prepend}v has invalid size: should be {el_other.shape} but got {el_v.shape}."
)
def _check_requires_grad(inputs, input_type, strict):
# Used to make all the necessary checks to raise nice errors in strict mode.
if not strict:
return
if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
raise RuntimeError("Invalid input_type to _check_requires_grad")
for i, inp in enumerate(inputs):
if inp is None:
# This can only be reached for grad_inputs.
raise RuntimeError(
f"The output of the user-provided function is independent of input {i}."
" This is not allowed in strict mode."
)
if not inp.requires_grad:
if input_type == "hessian":
raise RuntimeError(
f"The hessian of the user-provided function with respect to input {i}"
" is independent of the input. This is not allowed in strict mode."
" You should ensure that your function is thrice differentiable and that"
" the hessian depends on the inputs."
)
elif input_type == "jacobian":
raise RuntimeError(
"While computing the hessian, found that the jacobian of the user-provided"
f" function with respect to input {i} is independent of the input. This is not"
" allowed in strict mode. You should ensure that your function is twice"
" differentiable and that the jacobian depends on the inputs (this would be"
" violated by a linear function for example)."
)
elif input_type == "grad_inputs":
raise RuntimeError(
f"The gradient with respect to input {i} is independent of the inputs of the"
" user-provided function. This is not allowed in strict mode."
)
else:
raise RuntimeError(
f"Output {i} of the user-provided function does not require gradients."
" The outputs must be computed in a differentiable manner from the input"
" when running in strict mode."
)
def _autograd_grad(
outputs,
inputs,
grad_outputs=None,
create_graph=True,
):
# Version of grad that accepts `None` in outputs and do not compute gradients for them.
# This has the extra constraint that inputs has to be a tuple
assert isinstance(outputs, tuple)
if grad_outputs is None:
grad_outputs = (None,) * len(outputs)
assert isinstance(grad_outputs, tuple)
assert len(outputs) == len(grad_outputs)
new_outputs = ()
new_grad_outputs = ()
for out, grad_out in zip(outputs, grad_outputs):
if out is not None and out.requires_grad:
new_outputs += (out,)
new_grad_outputs += (grad_out,)
if len(new_outputs) == 0:
# No differentiable output, we don't need to call the autograd engine
return (None,) * len(inputs)
else:
acc_loss = None
for new_output, grad_output in zip(new_outputs, grad_outputs):
if isinstance(new_output, jt.nn.ComplexNumber):
if grad_output is not None:
loss = (new_output.value * grad_output.value).sum()
else:
loss = new_output.value.sum()
else:
if grad_output is not None:
new_output = new_output * grad_output
loss = new_output.sum()
if acc_loss is None:
acc_loss = loss
else:
acc_loss += loss
complex_inds = []
var_inputs = []
for idx, inp in enumerate(inputs):
if isinstance(inp, jt.nn.ComplexNumber):
var_inputs.append(inp.value)
complex_inds.append(idx)
else:
var_inputs.append(inp)
grads = jt.grad(acc_loss, var_inputs, retain_graph=create_graph)
for complex_ind in complex_inds:
grads[complex_ind] = jt.nn.ComplexNumber(grads[complex_ind], is_concat_value=True)
return tuple(grads)
def _fill_in_zeros(grads, refs, strict, create_graph, stage):
# Used to detect None in the grads and depending on the flags, either replace them
# with Tensors full of 0s of the appropriate size based on the refs or raise an error.
# strict and create graph allow us to detect when it is appropriate to raise an error
# stage gives us information of which backward call we consider to give good error message
if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros")
res = ()
for i, grads_i in enumerate(grads):
if grads_i is None:
if strict:
if stage == "back":
raise RuntimeError(
"The output of the user-provided function is independent of "
f"input {i}. This is not allowed in strict mode."
)
elif stage == "back_trick":
raise RuntimeError(
f"The gradient with respect to the input is independent of entry {i}"
" in the grad_outputs when using the double backward trick to compute"
" forward mode gradients. This is not allowed in strict mode."
)
elif stage == "double_back":
raise RuntimeError(
"The jacobian of the user-provided function is independent of "
f"input {i}. This is not allowed in strict mode."
)
else:
raise RuntimeError(
"The hessian of the user-provided function is independent of "
f"entry {i} in the grad_jacobian. This is not allowed in strict "
"mode as it prevents from using the double backward trick to "
"replace forward mode AD."
)
refs_i = refs[i]
if isinstance(refs_i, jt.nn.ComplexNumber):
grads_i = jt.nn.ComplexNumber(jt.zeros_like(refs_i.value), is_concat_value=True)
else:
grads_i = jt.zeros_like(refs_i)
else:
if strict and create_graph and not grads_i.requires_grad:
if "double" not in stage:
raise RuntimeError(
"The jacobian of the user-provided function is independent of "
f"input {i}. This is not allowed in strict mode when create_graph=True."
)
else:
raise RuntimeError(
"The hessian of the user-provided function is independent of "
f"input {i}. This is not allowed in strict mode when create_graph=True."
)
res += (grads_i,)
return res
# Public API
def vjp(func, inputs, v=None, create_graph=False, strict=False):
r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs.
Args:
func (function): a Python function that takes Tensor inputs and returns
a tuple of Tensors or a Tensor.
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
v (tuple of Tensors or Tensor): The vector for which the vector
Jacobian product is computed. Must be the same size as the output
of ``func``. This argument is optional when the output of ``func``
contains a single element and (if it is not provided) will be set
as a Tensor containing a single ``1``.
create_graph (bool, optional): If ``True``, both the output and result
will be computed in a differentiable way. Note that when ``strict``
is ``False``, the result can not require gradients or be
disconnected from the inputs. Defaults to ``False``.
strict (bool, optional): If ``True``, an error will be raised when we
detect that there exists an input such that all the outputs are
independent of it. If ``False``, we return a Tensor of zeros as the
vjp for said inputs, which is the expected mathematical value.
Defaults to ``False``.
Returns:
output (tuple): tuple with:
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
vjp (tuple of Tensors or Tensor): result of the dot product with
the same shape as the inputs.
"""
with jt.enable_grad():
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
outputs = func(*inputs)
is_outputs_tuple, outputs = _as_tuple(
outputs, "outputs of the user-provided function", "vjp"
)
_check_requires_grad(outputs, "outputs", strict=strict)
if v is not None:
_, v = _as_tuple(v, "v", "vjp")
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
_validate_v(v, outputs, is_outputs_tuple)
else:
if len(outputs) != 1 or outputs[0].nelement() != 1:
raise RuntimeError(
"The vector v can only be None if the "
"user-provided function returns "
"a single Tensor with a single element."
)
with jt.enable_grad():
grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
# Cleanup objects and return them to the user
outputs = _grad_postprocess(outputs, create_graph)
vjp = _grad_postprocess(vjp, create_graph)
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
vjp, is_inputs_tuple
)
def jvp(func, inputs, v=None, create_graph=False, strict=False):
r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``.
Args:
func (function): a Python function that takes Tensor inputs and returns
a tuple of Tensors or a Tensor.
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
v (tuple of Tensors or Tensor): The vector for which the Jacobian
vector product is computed. Must be the same size as the input of
``func``. This argument is optional when the input to ``func``
contains a single element and (if it is not provided) will be set
as a Tensor containing a single ``1``.
create_graph (bool, optional): If ``True``, both the output and result
will be computed in a differentiable way. Note that when ``strict``
is ``False``, the result can not require gradients or be
disconnected from the inputs. Defaults to ``False``.
strict (bool, optional): If ``True``, an error will be raised when we
detect that there exists an input such that all the outputs are
independent of it. If ``False``, we return a Tensor of zeros as the
jvp for said inputs, which is the expected mathematical value.
Defaults to ``False``.
Returns:
output (tuple): tuple with:
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
jvp (tuple of Tensors or Tensor): result of the dot product with
the same shape as the output.
"""
with jt.enable_grad():
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
if v is not None:
_, v = _as_tuple(v, "v", "jvp")
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
_validate_v(v, inputs, is_inputs_tuple)
else:
if len(inputs) != 1 or inputs[0].nelement() != 1:
raise RuntimeError(
"The vector v can only be None if the input to "
"the user-provided function is a single Tensor "
"with a single element."
)
outputs = func(*inputs)
is_outputs_tuple, outputs = _as_tuple(
outputs, "outputs of the user-provided function", "jvp"
)
_check_requires_grad(outputs, "outputs", strict=strict)
# The backward is linear so the value of grad_outputs is not important as
# it won't appear in the double backward graph. We only need to ensure that
# it does not contain inf or nan.
grad_outputs = tuple(
jt.nn.ComplexNumber(jt.zeros_like(out.value), is_concat_value=True) if isinstance(out, jt.nn.ComplexNumber) else jt.zeros_like(out)
for out in outputs
)
grad_inputs = _autograd_grad(outputs, inputs, grad_outputs=grad_outputs, create_graph=True)
_check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
if create_graph:
with jt.enable_grad():
grad_res = _autograd_grad(
grad_inputs, grad_outputs, v, create_graph=create_graph
)
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
else:
grad_res = _autograd_grad(
grad_inputs, grad_outputs, v, create_graph=create_graph
)
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
# Cleanup objects and return them to the user
outputs = _grad_postprocess(outputs, create_graph)
jvp = _grad_postprocess(jvp, create_graph)
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
jvp, is_outputs_tuple
)

View File

@ -3130,6 +3130,10 @@ class ComplexNumber:
assert real.dtype == imag.dtype
self.value = jt.stack([real, imag], dim=-1)
@property
def requires_grad(self):
return self.value.requires_grad
@property
def real(self):
return self.value[..., 0]
@ -3142,6 +3146,10 @@ class ComplexNumber:
def shape(self):
return self.value.shape[:-1]
@property
def dtype(self):
return "complex64"
def norm(self):
return jt.sqrt(jt.sqr(self.real) + jt.sqr(self.imag))
@ -3287,6 +3295,129 @@ def view_as_complex(x: jt.Var) -> ComplexNumber:
def view_as_real(x: ComplexNumber) -> jt.Var:
return jt.stack([x.value[...,0],x.value[...,1]],dim=-1)
# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/torch/functional.py#L1258
def tensordot(a, b, dims=2):
r"""Returns a contraction of a and b over multiple dimensions.
:attr:`tensordot` implements a generalized matrix product.
Args:
a (Tensor): Left tensor to contract
b (Tensor): Right tensor to contract
dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
contract or explicit lists of dimensions for :attr:`a` and
:attr:`b` respectively
When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
respectively, :func:`tensordot` computes
.. math::
r_{i_0,...,i_{m-d}, i_d,...,i_n}
= \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
When called with :attr:`dims` of the list form, the given dimensions will be contracted
in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
in these dimensions must match.
"""
if not isinstance(dims, (tuple, list, int)):
raise RuntimeError(
"tensordot expects dims to be int or "
+ "Tuple[List[int], List[int]] or "
+ "List[List[int]] containing two lists, but got "
+ f"dims={dims}"
)
dims_a, dims_b = [], []
if isinstance(dims, (tuple, list)):
dims_a, dims_b = dims
if isinstance(dims, (int)):
if dims < 0:
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
if dims > min(len(a.shape), len(b.shape)):
raise RuntimeError(
f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
)
dims_a = list(range(len(a.shape)-dims, len(a.shape)))
dims_b = list(range(dims))
# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/aten/src/ATen/native/Linear.cpp#L769
def __tensordot_native(input1:jt.Var, input2:jt.Var, dims1, dims2):
if not isinstance(dims1, (list, tuple)):
raise RuntimeError("tensordot expects dims1 to be List[Int], but got dims={}".format(dims1))
if not isinstance(dims2, (list, tuple)):
raise RuntimeError("tensordot expects dims2 to be List[Int], but got dims={}".format(dims2))
dims1 = list(dims1)
dims2 = list(dims2)
if len(dims1) != len(dims2):
raise RuntimeError("both dimension lists should have the same length")
if input1.dtype != input2.dtype:
raise RuntimeError("both inputs should have the same dtype")
t1 = input1
t2 = input2
csize = 1
input1_bitmap = np.zeros(len(input1.shape), dtype='bool')
input2_bitmap = np.zeros(len(input2.shape), dtype='bool')
for i in range(len(dims1)):
s1 = input1.shape[dims1[i]]
s2 = input2.shape[dims2[i]]
input1_bitmap[dims1] = True
input2_bitmap[dims2] = True
if s2 == 1: #broadcasted dimensions can be summed right away
t1 = t1.sum(dims1[i], keepdims=True)
elif s1 == 1:
t2 = t2.sum(dims2[i], keepdims=True)
else:
if s1 != s2:
raise RuntimeError("contracted dimensions need to match, but first has size {}, in dim {}, and second has size {}".format(s1, i, s2))
csize *= s1
p1, p2 = [], [] # p1, p2: input permutations
rsizes = []
size1, size2 = 1, 1 # number of non-contracted elements
for i in range(len(input1.shape)):
if not input1_bitmap[i]:
p1.append(i)
size1 *= t1.shape[i]
rsizes.append(t1.shape[i])
p1 += dims1
p2 += dims2
for i in range(len(input2.shape)):
if not input2_bitmap[i]:
p2.append(i)
size2 *= t2.shape[i]
rsizes.append(t2.shape[i])
# permute and reshape for matrix multiplication
t1 = t1.permute(p1).reshape((size1, csize))
t2 = t2.permute(p2).reshape((csize, size2))
# multiply and reshape to target size
return jt.matmul(t1, t2).reshape(rsizes)
return __tensordot_native(a, b, dims_a, dims_b)
# reference: https://github.com/pytorch/pytorch/blob/5ed3b70d09a4ab2a5be4becfda9dd0d3e3227c39/aten/src/ATen/native/LinearAlgebra.cpp#L3375
def kron(a:jt.Var, b:jt.Var):
a_dim, b_dim = len(a.shape), len(b.shape)
max_dim = max(a_dim, b_dim)
pad_a, pad_b = max_dim-a_dim, max_dim-b_dim
a_reshape, b_reshape = [], []
result_reshape = []
for i in range(max_dim):
a_2i_shape = a.shape[i - pad_a] if i >= pad_a else 1
b_2i1_shape = b.shape[i - pad_b] if i >= pad_b else 1
a_reshape.append(a_2i_shape)
a_reshape.append(1)
b_reshape.append(1)
b_reshape.append(b_2i1_shape)
result_reshape.append(a_2i_shape * b_2i1_shape)
a = a.reshape(a_reshape)
b = b.reshape(b_reshape)
return (a * b).reshape(result_reshape)
def one_hot(x: jt.Var, num_classes: int=-1) -> jt.Var:
''' Returns the one_hot encoding of inputs.

View File

@ -2,6 +2,7 @@ import jittor as jt
from jittor.nn import ComplexNumber
import unittest
import numpy as np
from functools import partial
__skip_torch_test = False
try:
@ -10,6 +11,15 @@ except:
__skip_torch_test = True
class TestResultAndGrad:
def flatten_list(self, list_like):
results = []
if isinstance(list_like, (list, tuple)):
for x in list_like:
results.extend(self.flatten_list(x))
return results
else:
return [list_like]
def check_results(self, rlist1, rlist2):
assert len(rlist1) == len(rlist2)
for r1, r2 in zip(rlist1, rlist2):
@ -36,13 +46,21 @@ class TestResultAndGrad:
grads.append(g.detach().cpu().numpy())
return grads
def run_jittor_op(self, op, input_list, weights=None):
def _np_to_jittor(x:np.ndarray):
if x.dtype == np.complex64 or x.dtype == np.complex128:
nx = np.stack([np.real(x), np.imag(x)], axis=-1)
return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True)
elif x.dtype == np.float32 or x.dtype == np.float64:
return jt.array(x, dtype=jt.float32)
def run_jittor_op(self, op, input_list, weights=None, key_names=None, **kwargs):
def _np_to_jittor(x):
if isinstance(x, np.ndarray):
if x.dtype == np.complex64 or x.dtype == np.complex128:
nx = np.stack([np.real(x), np.imag(x)], axis=-1)
return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True)
elif x.dtype == np.float32 or x.dtype == np.float64:
return jt.array(x, dtype=jt.float32)
else:
assert False
elif isinstance(x, (list, tuple)):
nx = [_np_to_jittor(vx) for vx in x]
if isinstance(x, tuple):
return tuple(nx)
return nx
else:
assert False
def _jittor_to_np(x):
@ -51,11 +69,19 @@ class TestResultAndGrad:
elif isinstance(x, ComplexNumber):
return x.real.numpy() + 1j * x.imag.numpy()
assert False
ninput_list = [_np_to_jittor(x) for x in input_list]
output_list = op(*ninput_list)
if key_names != None:
assert len(ninput_list) == len(key_names)
nkwargs = kwargs.copy()
for k, v in zip(key_names, ninput_list):
nkwargs[k] = v
output_list = op(**nkwargs)
else:
output_list = op(*ninput_list, **kwargs)
if isinstance(output_list, (jt.Var, ComplexNumber)):
output_list = [output_list]
output_list = self.flatten_list(output_list)
losses = []
if weights is None:
weights = []
@ -73,15 +99,31 @@ class TestResultAndGrad:
output_list = [_jittor_to_np(x) for x in output_list]
return ninput_list, output_list, losses, weights
def run_torch_op(self, op, input_list, weights=None):
def _np_to_torch(x:np.ndarray):
return torch.from_numpy(x).requires_grad_(True)
def run_torch_op(self, op, input_list, weights=None, key_names=None, **kwargs):
def _np_to_torch(x):
if isinstance(x, np.ndarray):
return torch.from_numpy(x).requires_grad_(True)
elif isinstance(x, (list, tuple)):
nx = [_np_to_torch(vx) for vx in x]
if isinstance(x, tuple):
return tuple(nx)
return nx
else:
assert False
def _torch_to_np(x:torch.Tensor) -> np.ndarray:
return x.detach().cpu().numpy()
ninput_list = [_np_to_torch(x) for x in input_list]
output_list = op(*ninput_list)
if key_names != None:
assert len(ninput_list) == len(key_names)
nkwargs = kwargs.copy()
for k, v in zip(key_names, ninput_list):
nkwargs[k] = v
output_list = op(**nkwargs)
else:
output_list = op(*ninput_list, **kwargs)
if isinstance(output_list, torch.Tensor):
output_list = [output_list]
output_list = self.flatten_list(output_list)
losses = []
if weights is None:
weights = []
@ -99,10 +141,10 @@ class TestResultAndGrad:
output_list = [_torch_to_np(x) for x in output_list]
return ninput_list, output_list, losses, weights
def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True):
def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True, jittor_knames=None, torch_knames=None, **kwargs):
weights = None
jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights)
torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights)
jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights, key_names=jittor_knames, **kwargs)
torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights, key_names=torch_knames, **kwargs)
self.check_results(jittor_output, torch_output)
if check_grad:
@ -195,6 +237,249 @@ class TestComplexLinalg(unittest.TestCase, TestResultAndGrad):
inputs = [m1]
self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs)
class TestTensordot(unittest.TestCase, TestResultAndGrad):
def random_complex_matrix(self, shape):
r = np.random.randn(*shape)
i = np.random.randn(*shape)
return r + 1j * i
def random_real_matrix(self, shape):
return np.random.randn(*shape)
def test_complex_tensordot_numberdim(self):
s1 = (3, 4, 5)
s2 = (4, 5, 6)
dims = 2
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
def test_complex_tensordot_tupledim(self):
s1 = (3, 5, 4, 6)
s2 = (6, 4, 5, 3)
dims = ([2, 1, 3], [1, 2, 0])
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
def test_real_tensordot_numberdim(self):
s1 = (3, 4, 5)
s2 = (4, 5, 6)
dims = 2
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
def test_real_tensordot_tupledim(self):
s1 = (3, 5, 4, 6)
s2 = (6, 4, 5, 3)
dims = ([2, 1, 3], [1, 2, 0])
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
class TestKron(unittest.TestCase, TestResultAndGrad):
def random_complex_matrix(self, shape):
r = np.random.randn(*shape)
i = np.random.randn(*shape)
return r + 1j * i
def random_real_matrix(self, shape):
return np.random.randn(*shape)
def test_complex_firstlarge(self):
s1 = (2, 3, 4)
s2 = (5, 2)
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
def test_complex_second_large(self):
s1 = (2, 3)
s2 = (5, 2, 4)
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
def test_real_firstlarge(self):
s1 = (2, 3, 4)
s2 = (5, 2)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
def test_real_second_large(self):
s1 = (2, 3)
s2 = (5, 2, 4)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
@unittest.skipIf(__skip_torch_test, "No Torch found")
class TestGradFunctional(unittest.TestCase, TestResultAndGrad):
def random_complex_matrix(self, shape):
r = np.random.randn(*shape)
i = np.random.randn(*shape)
return r + 1j * i
def random_real_matrix(self, shape):
return np.random.randn(*shape) * 0.0 + 1.0
def test_real_jvp_exp(self):
def exp_reducer(x):
return x.exp().sum(dim=1)
s1 = (5, 6)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s1)
inputs = [m1, m2]
self.check_op_with_torch(
partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True),
partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False)
def test_complex_jvp_exp(self):
def exp_reducer(x):
return x.exp().sum(1)
s1 = (5, 6)
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s1)
inputs = [m1, m2]
self.check_op_with_torch(
partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True),
partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
def test_real_jvp_add(self):
w1, w2 = np.random.rand(), np.random.rand()
def adder(x, y):
return w1 * x + w2 * y
s1 = (5, 6)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s1)
m3 = self.random_real_matrix(s1)
m4 = self.random_real_matrix(s1)
inputs = [(m1, m2), (m3, m4)]
self.check_op_with_torch(
partial(jt.gradfunctional.jvp, func=adder, create_graph=True),
partial(torch.autograd.functional.jvp, func=adder, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
def test_complex_jvp_add(self):
w1r, w1i = np.random.rand(), np.random.rand()
w2r, w2i = np.random.rand(), np.random.rand()
def adder_pt(x, y):
return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y
def adder_jt(x, y):
w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1))
w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1))
return w1 * x + w2 * y
s1 = (5, 6)
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s1)
m3 = self.random_complex_matrix(s1)
m4 = self.random_complex_matrix(s1)
inputs = [(m1, m2), (m3, m4)]
self.check_op_with_torch(
partial(jt.gradfunctional.jvp, func=adder_jt, create_graph=True),
partial(torch.autograd.functional.jvp, func=adder_pt, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
def test_real_vjp_exp(self):
def exp_reducer(x):
return x.exp().sum(dim=1)
s1 = (5, 6)
s2 = (5,)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(
partial(jt.gradfunctional.vjp, func=exp_reducer),
partial(torch.autograd.functional.vjp, func=exp_reducer),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False)
def test_complex_vjp_exp(self):
def exp_reducer(x):
return x.exp().sum(1)
s1 = (5, 6)
s2 = (5,)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s2)
inputs = [m1, m2]
self.check_op_with_torch(
partial(jt.gradfunctional.vjp, func=exp_reducer, create_graph=True),
partial(torch.autograd.functional.vjp, func=exp_reducer, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
def test_real_vjp_add(self):
w1, w2 = np.random.rand(), np.random.rand()
def adder(x, y):
return w1 * x + w2 * y
s1 = (5, 6)
m1 = self.random_real_matrix(s1)
m2 = self.random_real_matrix(s1)
m3 = self.random_real_matrix(s1)
inputs = [(m1, m2), m3]
self.check_op_with_torch(
partial(jt.gradfunctional.vjp, func=adder, create_graph=True),
partial(torch.autograd.functional.vjp, func=adder, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
def test_complex_vjp_add(self):
w1r, w1i = np.random.rand(), np.random.rand()
w2r, w2i = np.random.rand(), np.random.rand()
def adder_pt(x, y):
return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y
def adder_jt(x, y):
w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1))
w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1))
return w1 * x + w2 * y
s1 = (5, 6)
m1 = self.random_complex_matrix(s1)
m2 = self.random_complex_matrix(s1)
m3 = self.random_complex_matrix(s1)
inputs = [(m1, m2), (m3)]
self.check_op_with_torch(
partial(jt.gradfunctional.vjp, func=adder_jt, create_graph=True),
partial(torch.autograd.functional.vjp, func=adder_pt, create_graph=True),
inputs,
jittor_knames = ['inputs', 'v'],
torch_knames = ['inputs', 'v'],
check_grad=False,
)
if __name__ == "__main__":
unittest.main()