polish concat error message

This commit is contained in:
Dun Liang 2020-12-24 20:33:31 +08:00
parent d6028df8dc
commit 950fd6e42a
1 changed files with 6 additions and 3 deletions

View File

@ -211,21 +211,24 @@ def setitem(x, slices, value):
jt.Var.__getitem__ = jt.Var.slice_var = getitem
jt.Var.__setitem__ = setitem
def concat(arr, dim):
def concat(arr, dim=0):
'''Concat Operator can concat a list of jt Var at a specfic dimension.
* [in] x: input var list for concat
* [in] dim: concat which dim
* [out] out: concat result
* return: concat result
Example::
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
# return [[1],[2],[2],[2]]
'''
# TODO: low performance when concat lots of vars
if not isinstance(arr, Sequence):
raise TypeError("concat arr needs to be a tuple or list")
if len(arr) == 0:
raise ValueError("need at least one array to concat")
total_dim = 0
if dim < 0: dim += len(arr[0].shape)
for a in arr: