mirror of https://github.com/Jittor/Jittor
polish concat
This commit is contained in:
parent
ef66d6d832
commit
4285d5d61d
|
@ -206,6 +206,23 @@ def setitem(x, slices, value):
|
|||
jt.Var.__getitem__ = jt.Var.slice_var = getitem
|
||||
jt.Var.__setitem__ = setitem
|
||||
|
||||
|
||||
def _merge_dtypes(dtypes):
|
||||
s = -1
|
||||
e = -1
|
||||
names = ["bool","uint","int","float"]
|
||||
dbytes = ["8","16","32","64"]
|
||||
for d in dtypes:
|
||||
for name in names:
|
||||
if d.startswith(name):
|
||||
s = max(s,names.index(name))
|
||||
for db in dbytes:
|
||||
if d.endswith(db):
|
||||
e = max(e,dbytes.index(db))
|
||||
assert s>=0 and s<4 and e<4
|
||||
dtype = names[s]+("" if e ==-1 else dbytes[e])
|
||||
return dtype
|
||||
|
||||
def concat(arr, dim=0):
|
||||
'''Concat Operator can concat a list of jt Var at a specfic dimension.
|
||||
|
||||
|
@ -226,12 +243,14 @@ Example::
|
|||
raise ValueError("need at least one array to concat")
|
||||
total_dim = 0
|
||||
if dim < 0: dim += len(arr[0].shape)
|
||||
dtypes = []
|
||||
for a in arr:
|
||||
total_dim += a.shape[dim]
|
||||
dtypes.append(str(a.dtype))
|
||||
cdim = 0
|
||||
shape = list(a.shape)
|
||||
shape[dim] = total_dim
|
||||
s = jt.empty(shape, a.dtype)
|
||||
s = jt.empty(shape, dtype = _merge_dtypes(dtypes))
|
||||
slices = [slice(None)]*len(a.shape)
|
||||
for a in arr:
|
||||
if a.shape[dim] == 0:
|
||||
|
|
Loading…
Reference in New Issue