polish concat

This commit is contained in:
Dun Liang 2020-12-24 20:39:05 +08:00
parent 950fd6e42a
commit 16897047aa
3 changed files with 6 additions and 125 deletions

View File

@ -40,12 +40,16 @@ Example::
cdim += a.shape[dim]
return s
def numpy_concat(arr, dim):
arr = [ a.numpy() for a in arr ]
return np.concatenate(arr, dim)
class TestConcatOp(unittest.TestCase):
def test_concat_op(self):
def check(tmp, dim=0):
res1 = jt.WIP_concat(tmp, dim=dim)
res1 = numpy_concat(tmp, dim=dim)
res2 = jt.contrib.concat(tmp, dim=dim)
assert (res1!=res2).data.sum()==0, "concat fail..."
assert (res2!=res1).data.sum()==0, "concat fail..."
check([jt.array([[1],[2]]), jt.array([[2],[2]])])
check([jt.array(np.array(range(24))).reshape((1,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])
check([jt.array(np.array(range(120))).reshape((5,2,3,4)), jt.array(np.array(range(24))).reshape((1,2,3,4))])

View File

@ -1,82 +0,0 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include "helper_cuda.h"
#endif
#include <algorithm>
#include "var.h"
#include "ops/concat_op.h"
#include <vector>
#include "executor.h"
#include "misc/cuda_flags.h"
#include "ops/op_register.h"
namespace jittor {
#ifndef JIT
ConcatOp::ConcatOp(vector<Var*>&& x, int dim)
: x(x), dim(dim) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
CHECK(x.size()>0) << "size of x cannot be empty.";
CHECK(dim==0) << "only support concat at dim 0 now.";
NanoVector shape = x[0]->shape;
NanoString type = x[0]->dtype();
uint size = x[0]->shape.size();
for (uint i = 1; i < x.size(); ++i) {
NanoVector _shape = x[i]->shape;
CHECK(x[i]->dtype()==type) << "type of x must be same.";
CHECK(_shape.size()==size) << "shape of x must have same length.";
for (uint j = 0; j < _shape.size(); ++j) {
if (j==dim) continue;
CHECK(_shape[j]==shape[j]) << "shape of x except dim must be same.";
}
}
y = create_output(nullptr, x[0]->dtype());
}
void ConcatOp::infer_shape() {
NanoVector shape;
uint concat_dim = 0;
for (Var* x : inputs()) {
concat_dim += x->shape[dim];
}
for (uint i = 0; i < x[0]->shape.size(); ++i) {
if (i != dim) {
shape.push_back(x[0]->shape[i]);
}
else {
shape.push_back(concat_dim);
}
}
y->set_shape(shape);
}
void ConcatOp::jit_prepare(JK& jk) {
jk << _CS("[T:int]");
}
VarPtr ConcatOp::grad(Var* out, Var* dout, Var* v, int v_index) {
return nullptr;
}
#else // JIT
void ConcatOp::jit_run() {
auto* y_ptr = (char*)y->mem_ptr;
for (Var* x : inputs()) {
#ifdef JIT_cpu
std::memcpy(y_ptr, x->mem_ptr, x->size);
#else
checkCudaErrors(cudaMemcpyAsync(y_ptr, x->mem_ptr, x->size, cudaMemcpyDefault, 0));
#endif
y_ptr += x->size;
}
}
#endif // JIT
} // jittor

View File

@ -1,41 +0,0 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct ConcatOp : Op {
vector<Var*> x;
Var* y;
int dim;
/**
Concat Operator can concat a list of jt Var at a specfic dimension
(WIP: this op don't have grad and working in progress, use jt.contrib.concat instead).
* [in] x: input var list for concat
* [in] dim: concat which dim
* [out] out: concat result
Example::
jt.concat([jt.array([[1],[2]]), jt.array([[2],[2]])], dim=1)
# return [[1],[2],[2],[2]]
*/
// @pybind(WIP_concat)
ConcatOp(vector<Var*>&& x, int dim=0);
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
const char* name() const override { return "concat"; }
DECLARE_jit_run;
};
} // jittor