mirror of https://github.com/Jittor/Jittor
add several ffunctions
This commit is contained in:
parent
c124023085
commit
822955ac00
|
@ -2140,6 +2140,7 @@ from . import sparse
|
|||
from . import optim
|
||||
from . import dataset
|
||||
from . import init
|
||||
from . import gradfunctional
|
||||
|
||||
dtype = NanoString
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from .misc import *
|
||||
from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse
|
||||
from . import attention as attention, contrib as contrib, dataset as dataset, init as init, linalg as linalg, lr_scheduler as lr_scheduler, numpy2cupy as numpy2cupy, optim as optim, sparse as sparse, gradfunctional as gradfunctional
|
||||
from .compile_extern import cublas as cublas, cudnn as cudnn, cufft as cufft, curand as curand, mkl_ops as mkl_ops, mpi_ops as mpi_ops, world_size as world_size
|
||||
from .compiler import compile_custom_op as compile_custom_op, compile_custom_ops as compile_custom_ops
|
||||
from .contrib import concat as concat
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
from .functional import jvp, vjp
|
||||
|
|
@ -0,0 +1,420 @@
|
|||
# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/torch/autograd/functional.py
|
||||
import jittor as jt
|
||||
|
||||
__all__ = ["vjp", "jvp", "jacobian", "hessian", "hvp", "vhp"]
|
||||
|
||||
# Utility functions
|
||||
def _as_tuple_nocheck(x):
|
||||
if isinstance(x, tuple):
|
||||
return x
|
||||
elif isinstance(x, list):
|
||||
return tuple(x)
|
||||
else:
|
||||
return (x,)
|
||||
|
||||
def _as_tuple(inp, arg_name=None, fn_name=None):
|
||||
# Ensures that inp is a tuple of Tensors
|
||||
# Returns whether or not the original inp was a tuple and the tupled version of the input
|
||||
if arg_name is None and fn_name is None:
|
||||
return _as_tuple_nocheck(inp)
|
||||
|
||||
is_inp_tuple = True
|
||||
if not isinstance(inp, tuple):
|
||||
inp = (inp,)
|
||||
is_inp_tuple = False
|
||||
|
||||
for i, el in enumerate(inp):
|
||||
if not isinstance(el, (jt.Var, jt.nn.ComplexNumber)):
|
||||
if is_inp_tuple:
|
||||
raise TypeError(
|
||||
f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
|
||||
f" value at index {i} has type {type(el)}."
|
||||
)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"The {arg_name} given to {fn_name} must be either a Tensor or a tuple of Tensors but the"
|
||||
f" given {arg_name} has type {type(el)}."
|
||||
)
|
||||
|
||||
return is_inp_tuple, inp
|
||||
|
||||
|
||||
def _tuple_postprocess(res, to_unpack):
|
||||
# Unpacks a potentially nested tuple of Tensors
|
||||
# to_unpack should be a single boolean or a tuple of two booleans.
|
||||
# It is used to:
|
||||
# - invert _as_tuple when res should match the inp given to _as_tuple
|
||||
# - optionally remove nesting of two tuples created by multiple calls to _as_tuple
|
||||
if isinstance(to_unpack, tuple):
|
||||
assert len(to_unpack) == 2
|
||||
if not to_unpack[1]:
|
||||
res = tuple(el[0] for el in res)
|
||||
if not to_unpack[0]:
|
||||
res = res[0]
|
||||
else:
|
||||
if not to_unpack:
|
||||
res = res[0]
|
||||
return res
|
||||
|
||||
|
||||
def _grad_preprocess(inputs, create_graph, need_graph):
|
||||
# Preprocess the inputs to make sure they require gradient
|
||||
# inputs is a tuple of Tensors to preprocess
|
||||
# create_graph specifies if the user wants gradients to flow back to the Tensors in inputs
|
||||
# need_graph specifies if we internally want gradients to flow back to the Tensors in res
|
||||
# Note that we *always* create a new Tensor object to be able to see the difference between
|
||||
# inputs given as arguments and the same Tensors automatically captured by the user function.
|
||||
# Check this issue for more details on how that can happen: https://github.com/pytorch/pytorch/issues/32576
|
||||
res = []
|
||||
for inp in inputs:
|
||||
if create_graph and inp.requires_grad:
|
||||
# Create at least a new Tensor object in a differentiable way
|
||||
# Use .reshae() to get a shallow copy
|
||||
res.append(inp.reshape(inp.shape))
|
||||
else:
|
||||
if need_graph:
|
||||
ninp = inp.detach().start_grad()
|
||||
else:
|
||||
ninp = inp.detach().stop_grad()
|
||||
res.append(ninp)
|
||||
return tuple(res)
|
||||
|
||||
|
||||
def _grad_postprocess(inputs, create_graph):
|
||||
# Postprocess the generated Tensors to avoid returning Tensors with history when the user did not
|
||||
# request it.
|
||||
if isinstance(inputs[0], (jt.Var, jt.nn.ComplexNumber)):
|
||||
if not create_graph:
|
||||
return tuple(inp.detach() for inp in inputs)
|
||||
else:
|
||||
return inputs
|
||||
else:
|
||||
return tuple(_grad_postprocess(inp, create_graph) for inp in inputs)
|
||||
|
||||
|
||||
def _validate_v(v, other, is_other_tuple):
|
||||
# This assumes that other is the correct shape, and v should match
|
||||
# Both are assumed to be tuples of Tensors
|
||||
if len(other) != len(v):
|
||||
if is_other_tuple:
|
||||
raise RuntimeError(
|
||||
f"v is a tuple of invalid length: should be {len(other)} but got {len(v)}."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("The given v should contain a single Tensor.")
|
||||
|
||||
for idx, (el_v, el_other) in enumerate(zip(v, other)):
|
||||
if el_v.shape != el_other.shape:
|
||||
prepend = ""
|
||||
if is_other_tuple:
|
||||
prepend = f"Entry {idx} in "
|
||||
raise RuntimeError(
|
||||
f"{prepend}v has invalid size: should be {el_other.shape} but got {el_v.shape}."
|
||||
)
|
||||
|
||||
|
||||
def _check_requires_grad(inputs, input_type, strict):
|
||||
# Used to make all the necessary checks to raise nice errors in strict mode.
|
||||
if not strict:
|
||||
return
|
||||
|
||||
if input_type not in ["outputs", "grad_inputs", "jacobian", "hessian"]:
|
||||
raise RuntimeError("Invalid input_type to _check_requires_grad")
|
||||
for i, inp in enumerate(inputs):
|
||||
if inp is None:
|
||||
# This can only be reached for grad_inputs.
|
||||
raise RuntimeError(
|
||||
f"The output of the user-provided function is independent of input {i}."
|
||||
" This is not allowed in strict mode."
|
||||
)
|
||||
if not inp.requires_grad:
|
||||
if input_type == "hessian":
|
||||
raise RuntimeError(
|
||||
f"The hessian of the user-provided function with respect to input {i}"
|
||||
" is independent of the input. This is not allowed in strict mode."
|
||||
" You should ensure that your function is thrice differentiable and that"
|
||||
" the hessian depends on the inputs."
|
||||
)
|
||||
elif input_type == "jacobian":
|
||||
raise RuntimeError(
|
||||
"While computing the hessian, found that the jacobian of the user-provided"
|
||||
f" function with respect to input {i} is independent of the input. This is not"
|
||||
" allowed in strict mode. You should ensure that your function is twice"
|
||||
" differentiable and that the jacobian depends on the inputs (this would be"
|
||||
" violated by a linear function for example)."
|
||||
)
|
||||
elif input_type == "grad_inputs":
|
||||
raise RuntimeError(
|
||||
f"The gradient with respect to input {i} is independent of the inputs of the"
|
||||
" user-provided function. This is not allowed in strict mode."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Output {i} of the user-provided function does not require gradients."
|
||||
" The outputs must be computed in a differentiable manner from the input"
|
||||
" when running in strict mode."
|
||||
)
|
||||
|
||||
|
||||
def _autograd_grad(
|
||||
outputs,
|
||||
inputs,
|
||||
grad_outputs=None,
|
||||
create_graph=True,
|
||||
):
|
||||
# Version of grad that accepts `None` in outputs and do not compute gradients for them.
|
||||
# This has the extra constraint that inputs has to be a tuple
|
||||
assert isinstance(outputs, tuple)
|
||||
if grad_outputs is None:
|
||||
grad_outputs = (None,) * len(outputs)
|
||||
assert isinstance(grad_outputs, tuple)
|
||||
assert len(outputs) == len(grad_outputs)
|
||||
|
||||
new_outputs = ()
|
||||
new_grad_outputs = ()
|
||||
for out, grad_out in zip(outputs, grad_outputs):
|
||||
if out is not None and out.requires_grad:
|
||||
new_outputs += (out,)
|
||||
new_grad_outputs += (grad_out,)
|
||||
|
||||
if len(new_outputs) == 0:
|
||||
# No differentiable output, we don't need to call the autograd engine
|
||||
return (None,) * len(inputs)
|
||||
else:
|
||||
acc_loss = None
|
||||
for new_output, grad_output in zip(new_outputs, grad_outputs):
|
||||
if isinstance(new_output, jt.nn.ComplexNumber):
|
||||
if grad_output is not None:
|
||||
loss = (new_output.value * grad_output.value).sum()
|
||||
else:
|
||||
loss = new_output.value.sum()
|
||||
else:
|
||||
if grad_output is not None:
|
||||
new_output = new_output * grad_output
|
||||
loss = new_output.sum()
|
||||
if acc_loss is None:
|
||||
acc_loss = loss
|
||||
else:
|
||||
acc_loss += loss
|
||||
|
||||
complex_inds = []
|
||||
var_inputs = []
|
||||
for idx, inp in enumerate(inputs):
|
||||
if isinstance(inp, jt.nn.ComplexNumber):
|
||||
var_inputs.append(inp.value)
|
||||
complex_inds.append(idx)
|
||||
else:
|
||||
var_inputs.append(inp)
|
||||
|
||||
grads = jt.grad(acc_loss, var_inputs, retain_graph=create_graph)
|
||||
for complex_ind in complex_inds:
|
||||
grads[complex_ind] = jt.nn.ComplexNumber(grads[complex_ind], is_concat_value=True)
|
||||
return tuple(grads)
|
||||
|
||||
|
||||
def _fill_in_zeros(grads, refs, strict, create_graph, stage):
|
||||
# Used to detect None in the grads and depending on the flags, either replace them
|
||||
# with Tensors full of 0s of the appropriate size based on the refs or raise an error.
|
||||
# strict and create graph allow us to detect when it is appropriate to raise an error
|
||||
# stage gives us information of which backward call we consider to give good error message
|
||||
if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
|
||||
raise RuntimeError(f"Invalid stage argument '{stage}' to _fill_in_zeros")
|
||||
|
||||
res = ()
|
||||
for i, grads_i in enumerate(grads):
|
||||
if grads_i is None:
|
||||
if strict:
|
||||
if stage == "back":
|
||||
raise RuntimeError(
|
||||
"The output of the user-provided function is independent of "
|
||||
f"input {i}. This is not allowed in strict mode."
|
||||
)
|
||||
elif stage == "back_trick":
|
||||
raise RuntimeError(
|
||||
f"The gradient with respect to the input is independent of entry {i}"
|
||||
" in the grad_outputs when using the double backward trick to compute"
|
||||
" forward mode gradients. This is not allowed in strict mode."
|
||||
)
|
||||
elif stage == "double_back":
|
||||
raise RuntimeError(
|
||||
"The jacobian of the user-provided function is independent of "
|
||||
f"input {i}. This is not allowed in strict mode."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"The hessian of the user-provided function is independent of "
|
||||
f"entry {i} in the grad_jacobian. This is not allowed in strict "
|
||||
"mode as it prevents from using the double backward trick to "
|
||||
"replace forward mode AD."
|
||||
)
|
||||
|
||||
refs_i = refs[i]
|
||||
if isinstance(refs_i, jt.nn.ComplexNumber):
|
||||
grads_i = jt.nn.ComplexNumber(jt.zeros_like(refs_i.value), is_concat_value=True)
|
||||
else:
|
||||
grads_i = jt.zeros_like(refs_i)
|
||||
else:
|
||||
if strict and create_graph and not grads_i.requires_grad:
|
||||
if "double" not in stage:
|
||||
raise RuntimeError(
|
||||
"The jacobian of the user-provided function is independent of "
|
||||
f"input {i}. This is not allowed in strict mode when create_graph=True."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"The hessian of the user-provided function is independent of "
|
||||
f"input {i}. This is not allowed in strict mode when create_graph=True."
|
||||
)
|
||||
|
||||
res += (grads_i,)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
# Public API
|
||||
|
||||
def vjp(func, inputs, v=None, create_graph=False, strict=False):
|
||||
r"""Compute the dot product between a vector ``v`` and the Jacobian of the given function at the point given by the inputs.
|
||||
|
||||
Args:
|
||||
func (function): a Python function that takes Tensor inputs and returns
|
||||
a tuple of Tensors or a Tensor.
|
||||
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
|
||||
v (tuple of Tensors or Tensor): The vector for which the vector
|
||||
Jacobian product is computed. Must be the same size as the output
|
||||
of ``func``. This argument is optional when the output of ``func``
|
||||
contains a single element and (if it is not provided) will be set
|
||||
as a Tensor containing a single ``1``.
|
||||
create_graph (bool, optional): If ``True``, both the output and result
|
||||
will be computed in a differentiable way. Note that when ``strict``
|
||||
is ``False``, the result can not require gradients or be
|
||||
disconnected from the inputs. Defaults to ``False``.
|
||||
strict (bool, optional): If ``True``, an error will be raised when we
|
||||
detect that there exists an input such that all the outputs are
|
||||
independent of it. If ``False``, we return a Tensor of zeros as the
|
||||
vjp for said inputs, which is the expected mathematical value.
|
||||
Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
output (tuple): tuple with:
|
||||
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
|
||||
|
||||
vjp (tuple of Tensors or Tensor): result of the dot product with
|
||||
the same shape as the inputs.
|
||||
"""
|
||||
with jt.enable_grad():
|
||||
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "vjp")
|
||||
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
|
||||
|
||||
outputs = func(*inputs)
|
||||
is_outputs_tuple, outputs = _as_tuple(
|
||||
outputs, "outputs of the user-provided function", "vjp"
|
||||
)
|
||||
_check_requires_grad(outputs, "outputs", strict=strict)
|
||||
|
||||
if v is not None:
|
||||
_, v = _as_tuple(v, "v", "vjp")
|
||||
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
|
||||
_validate_v(v, outputs, is_outputs_tuple)
|
||||
else:
|
||||
if len(outputs) != 1 or outputs[0].nelement() != 1:
|
||||
raise RuntimeError(
|
||||
"The vector v can only be None if the "
|
||||
"user-provided function returns "
|
||||
"a single Tensor with a single element."
|
||||
)
|
||||
|
||||
with jt.enable_grad():
|
||||
grad_res = _autograd_grad(outputs, inputs, v, create_graph=create_graph)
|
||||
vjp = _fill_in_zeros(grad_res, inputs, strict, create_graph, "back")
|
||||
|
||||
# Cleanup objects and return them to the user
|
||||
outputs = _grad_postprocess(outputs, create_graph)
|
||||
vjp = _grad_postprocess(vjp, create_graph)
|
||||
|
||||
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
|
||||
vjp, is_inputs_tuple
|
||||
)
|
||||
|
||||
|
||||
def jvp(func, inputs, v=None, create_graph=False, strict=False):
|
||||
r"""Compute the dot product between the Jacobian of the given function at the point given by the inputs and a vector ``v``.
|
||||
|
||||
Args:
|
||||
func (function): a Python function that takes Tensor inputs and returns
|
||||
a tuple of Tensors or a Tensor.
|
||||
inputs (tuple of Tensors or Tensor): inputs to the function ``func``.
|
||||
v (tuple of Tensors or Tensor): The vector for which the Jacobian
|
||||
vector product is computed. Must be the same size as the input of
|
||||
``func``. This argument is optional when the input to ``func``
|
||||
contains a single element and (if it is not provided) will be set
|
||||
as a Tensor containing a single ``1``.
|
||||
create_graph (bool, optional): If ``True``, both the output and result
|
||||
will be computed in a differentiable way. Note that when ``strict``
|
||||
is ``False``, the result can not require gradients or be
|
||||
disconnected from the inputs. Defaults to ``False``.
|
||||
strict (bool, optional): If ``True``, an error will be raised when we
|
||||
detect that there exists an input such that all the outputs are
|
||||
independent of it. If ``False``, we return a Tensor of zeros as the
|
||||
jvp for said inputs, which is the expected mathematical value.
|
||||
Defaults to ``False``.
|
||||
|
||||
Returns:
|
||||
output (tuple): tuple with:
|
||||
func_output (tuple of Tensors or Tensor): output of ``func(inputs)``
|
||||
|
||||
jvp (tuple of Tensors or Tensor): result of the dot product with
|
||||
the same shape as the output.
|
||||
|
||||
"""
|
||||
with jt.enable_grad():
|
||||
is_inputs_tuple, inputs = _as_tuple(inputs, "inputs", "jvp")
|
||||
inputs = _grad_preprocess(inputs, create_graph=create_graph, need_graph=True)
|
||||
|
||||
if v is not None:
|
||||
_, v = _as_tuple(v, "v", "jvp")
|
||||
v = _grad_preprocess(v, create_graph=create_graph, need_graph=False)
|
||||
_validate_v(v, inputs, is_inputs_tuple)
|
||||
else:
|
||||
if len(inputs) != 1 or inputs[0].nelement() != 1:
|
||||
raise RuntimeError(
|
||||
"The vector v can only be None if the input to "
|
||||
"the user-provided function is a single Tensor "
|
||||
"with a single element."
|
||||
)
|
||||
|
||||
outputs = func(*inputs)
|
||||
is_outputs_tuple, outputs = _as_tuple(
|
||||
outputs, "outputs of the user-provided function", "jvp"
|
||||
)
|
||||
_check_requires_grad(outputs, "outputs", strict=strict)
|
||||
# The backward is linear so the value of grad_outputs is not important as
|
||||
# it won't appear in the double backward graph. We only need to ensure that
|
||||
# it does not contain inf or nan.
|
||||
grad_outputs = tuple(
|
||||
jt.nn.ComplexNumber(jt.zeros_like(out.value), is_concat_value=True) if isinstance(out, jt.nn.ComplexNumber) else jt.zeros_like(out)
|
||||
for out in outputs
|
||||
)
|
||||
|
||||
grad_inputs = _autograd_grad(outputs, inputs, grad_outputs=grad_outputs, create_graph=True)
|
||||
_check_requires_grad(grad_inputs, "grad_inputs", strict=strict)
|
||||
|
||||
if create_graph:
|
||||
with jt.enable_grad():
|
||||
grad_res = _autograd_grad(
|
||||
grad_inputs, grad_outputs, v, create_graph=create_graph
|
||||
)
|
||||
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
|
||||
else:
|
||||
grad_res = _autograd_grad(
|
||||
grad_inputs, grad_outputs, v, create_graph=create_graph
|
||||
)
|
||||
jvp = _fill_in_zeros(grad_res, outputs, strict, create_graph, "back_trick")
|
||||
|
||||
# Cleanup objects and return them to the user
|
||||
outputs = _grad_postprocess(outputs, create_graph)
|
||||
jvp = _grad_postprocess(jvp, create_graph)
|
||||
|
||||
return _tuple_postprocess(outputs, is_outputs_tuple), _tuple_postprocess(
|
||||
jvp, is_outputs_tuple
|
||||
)
|
|
@ -3130,6 +3130,10 @@ class ComplexNumber:
|
|||
assert real.dtype == imag.dtype
|
||||
self.value = jt.stack([real, imag], dim=-1)
|
||||
|
||||
@property
|
||||
def requires_grad(self):
|
||||
return self.value.requires_grad
|
||||
|
||||
@property
|
||||
def real(self):
|
||||
return self.value[..., 0]
|
||||
|
@ -3142,6 +3146,10 @@ class ComplexNumber:
|
|||
def shape(self):
|
||||
return self.value.shape[:-1]
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return "complex64"
|
||||
|
||||
def norm(self):
|
||||
return jt.sqrt(jt.sqr(self.real) + jt.sqr(self.imag))
|
||||
|
||||
|
@ -3287,6 +3295,129 @@ def view_as_complex(x: jt.Var) -> ComplexNumber:
|
|||
def view_as_real(x: ComplexNumber) -> jt.Var:
|
||||
return jt.stack([x.value[...,0],x.value[...,1]],dim=-1)
|
||||
|
||||
# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/torch/functional.py#L1258
|
||||
def tensordot(a, b, dims=2):
|
||||
r"""Returns a contraction of a and b over multiple dimensions.
|
||||
|
||||
:attr:`tensordot` implements a generalized matrix product.
|
||||
|
||||
Args:
|
||||
a (Tensor): Left tensor to contract
|
||||
b (Tensor): Right tensor to contract
|
||||
dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
|
||||
contract or explicit lists of dimensions for :attr:`a` and
|
||||
:attr:`b` respectively
|
||||
|
||||
When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
|
||||
the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
|
||||
respectively, :func:`tensordot` computes
|
||||
|
||||
.. math::
|
||||
r_{i_0,...,i_{m-d}, i_d,...,i_n}
|
||||
= \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
|
||||
|
||||
When called with :attr:`dims` of the list form, the given dimensions will be contracted
|
||||
in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
|
||||
in these dimensions must match.
|
||||
|
||||
"""
|
||||
if not isinstance(dims, (tuple, list, int)):
|
||||
raise RuntimeError(
|
||||
"tensordot expects dims to be int or "
|
||||
+ "Tuple[List[int], List[int]] or "
|
||||
+ "List[List[int]] containing two lists, but got "
|
||||
+ f"dims={dims}"
|
||||
)
|
||||
|
||||
dims_a, dims_b = [], []
|
||||
|
||||
if isinstance(dims, (tuple, list)):
|
||||
dims_a, dims_b = dims
|
||||
|
||||
if isinstance(dims, (int)):
|
||||
if dims < 0:
|
||||
raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
|
||||
if dims > min(len(a.shape), len(b.shape)):
|
||||
raise RuntimeError(
|
||||
f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
|
||||
)
|
||||
dims_a = list(range(len(a.shape)-dims, len(a.shape)))
|
||||
dims_b = list(range(dims))
|
||||
|
||||
# reference: https://github.com/pytorch/pytorch/blob/8ea5b572a63b1acc538a9fc8d3862c73739116e8/aten/src/ATen/native/Linear.cpp#L769
|
||||
def __tensordot_native(input1:jt.Var, input2:jt.Var, dims1, dims2):
|
||||
if not isinstance(dims1, (list, tuple)):
|
||||
raise RuntimeError("tensordot expects dims1 to be List[Int], but got dims={}".format(dims1))
|
||||
if not isinstance(dims2, (list, tuple)):
|
||||
raise RuntimeError("tensordot expects dims2 to be List[Int], but got dims={}".format(dims2))
|
||||
dims1 = list(dims1)
|
||||
dims2 = list(dims2)
|
||||
if len(dims1) != len(dims2):
|
||||
raise RuntimeError("both dimension lists should have the same length")
|
||||
if input1.dtype != input2.dtype:
|
||||
raise RuntimeError("both inputs should have the same dtype")
|
||||
t1 = input1
|
||||
t2 = input2
|
||||
csize = 1
|
||||
input1_bitmap = np.zeros(len(input1.shape), dtype='bool')
|
||||
input2_bitmap = np.zeros(len(input2.shape), dtype='bool')
|
||||
for i in range(len(dims1)):
|
||||
s1 = input1.shape[dims1[i]]
|
||||
s2 = input2.shape[dims2[i]]
|
||||
input1_bitmap[dims1] = True
|
||||
input2_bitmap[dims2] = True
|
||||
if s2 == 1: #broadcasted dimensions can be summed right away
|
||||
t1 = t1.sum(dims1[i], keepdims=True)
|
||||
elif s1 == 1:
|
||||
t2 = t2.sum(dims2[i], keepdims=True)
|
||||
else:
|
||||
if s1 != s2:
|
||||
raise RuntimeError("contracted dimensions need to match, but first has size {}, in dim {}, and second has size {}".format(s1, i, s2))
|
||||
csize *= s1
|
||||
|
||||
p1, p2 = [], [] # p1, p2: input permutations
|
||||
rsizes = []
|
||||
size1, size2 = 1, 1 # number of non-contracted elements
|
||||
for i in range(len(input1.shape)):
|
||||
if not input1_bitmap[i]:
|
||||
p1.append(i)
|
||||
size1 *= t1.shape[i]
|
||||
rsizes.append(t1.shape[i])
|
||||
p1 += dims1
|
||||
p2 += dims2
|
||||
for i in range(len(input2.shape)):
|
||||
if not input2_bitmap[i]:
|
||||
p2.append(i)
|
||||
size2 *= t2.shape[i]
|
||||
rsizes.append(t2.shape[i])
|
||||
|
||||
# permute and reshape for matrix multiplication
|
||||
t1 = t1.permute(p1).reshape((size1, csize))
|
||||
t2 = t2.permute(p2).reshape((csize, size2))
|
||||
# multiply and reshape to target size
|
||||
return jt.matmul(t1, t2).reshape(rsizes)
|
||||
|
||||
return __tensordot_native(a, b, dims_a, dims_b)
|
||||
|
||||
# reference: https://github.com/pytorch/pytorch/blob/5ed3b70d09a4ab2a5be4becfda9dd0d3e3227c39/aten/src/ATen/native/LinearAlgebra.cpp#L3375
|
||||
def kron(a:jt.Var, b:jt.Var):
|
||||
a_dim, b_dim = len(a.shape), len(b.shape)
|
||||
max_dim = max(a_dim, b_dim)
|
||||
pad_a, pad_b = max_dim-a_dim, max_dim-b_dim
|
||||
a_reshape, b_reshape = [], []
|
||||
result_reshape = []
|
||||
for i in range(max_dim):
|
||||
a_2i_shape = a.shape[i - pad_a] if i >= pad_a else 1
|
||||
b_2i1_shape = b.shape[i - pad_b] if i >= pad_b else 1
|
||||
a_reshape.append(a_2i_shape)
|
||||
a_reshape.append(1)
|
||||
b_reshape.append(1)
|
||||
b_reshape.append(b_2i1_shape)
|
||||
result_reshape.append(a_2i_shape * b_2i1_shape)
|
||||
a = a.reshape(a_reshape)
|
||||
b = b.reshape(b_reshape)
|
||||
return (a * b).reshape(result_reshape)
|
||||
|
||||
def one_hot(x: jt.Var, num_classes: int=-1) -> jt.Var:
|
||||
''' Returns the one_hot encoding of inputs.
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ import jittor as jt
|
|||
from jittor.nn import ComplexNumber
|
||||
import unittest
|
||||
import numpy as np
|
||||
from functools import partial
|
||||
|
||||
__skip_torch_test = False
|
||||
try:
|
||||
|
@ -10,6 +11,15 @@ except:
|
|||
__skip_torch_test = True
|
||||
|
||||
class TestResultAndGrad:
|
||||
def flatten_list(self, list_like):
|
||||
results = []
|
||||
if isinstance(list_like, (list, tuple)):
|
||||
for x in list_like:
|
||||
results.extend(self.flatten_list(x))
|
||||
return results
|
||||
else:
|
||||
return [list_like]
|
||||
|
||||
def check_results(self, rlist1, rlist2):
|
||||
assert len(rlist1) == len(rlist2)
|
||||
for r1, r2 in zip(rlist1, rlist2):
|
||||
|
@ -36,13 +46,21 @@ class TestResultAndGrad:
|
|||
grads.append(g.detach().cpu().numpy())
|
||||
return grads
|
||||
|
||||
def run_jittor_op(self, op, input_list, weights=None):
|
||||
def _np_to_jittor(x:np.ndarray):
|
||||
if x.dtype == np.complex64 or x.dtype == np.complex128:
|
||||
nx = np.stack([np.real(x), np.imag(x)], axis=-1)
|
||||
return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True)
|
||||
elif x.dtype == np.float32 or x.dtype == np.float64:
|
||||
return jt.array(x, dtype=jt.float32)
|
||||
def run_jittor_op(self, op, input_list, weights=None, key_names=None, **kwargs):
|
||||
def _np_to_jittor(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
if x.dtype == np.complex64 or x.dtype == np.complex128:
|
||||
nx = np.stack([np.real(x), np.imag(x)], axis=-1)
|
||||
return ComplexNumber(jt.array(nx, dtype=jt.float32), is_concat_value=True)
|
||||
elif x.dtype == np.float32 or x.dtype == np.float64:
|
||||
return jt.array(x, dtype=jt.float32)
|
||||
else:
|
||||
assert False
|
||||
elif isinstance(x, (list, tuple)):
|
||||
nx = [_np_to_jittor(vx) for vx in x]
|
||||
if isinstance(x, tuple):
|
||||
return tuple(nx)
|
||||
return nx
|
||||
else:
|
||||
assert False
|
||||
def _jittor_to_np(x):
|
||||
|
@ -51,11 +69,19 @@ class TestResultAndGrad:
|
|||
elif isinstance(x, ComplexNumber):
|
||||
return x.real.numpy() + 1j * x.imag.numpy()
|
||||
assert False
|
||||
|
||||
ninput_list = [_np_to_jittor(x) for x in input_list]
|
||||
output_list = op(*ninput_list)
|
||||
|
||||
if key_names != None:
|
||||
assert len(ninput_list) == len(key_names)
|
||||
nkwargs = kwargs.copy()
|
||||
for k, v in zip(key_names, ninput_list):
|
||||
nkwargs[k] = v
|
||||
output_list = op(**nkwargs)
|
||||
else:
|
||||
output_list = op(*ninput_list, **kwargs)
|
||||
if isinstance(output_list, (jt.Var, ComplexNumber)):
|
||||
output_list = [output_list]
|
||||
output_list = self.flatten_list(output_list)
|
||||
losses = []
|
||||
if weights is None:
|
||||
weights = []
|
||||
|
@ -73,15 +99,31 @@ class TestResultAndGrad:
|
|||
output_list = [_jittor_to_np(x) for x in output_list]
|
||||
return ninput_list, output_list, losses, weights
|
||||
|
||||
def run_torch_op(self, op, input_list, weights=None):
|
||||
def _np_to_torch(x:np.ndarray):
|
||||
return torch.from_numpy(x).requires_grad_(True)
|
||||
def run_torch_op(self, op, input_list, weights=None, key_names=None, **kwargs):
|
||||
def _np_to_torch(x):
|
||||
if isinstance(x, np.ndarray):
|
||||
return torch.from_numpy(x).requires_grad_(True)
|
||||
elif isinstance(x, (list, tuple)):
|
||||
nx = [_np_to_torch(vx) for vx in x]
|
||||
if isinstance(x, tuple):
|
||||
return tuple(nx)
|
||||
return nx
|
||||
else:
|
||||
assert False
|
||||
def _torch_to_np(x:torch.Tensor) -> np.ndarray:
|
||||
return x.detach().cpu().numpy()
|
||||
ninput_list = [_np_to_torch(x) for x in input_list]
|
||||
output_list = op(*ninput_list)
|
||||
if key_names != None:
|
||||
assert len(ninput_list) == len(key_names)
|
||||
nkwargs = kwargs.copy()
|
||||
for k, v in zip(key_names, ninput_list):
|
||||
nkwargs[k] = v
|
||||
output_list = op(**nkwargs)
|
||||
else:
|
||||
output_list = op(*ninput_list, **kwargs)
|
||||
if isinstance(output_list, torch.Tensor):
|
||||
output_list = [output_list]
|
||||
output_list = self.flatten_list(output_list)
|
||||
losses = []
|
||||
if weights is None:
|
||||
weights = []
|
||||
|
@ -99,10 +141,10 @@ class TestResultAndGrad:
|
|||
output_list = [_torch_to_np(x) for x in output_list]
|
||||
return ninput_list, output_list, losses, weights
|
||||
|
||||
def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True):
|
||||
def check_op_with_torch(self, jittor_op, torch_op, input_list, check_grad=True, jittor_knames=None, torch_knames=None, **kwargs):
|
||||
weights = None
|
||||
jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights)
|
||||
torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights)
|
||||
jittor_input, jittor_output, jittor_losses, weights = self.run_jittor_op(jittor_op, input_list, weights, key_names=jittor_knames, **kwargs)
|
||||
torch_input, torch_output, torch_losses, weights = self.run_torch_op(torch_op, input_list, weights, key_names=torch_knames, **kwargs)
|
||||
self.check_results(jittor_output, torch_output)
|
||||
|
||||
if check_grad:
|
||||
|
@ -195,6 +237,249 @@ class TestComplexLinalg(unittest.TestCase, TestResultAndGrad):
|
|||
inputs = [m1]
|
||||
self.check_op_with_numpy(jt.linalg.svd, np.linalg.svd, inputs)
|
||||
|
||||
class TestTensordot(unittest.TestCase, TestResultAndGrad):
|
||||
def random_complex_matrix(self, shape):
|
||||
r = np.random.randn(*shape)
|
||||
i = np.random.randn(*shape)
|
||||
return r + 1j * i
|
||||
|
||||
def random_real_matrix(self, shape):
|
||||
return np.random.randn(*shape)
|
||||
|
||||
def test_complex_tensordot_numberdim(self):
|
||||
s1 = (3, 4, 5)
|
||||
s2 = (4, 5, 6)
|
||||
dims = 2
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
|
||||
|
||||
def test_complex_tensordot_tupledim(self):
|
||||
s1 = (3, 5, 4, 6)
|
||||
s2 = (6, 4, 5, 3)
|
||||
dims = ([2, 1, 3], [1, 2, 0])
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
|
||||
|
||||
def test_real_tensordot_numberdim(self):
|
||||
s1 = (3, 4, 5)
|
||||
s2 = (4, 5, 6)
|
||||
dims = 2
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
|
||||
|
||||
def test_real_tensordot_tupledim(self):
|
||||
s1 = (3, 5, 4, 6)
|
||||
s2 = (6, 4, 5, 3)
|
||||
dims = ([2, 1, 3], [1, 2, 0])
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.tensordot, torch.tensordot, inputs, dims = dims)
|
||||
|
||||
class TestKron(unittest.TestCase, TestResultAndGrad):
|
||||
def random_complex_matrix(self, shape):
|
||||
r = np.random.randn(*shape)
|
||||
i = np.random.randn(*shape)
|
||||
return r + 1j * i
|
||||
|
||||
def random_real_matrix(self, shape):
|
||||
return np.random.randn(*shape)
|
||||
|
||||
def test_complex_firstlarge(self):
|
||||
s1 = (2, 3, 4)
|
||||
s2 = (5, 2)
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
|
||||
|
||||
def test_complex_second_large(self):
|
||||
s1 = (2, 3)
|
||||
s2 = (5, 2, 4)
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
|
||||
|
||||
def test_real_firstlarge(self):
|
||||
s1 = (2, 3, 4)
|
||||
s2 = (5, 2)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
|
||||
|
||||
def test_real_second_large(self):
|
||||
s1 = (2, 3)
|
||||
s2 = (5, 2, 4)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(jt.nn.kron, torch.kron, inputs)
|
||||
|
||||
@unittest.skipIf(__skip_torch_test, "No Torch found")
|
||||
class TestGradFunctional(unittest.TestCase, TestResultAndGrad):
|
||||
def random_complex_matrix(self, shape):
|
||||
r = np.random.randn(*shape)
|
||||
i = np.random.randn(*shape)
|
||||
return r + 1j * i
|
||||
|
||||
def random_real_matrix(self, shape):
|
||||
return np.random.randn(*shape) * 0.0 + 1.0
|
||||
|
||||
def test_real_jvp_exp(self):
|
||||
def exp_reducer(x):
|
||||
return x.exp().sum(dim=1)
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s1)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True),
|
||||
partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False)
|
||||
|
||||
def test_complex_jvp_exp(self):
|
||||
def exp_reducer(x):
|
||||
return x.exp().sum(1)
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s1)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.jvp, func=exp_reducer, create_graph=True),
|
||||
partial(torch.autograd.functional.jvp, func=exp_reducer, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
def test_real_jvp_add(self):
|
||||
w1, w2 = np.random.rand(), np.random.rand()
|
||||
def adder(x, y):
|
||||
return w1 * x + w2 * y
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s1)
|
||||
m3 = self.random_real_matrix(s1)
|
||||
m4 = self.random_real_matrix(s1)
|
||||
inputs = [(m1, m2), (m3, m4)]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.jvp, func=adder, create_graph=True),
|
||||
partial(torch.autograd.functional.jvp, func=adder, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
def test_complex_jvp_add(self):
|
||||
w1r, w1i = np.random.rand(), np.random.rand()
|
||||
w2r, w2i = np.random.rand(), np.random.rand()
|
||||
def adder_pt(x, y):
|
||||
return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y
|
||||
def adder_jt(x, y):
|
||||
w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1))
|
||||
w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1))
|
||||
return w1 * x + w2 * y
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s1)
|
||||
m3 = self.random_complex_matrix(s1)
|
||||
m4 = self.random_complex_matrix(s1)
|
||||
inputs = [(m1, m2), (m3, m4)]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.jvp, func=adder_jt, create_graph=True),
|
||||
partial(torch.autograd.functional.jvp, func=adder_pt, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
def test_real_vjp_exp(self):
|
||||
def exp_reducer(x):
|
||||
return x.exp().sum(dim=1)
|
||||
s1 = (5, 6)
|
||||
s2 = (5,)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.vjp, func=exp_reducer),
|
||||
partial(torch.autograd.functional.vjp, func=exp_reducer),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False)
|
||||
|
||||
def test_complex_vjp_exp(self):
|
||||
def exp_reducer(x):
|
||||
return x.exp().sum(1)
|
||||
s1 = (5, 6)
|
||||
s2 = (5,)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s2)
|
||||
inputs = [m1, m2]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.vjp, func=exp_reducer, create_graph=True),
|
||||
partial(torch.autograd.functional.vjp, func=exp_reducer, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
def test_real_vjp_add(self):
|
||||
w1, w2 = np.random.rand(), np.random.rand()
|
||||
def adder(x, y):
|
||||
return w1 * x + w2 * y
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_real_matrix(s1)
|
||||
m2 = self.random_real_matrix(s1)
|
||||
m3 = self.random_real_matrix(s1)
|
||||
inputs = [(m1, m2), m3]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.vjp, func=adder, create_graph=True),
|
||||
partial(torch.autograd.functional.vjp, func=adder, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
def test_complex_vjp_add(self):
|
||||
w1r, w1i = np.random.rand(), np.random.rand()
|
||||
w2r, w2i = np.random.rand(), np.random.rand()
|
||||
def adder_pt(x, y):
|
||||
return (w1r + 1j * w1i) * x + (w2r + 1j * w2i) * y
|
||||
def adder_jt(x, y):
|
||||
w1 = ComplexNumber(real=jt.array(w1r).reshape(1,1), imag = jt.array(w1i).reshape(1,1))
|
||||
w2 = ComplexNumber(real=jt.array(w2r).reshape(1,1), imag = jt.array(w2i).reshape(1,1))
|
||||
return w1 * x + w2 * y
|
||||
s1 = (5, 6)
|
||||
m1 = self.random_complex_matrix(s1)
|
||||
m2 = self.random_complex_matrix(s1)
|
||||
m3 = self.random_complex_matrix(s1)
|
||||
inputs = [(m1, m2), (m3)]
|
||||
self.check_op_with_torch(
|
||||
partial(jt.gradfunctional.vjp, func=adder_jt, create_graph=True),
|
||||
partial(torch.autograd.functional.vjp, func=adder_pt, create_graph=True),
|
||||
inputs,
|
||||
jittor_knames = ['inputs', 'v'],
|
||||
torch_knames = ['inputs', 'v'],
|
||||
check_grad=False,
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue