mirror of https://github.com/Jittor/Jittor
polish concat
This commit is contained in:
parent
950fd6e42a
commit
16897047aa
|
@ -40,12 +40,16 @@ Example::
|
|||
cdim += a.shape[dim]
|
||||
return s
|
||||
|
||||
def numpy_concat(arr, dim):
|
||||
arr = [ a.numpy() for a in arr ]
|
||||
return np.concatenate(arr, dim)
|
||||
|
||||
class TestConcatOp(unittest.TestCase):
|
||||
def test_concat_op(self):
|
||||
def check(tmp, dim=0):
|
||||
res1 = jt.WIP_concat(tmp, dim=dim)
|
||||
res1 = numpy_concat(tmp, dim=dim)
|
||||
res2 = jt.contrib.concat(tmp, dim=dim)
|
||||
assert (res1!=res2).data.sum()==0, "concat fail..."
|
||||
assert (res2!=res1).data.sum()==0, "concat fail..."
|
||||
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
|
||||
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
|
||||
check([jt.array(np.array(range(120))).reshape((5,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
|
||||
|
|
|
@ -1,82 +0,0 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#ifdef HAS_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#include "helper_cuda.h"
|
||||
#endif
|
||||
#include <algorithm>
|
||||
#include "var.h"
|
||||
#include "ops/concat_op.h"
|
||||
#include <vector>
|
||||
#include "executor.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
#include "ops/op_register.h"
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
||||
ConcatOp::ConcatOp(vector<Var*>&& x, int dim)
|
||||
: x(x), dim(dim) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
CHECK(x.size()>0) << "size of x cannot be empty.";
|
||||
CHECK(dim==0) << "only support concat at dim 0 now.";
|
||||
NanoVector shape = x[0]->shape;
|
||||
NanoString type = x[0]->dtype();
|
||||
uint size = x[0]->shape.size();
|
||||
for (uint i = 1; i < x.size(); ++i) {
|
||||
NanoVector _shape = x[i]->shape;
|
||||
CHECK(x[i]->dtype()==type) << "type of x must be same.";
|
||||
CHECK(_shape.size()==size) << "shape of x must have same length.";
|
||||
for (uint j = 0; j < _shape.size(); ++j) {
|
||||
if (j==dim) continue;
|
||||
CHECK(_shape[j]==shape[j]) << "shape of x except dim must be same.";
|
||||
}
|
||||
}
|
||||
y = create_output(nullptr, x[0]->dtype());
|
||||
}
|
||||
|
||||
void ConcatOp::infer_shape() {
|
||||
NanoVector shape;
|
||||
uint concat_dim = 0;
|
||||
for (Var* x : inputs()) {
|
||||
concat_dim += x->shape[dim];
|
||||
}
|
||||
for (uint i = 0; i < x[0]->shape.size(); ++i) {
|
||||
if (i != dim) {
|
||||
shape.push_back(x[0]->shape[i]);
|
||||
}
|
||||
else {
|
||||
shape.push_back(concat_dim);
|
||||
}
|
||||
}
|
||||
y->set_shape(shape);
|
||||
}
|
||||
void ConcatOp::jit_prepare(JK& jk) {
|
||||
jk << _CS("[T:int]");
|
||||
}
|
||||
|
||||
VarPtr ConcatOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
|
||||
void ConcatOp::jit_run() {
|
||||
auto* y_ptr = (char*)y->mem_ptr;
|
||||
for (Var* x : inputs()) {
|
||||
#ifdef JIT_cpu
|
||||
std::memcpy(y_ptr, x->mem_ptr, x->size);
|
||||
#else
|
||||
checkCudaErrors(cudaMemcpyAsync(y_ptr, x->mem_ptr, x->size, cudaMemcpyDefault, 0));
|
||||
#endif
|
||||
y_ptr += x->size;
|
||||
}
|
||||
}
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -1,41 +0,0 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct ConcatOp : Op {
|
||||
vector<Var*> x;
|
||||
Var* y;
|
||||
int dim;
|
||||
/**
|
||||
Concat Operator can concat a list of jt Var at a specfic dimension
|
||||
(WIP: this op don't have grad and working in progress, use jt.contrib.concat instead).
|
||||
|
||||
* [in] x: input var list for concat
|
||||
|
||||
* [in] dim: concat which dim
|
||||
|
||||
* [out] out: concat result
|
||||
|
||||
Example::
|
||||
|
||||
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
|
||||
# return [[1],[2],[2],[2]]
|
||||
*/
|
||||
// @pybind(WIP_concat)
|
||||
ConcatOp(vector<Var*>&& x, int dim=0);
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "concat"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
Loading…
Reference in New Issue