diff --git a/python/jittor/contrib.py b/python/jittor/contrib.py index f06f44cf..e1982f7e 100644 --- a/python/jittor/contrib.py +++ b/python/jittor/contrib.py @@ -211,21 +211,24 @@ def setitem(x, slices, value): jt.Var.__getitem__ = jt.Var.slice_var = getitem jt.Var.__setitem__ = setitem -def concat(arr, dim): +def concat(arr, dim=0): '''Concat Operator can concat a list of jt Var at a specfic dimension. * [in] x: input var list for concat * [in] dim: concat which dim - * [out] out: concat result + * return: concat result Example:: jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1) # return [[1],[2],[2],[2]] ''' - # TODO: low performance when concat lots of vars + if not isinstance(arr, Sequence): + raise TypeError("concat arr needs to be a tuple or list") + if len(arr) == 0: + raise ValueError("need at least one array to concat") total_dim = 0 if dim < 0: dim += len(arr[0].shape) for a in arr: