mirror of https://github.com/Jittor/Jittor
WIP: concat
This commit is contained in:
parent
802cf56162
commit
3fca6fa321
|
@ -755,3 +755,4 @@ Var.double = Var.float64
|
|||
from . import nn
|
||||
from .nn import matmul
|
||||
from . import contrib
|
||||
from .contrib import concat
|
||||
|
|
|
@ -15,6 +15,19 @@ def argmax_pool(x, size, stride, padding=0):
|
|||
return pool.pool(x, size, 'maximum', padding, stride)
|
||||
|
||||
def concat(arr, dim):
|
||||
'''Concat Operator can concat a list of jt Var at a specfic dimension.
|
||||
|
||||
* [in] x: input var list for concat
|
||||
|
||||
* [in] dim: concat which dim
|
||||
|
||||
* [out] out: concat result
|
||||
|
||||
Example::
|
||||
|
||||
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
|
||||
# return [[1],[2],[2],[2]]
|
||||
'''
|
||||
# TODO: low performance when concat lots of vars
|
||||
total_dim = 0
|
||||
if dim < 0: dim += len(arr[0].shape)
|
||||
|
|
|
@ -11,7 +11,7 @@ import numpy as np
|
|||
class TestConcatOp(unittest.TestCase):
|
||||
def test_concat_op(self):
|
||||
def check(tmp, dim=0):
|
||||
res1 = jt.concat(tmp, dim=dim)
|
||||
res1 = jt.WIP_concat(tmp, dim=dim)
|
||||
res2 = jt.contrib.concat(tmp, dim=dim)
|
||||
assert (res1!=res2).data.sum()==0, "concat fail..."
|
||||
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
|
||||
|
|
|
@ -14,7 +14,8 @@ struct ConcatOp : Op {
|
|||
Var* y;
|
||||
int dim;
|
||||
/**
|
||||
Concat Operator can concat a list of jt Var at a specfic dimension.
|
||||
Concat Operator can concat a list of jt Var at a specfic dimension
|
||||
(WIP: this op don't have grad and working in progress, use jt.contrib.concat instead).
|
||||
|
||||
* [in] x: input var list for concat
|
||||
|
||||
|
@ -27,6 +28,7 @@ struct ConcatOp : Op {
|
|||
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
|
||||
# return [[1],[2],[2],[2]]
|
||||
*/
|
||||
// @pybind(WIP_concat)
|
||||
ConcatOp(vector<Var*>&& x, int dim=0);
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
|
|
Loading…
Reference in New Issue