WIP: concat

This commit is contained in:
Dun Liang 2020-07-15 14:47:40 +08:00
parent 802cf56162
commit 3fca6fa321
4 changed files with 18 additions and 2 deletions

View File

@ -755,3 +755,4 @@ Var.double = Var.float64
from . import nn
from .nn import matmul
from . import contrib
from .contrib import concat

View File

@ -15,6 +15,19 @@ def argmax_pool(x, size, stride, padding=0):
return pool.pool(x, size, 'maximum', padding, stride)
def concat(arr, dim):
'''Concat Operator can concat a list of jt Var at a specfic dimension.
* [in] x: input var list for concat
* [in] dim: concat which dim
* [out] out: concat result
Example::
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
# return [[1],[2],[2],[2]]
'''
# TODO: low performance when concat lots of vars
total_dim = 0
if dim < 0: dim += len(arr[0].shape)

View File

@ -11,7 +11,7 @@ import numpy as np
class TestConcatOp(unittest.TestCase):
def test_concat_op(self):
def check(tmp, dim=0):
res1 = jt.concat(tmp, dim=dim)
res1 = jt.WIP_concat(tmp, dim=dim)
res2 = jt.contrib.concat(tmp, dim=dim)
assert (res1!=res2).data.sum()==0, "concat fail..."
check([jt.array([[1],[2]]), jt.array([[2],[2]])])

View File

@ -14,7 +14,8 @@ struct ConcatOp : Op {
Var* y;
int dim;
/**
Concat Operator can concat a list of jt Var at a specfic dimension.
Concat Operator can concat a list of jt Var at a specfic dimension
(WIP: this op don't have grad and working in progress, use jt.contrib.concat instead).
* [in] x: input var list for concat
@ -27,6 +28,7 @@ struct ConcatOp : Op {
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
# return [[1],[2],[2],[2]]
*/
// @pybind(WIP_concat)
ConcatOp(vector<Var*>&& x, int dim=0);
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;