polish concat

This commit is contained in:
li-xl 2021-06-25 16:26:48 +08:00
parent ef66d6d832
commit 4285d5d61d
1 changed files with 20 additions and 1 deletions

View File

@ -206,6 +206,23 @@ def setitem(x, slices, value):
jt.Var.__getitem__ = jt.Var.slice_var = getitem
jt.Var.__setitem__ = setitem
def _merge_dtypes(dtypes):
s = -1
e = -1
names = ["bool","uint","int","float"]
dbytes = ["8","16","32","64"]
for d in dtypes:
for name in names:
if d.startswith(name):
s = max(s,names.index(name))
for db in dbytes:
if d.endswith(db):
e = max(e,dbytes.index(db))
assert s>=0 and s<4 and e<4
dtype = names[s]+("" if e ==-1 else dbytes[e])
return dtype
def concat(arr, dim=0):
'''Concat Operator can concat a list of jt Var at a specfic dimension.
@ -226,12 +243,14 @@ Example::
raise ValueError("need at least one array to concat")
total_dim = 0
if dim < 0: dim += len(arr[0].shape)
dtypes = []
for a in arr:
total_dim += a.shape[dim]
dtypes.append(str(a.dtype))
cdim = 0
shape = list(a.shape)
shape[dim] = total_dim
s = jt.empty(shape, a.dtype)
s = jt.empty(shape, dtype = _merge_dtypes(dtypes))
slices = [slice(None)]*len(a.shape)
for a in arr:
if a.shape[dim] == 0: