mirror of https://github.com/Jittor/Jittor
polish concat error message
This commit is contained in:
parent
d6028df8dc
commit
950fd6e42a
|
@ -211,21 +211,24 @@ def setitem(x, slices, value):
|
||||||
jt.Var.__getitem__ = jt.Var.slice_var = getitem
|
jt.Var.__getitem__ = jt.Var.slice_var = getitem
|
||||||
jt.Var.__setitem__ = setitem
|
jt.Var.__setitem__ = setitem
|
||||||
|
|
||||||
def concat(arr, dim):
|
def concat(arr, dim=0):
|
||||||
'''Concat Operator can concat a list of jt Var at a specfic dimension.
|
'''Concat Operator can concat a list of jt Var at a specfic dimension.
|
||||||
|
|
||||||
* [in] x: input var list for concat
|
* [in] x: input var list for concat
|
||||||
|
|
||||||
* [in] dim: concat which dim
|
* [in] dim: concat which dim
|
||||||
|
|
||||||
* [out] out: concat result
|
* return: concat result
|
||||||
|
|
||||||
Example::
|
Example::
|
||||||
|
|
||||||
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
|
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
|
||||||
# return [[1],[2],[2],[2]]
|
# return [[1],[2],[2],[2]]
|
||||||
'''
|
'''
|
||||||
# TODO: low performance when concat lots of vars
|
if not isinstance(arr, Sequence):
|
||||||
|
raise TypeError("concat arr needs to be a tuple or list")
|
||||||
|
if len(arr) == 0:
|
||||||
|
raise ValueError("need at least one array to concat")
|
||||||
total_dim = 0
|
total_dim = 0
|
||||||
if dim < 0: dim += len(arr[0].shape)
|
if dim < 0: dim += len(arr[0].shape)
|
||||||
for a in arr:
|
for a in arr:
|
||||||
|
|
Loading…
Reference in New Issue