mirror of https://github.com/Jittor/Jittor
polish concat error message
This commit is contained in:
parent
d6028df8dc
commit
950fd6e42a
|
@ -211,21 +211,24 @@ def setitem(x, slices, value):
|
|||
jt.Var.__getitem__ = jt.Var.slice_var = getitem
|
||||
jt.Var.__setitem__ = setitem
|
||||
|
||||
def concat(arr, dim):
|
||||
def concat(arr, dim=0):
|
||||
'''Concat Operator can concat a list of jt Var at a specfic dimension.
|
||||
|
||||
* [in] x: input var list for concat
|
||||
|
||||
* [in] dim: concat which dim
|
||||
|
||||
* [out] out: concat result
|
||||
* return: concat result
|
||||
|
||||
Example::
|
||||
|
||||
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
|
||||
# return [[1],[2],[2],[2]]
|
||||
'''
|
||||
# TODO: low performance when concat lots of vars
|
||||
if not isinstance(arr, Sequence):
|
||||
raise TypeError("concat arr needs to be a tuple or list")
|
||||
if len(arr) == 0:
|
||||
raise ValueError("need at least one array to concat")
|
||||
total_dim = 0
|
||||
if dim < 0: dim += len(arr[0].shape)
|
||||
for a in arr:
|
||||
|
|
Loading…
Reference in New Issue