doc: added arg_reduce, broadcast_to, reduce, reshape

This commit is contained in:
lzhengning 2021-03-16 15:41:39 +08:00
parent 5dcf826ef1
commit 331964afe8
4 changed files with 318 additions and 0 deletions

View File

@ -18,6 +18,32 @@ struct ArgReduceOp : Op {
NanoString op;
int dim;
bool keepdims;
/**
Returns the indices of the maximum / minimum of the input across a dimension.
----------------
* [in] x: the input jt.Var.
* [in] op: "max" or "min".
* [in] dim: int. Specifies which dimension to reduce.
* [in] keepdim: bool. Whether the output has ``dim`` retained or not.
----------------
Example-1::
>>> x = jt.randint(0, 10, shape=(2, 3))
>>> x
jt.Var([[4 2 5]
[6 7 1]], dtype=int32)
>>> jt.arg_reduce(x, 'max', dim=1, keepdims=False)
[jt.Var([2 1], dtype=int32), jt.Var([5 7], dtype=int32)]
>>> jt.arg_reduce(x, 'min', dim=1, keepdims=False)
[jt.Var([1 2], dtype=int32), jt.Var([5 7], dtype=int32)]
*/
// @attrs(multiple_outputs)
ArgReduceOp(Var* x, NanoString op, int dim, bool keepdims);
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;

View File

@ -16,8 +16,72 @@ struct BroadcastToOp : Op {
uint16 bcast_mask;
uint16 keepdims_mask;
/**
Broadcast ``x`` to a given shape.
----------------
* [in] x: the input jt.Var.
* [in] shape: the output shape.
* [in] dims: specifies the new dimension in the output shape, an integer array.
----------------
Example-1::
>>> x = jt.randint(0, 10, shape=(2, 2))
>>> x
jt.Var([[8 1]
[7 6]], dtype=int32)
>>> jt.broadcast(x, shape=(2, 3, 2), dims=[1])
jt.Var([[[8 1]
[8 1]
[8 1]],
[[7 6]
[7 6]
[7 6]]], dtype=int32)
*/
// @pybind(broadcast)
BroadcastToOp(Var* x, NanoVector shape, NanoVector dims=NanoVector());
/**
Broadcast ``x`` to the same shape as ``y``.
----------------
* [in] x: the input jt.Var.
* [in] y: the reference jt.Var.
* [in] dims: specifies the new dimension in the output shape, an integer array.
----------------
.. note::
jt.broadcast_var(x, y, dims) is an alias of jt.broadcast(x, y, dims)
Example-1::
>>> x = jt.randint(0, 10, shape=(2, 2))
>>> x
jt.Var([[8 1]
[7 6]], dtype=int32)
>>> y = jt.randint(0, 10, shape=(2, 3, 2))
>>> jt.broadcast(x, y, dims=[1])
jt.Var([[[8 1]
[8 1]
[8 1]],
[[7 6]
[7 6]
[7 6]]], dtype=int32)
>>> jt.broadcast_var(x, y, dims=[1])
jt.Var([[[8 1]
[8 1]
[8 1]],
[[7 6]
[7 6]
[7 6]]], dtype=int32)
*/
// @pybind(broadcast,broadcast_var)
BroadcastToOp(Var* x, Var* y, NanoVector dims=NanoVector());
// @pybind(None)

View File

@ -27,22 +27,224 @@ static auto make_number = get_op_info("number")
NanoString binary_dtype_infer(NanoString op, Var* dx, Var* dy);
unordered_set<string> reduce_ops = {
/**
Returns the maximum elements in the input.
----------------
* [in] x: the input jt.Var.
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
----------------
Example-1::
>>> x = jt.randint(10, shape=(2, 3))
>>> x
jt.Var([[4 1 2]
[0 2 4]], dtype=int32)
>>> jt.max(x)
jt.Var([4], dtype=int32)
>>> x.max()
jt.Var([4], dtype=int32)
>>> x.max(dim=1)
jt.Var([4 4], dtype=int32)
>>> x.max(dim=1, keepdims=True)
jt.Var([[4]
[4]], dtype=int32)
*/
// @pybind(max, reduce_maximum)
"maximum",
/**
Returns the minimum elements in the input.
----------------
* [in] x: the input jt.Var.
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
----------------
Example-1::
>>> x = jt.randint(10, shape=(2, 3))
>>> x
jt.Var([[4 1 2]
[0 2 4]], dtype=int32)
>>> jt.min(x)
jt.Var([0], dtype=int32)
>>> x.min()
jt.Var([0], dtype=int32)
>>> x.min(dim=1)
jt.Var([1 0], dtype=int32)
>>> x.min(dim=1, keepdims=True)
jt.Var([[1]
[0]], dtype=int32)
*/
// @pybind(min, reduce_minimum)
"minimum",
/**
Returns the sum of the input.
----------------
* [in] x: the input jt.Var.
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
----------------
Example-1::
>>> x = jt.randint(10, shape=(2, 3))
>>> x
jt.Var([[4 1 2]
[0 2 4]], dtype=int32)
>>> jt.sum(x)
jt.Var([13], dtype=int32)
>>> x.sum()
jt.Var([13], dtype=int32)
>>> x.sum(dim=1)
jt.Var([7 6], dtype=int32)
>>> x.sum(dim=1, keepdims=True)
jt.Var([[7]
[6]], dtype=int32)
*/
// @pybind(sum, reduce_add)
"add",
/**
Returns the product of all the elements in the input.
----------------
* [in] x: the input jt.Var.
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
----------------
Example-1::
>>> x = jt.randint(10, shape=(2, 3))
>>> x
jt.Var([[7 5 5]
[5 7 5]], dtype=int32)
>>> jt.prod(x)
jt.Var([30625], dtype=int32)
>>> x.prod()
jt.Var([30625], dtype=int32)
>>> x.prod(dim=1)
jt.Var([175 175], dtype=int32)
>>> x.prod(dim=1, keepdims=True)
jt.Var([[175]
[175]], dtype=int32)
*/
// @pybind(prod, product, reduce_multiply)
"multiply",
/**
Tests if all elements in input evaluate to True.
----------------
* [in] x: the input jt.Var.
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
----------------
Example-1::
>>> x = jt.randint(2, shape=(2, 3))
>>> x
jt.Var([[1 1 1]
[0 1 0]], dtype=int32)
>>> jt.all_(x)
jt.Var([False], dtype=int32)
>>> x.all_()
jt.Var([False], dtype=int32)
>>> x.all_(dim=1)
jt.Var([True False], dtype=int32)
>>> x.all_(dim=1, keepdims=True)
jt.Var([[True]
[False]], dtype=int32)
*/
// @pybind(reduce_logical_and, all_)
"logical_and",
/**
Tests if any elements in input evaluate to True.
----------------
* [in] x: the input jt.Var.
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
----------------
Example-1::
>>> x = jt.randint(2, shape=(2, 3))
>>> x
jt.Var([[1 0 1]
[0 0 0]], dtype=int32)
>>> jt.any_(x)
jt.Var([True], dtype=int32)
>>> x.any_()
jt.Var([True], dtype=int32)
>>> x.any_(dim=1)
jt.Var([True False], dtype=int32)
>>> x.any_(dim=1, keepdims=True)
jt.Var([[True]
[False]], dtype=int32)
*/
// @pybind(reduce_logical_or, any_)
"logical_or",
"logical_xor",
"bitwise_and",
"bitwise_or",
"bitwise_xor",
/**
Returns the mean value of the input.
----------------
* [in] x: the input jt.Var.
* [in] dim: int or tuples of ints (optional). If specified, reduce along the given the dimension(s).
* [in] keepdim: bool (optional). Whether the output has ``dim`` retained or not. Defaults to be False.
----------------
Example-1::
>>> x = jt.randint(10, shape=(2, 3))
>>> x
jt.Var([[9 4 4]
[1 9 6]], dtype=int32)
>>> jt.mean(x)
jt.Var([5.5000005], dtype=float32)
>>> x.mean()
jt.Var([5.5000005], dtype=float32)
>>> x.mean(dim=1)
jt.Var([5.666667 5.3333335], dtype=float32)
>>> x.mean(dim=1, keepdims=True)
jt.Var([[5.666667 ]
[5.3333335]], dtype=float32)
*/
// @pybind(mean)
"mean",
};

View File

@ -15,6 +15,32 @@ namespace jittor {
struct ReshapeOp : Op {
Var* x, * y;
NanoVector shape;
/**
Returns a tensor with the same data and number of elements as input, but with the specified shape.
A single dimension may be -1, in which case its inferred from the remaining dimensions and the number of elements in input.
----------------
* [in] x: the input jt.Var
* [in] shape: the output shape, an integer array
----------------
Example-1::
>>> a = jt.randint(0, 10, shape=(12,))
>>> a
jt.Var([4 0 8 4 6 3 1 8 1 1 2 2], dtype=int32)
>>> jt.reshape(a, (3, 4))
jt.Var([[4 0 8 4]
[6 3 1 8]
[1 1 2 2]], dtype=int32)
>>> jt.reshape(a, (-1, 6))
jt.Var([[4 0 8 4 6 3]
[1 8 1 1 2 2]], dtype=int32)
*/
ReshapeOp(Var* x, NanoVector shape);
const char* name() const override { return "reshape"; }