mirror of https://github.com/Jittor/Jittor
polish fused cpu and gpu op
This commit is contained in:
parent
1987728950
commit
5062b2d6e6
|
@ -509,6 +509,11 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
||||||
var->alloc(cpu_allocator);
|
var->alloc(cpu_allocator);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
for (Var* v : op->inputs()) {
|
||||||
|
if (!v->allocator->is_cuda())
|
||||||
|
migrate_to_gpu(v, allocator);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
#ifdef NODE_MEMCHECK
|
#ifdef NODE_MEMCHECK
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
#include "op_compiler.h"
|
#include "op_compiler.h"
|
||||||
#include "profiler/profiler.h"
|
#include "profiler/profiler.h"
|
||||||
#include "misc/fast_shared_ptr.h"
|
#include "misc/fast_shared_ptr.h"
|
||||||
|
#include "misc/cuda_flags.h"
|
||||||
|
|
||||||
namespace jittor {
|
namespace jittor {
|
||||||
|
|
||||||
|
@ -42,6 +43,7 @@ void FusedOp::update_ops() {
|
||||||
loop_options_tuned.clear();
|
loop_options_tuned.clear();
|
||||||
loop_options = loop_options_origin = nullptr;
|
loop_options = loop_options_origin = nullptr;
|
||||||
|
|
||||||
|
_inputs.clear();
|
||||||
_outputs.clear();
|
_outputs.clear();
|
||||||
for (Op* op : ops) {
|
for (Op* op : ops) {
|
||||||
for (Var* o : op->outputs()) {
|
for (Var* o : op->outputs()) {
|
||||||
|
@ -101,6 +103,7 @@ void FusedOp::update_ops() {
|
||||||
if (!(c&2)) {
|
if (!(c&2)) {
|
||||||
c += 2 + vars.size()*4;
|
c += 2 + vars.size()*4;
|
||||||
vars.push_back({i, 0});
|
vars.push_back({i, 0});
|
||||||
|
_inputs.emplace_back((Node*)i);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for (Var* o : opi->outputs()) {
|
for (Var* o : opi->outputs()) {
|
||||||
|
@ -135,6 +138,7 @@ FusedOp::FusedOp(const FusedOp& other) {
|
||||||
}
|
}
|
||||||
|
|
||||||
FusedOp::~FusedOp() {
|
FusedOp::~FusedOp() {
|
||||||
|
_inputs.clear();
|
||||||
_outputs.clear();
|
_outputs.clear();
|
||||||
Op::number_of_lived_ops++;
|
Op::number_of_lived_ops++;
|
||||||
}
|
}
|
||||||
|
@ -159,20 +163,15 @@ void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
|
||||||
|
|
||||||
void FusedOp::do_jit_prepare(JK& jk) {
|
void FusedOp::do_jit_prepare(JK& jk) {
|
||||||
jk.clear();
|
jk.clear();
|
||||||
int8 flags = 3;
|
|
||||||
for (uint i=0; i<ops.size(); i++) {
|
for (uint i=0; i<ops.size(); i++) {
|
||||||
Op* op = ops[i];
|
Op* op = ops[i];
|
||||||
jk << "[opkey" << i << JK::val;
|
jk << "[opkey" << i << JK::val;
|
||||||
op->do_jit_prepare(jk);
|
jk << op->name();
|
||||||
|
op->jit_prepare(jk);
|
||||||
jk << JK::end;
|
jk << JK::end;
|
||||||
if (op->flags.get(NodeFlags::_cpu))
|
|
||||||
flags &= 1; // only cpu
|
|
||||||
else
|
|
||||||
flags &= 2; // only gpu
|
|
||||||
}
|
}
|
||||||
ASSERT(flags) << "FusedOp cannot contain both cpu and cuda ops.";
|
|
||||||
jk << _CS("[JIT:1]");
|
jk << _CS("[JIT:1]");
|
||||||
if (flags==1) {
|
if (!use_cuda) {
|
||||||
// only cpu
|
// only cpu
|
||||||
jk << _CS("[JIT_cpu:1]");
|
jk << _CS("[JIT_cpu:1]");
|
||||||
this->flags.set(NodeFlags::_cuda, 0);
|
this->flags.set(NodeFlags::_cuda, 0);
|
||||||
|
@ -189,9 +188,17 @@ void FusedOp::do_jit_prepare(JK& jk) {
|
||||||
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
|
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
|
||||||
}
|
}
|
||||||
jk << _CS("][var_info:") << JK::val;
|
jk << _CS("][var_info:") << JK::val;
|
||||||
for (auto& vi : vars)
|
bool use_int64_t = false;
|
||||||
|
for (auto& vi : vars) {
|
||||||
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
|
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
|
||||||
|
if (vi.type != 1 && vi.var->num >= std::numeric_limits<int32_t>::max())
|
||||||
|
use_int64_t = true;
|
||||||
|
}
|
||||||
jk << JK::end;
|
jk << JK::end;
|
||||||
|
if (use_int64_t)
|
||||||
|
jk << _CS("[index_t:int64]");
|
||||||
|
else
|
||||||
|
jk << _CS("[index_t:int32]");
|
||||||
if (loop_options->size()) {
|
if (loop_options->size()) {
|
||||||
if (get_loop_option("compile_shapes")) {
|
if (get_loop_option("compile_shapes")) {
|
||||||
jk << _CS("[shapes:");
|
jk << _CS("[shapes:");
|
||||||
|
|
|
@ -123,43 +123,24 @@ void Op::do_jit_prepare(JK& jk) {
|
||||||
if (has_cuda && has_cpu && !use_cuda)
|
if (has_cuda && has_cpu && !use_cuda)
|
||||||
flags.set(NodeFlags::_cuda, 0);
|
flags.set(NodeFlags::_cuda, 0);
|
||||||
} else {
|
} else {
|
||||||
// check use int64_t as index_t if array is too big
|
|
||||||
int in_id=0, out_id=0;
|
|
||||||
bool use_int64_t = false;
|
bool use_int64_t = false;
|
||||||
// TODO: fused op do not have inputs,
|
// TODO: fused op do not have inputs,
|
||||||
// check use_cuda_op from outputs may not be enough
|
// check use_cuda_op from outputs may not be enough
|
||||||
bool use_cuda_op = use_cuda;
|
bool use_cuda_op = use_cuda;
|
||||||
for (Var* var : inputs()) {
|
for (Var* var : inputs()) {
|
||||||
if (var->mem_ptr) {
|
|
||||||
/* jit key don't include here, because
|
|
||||||
parallel compiler don't known
|
|
||||||
jk << JK::key << "alloc_i" << JK::hex1(in_id)
|
|
||||||
<< JK::hex1(var->allocator->flags()) << JK::end;
|
|
||||||
*/
|
|
||||||
use_cuda_op &= var->allocator->is_cuda();
|
|
||||||
}
|
|
||||||
if (var->num >= std::numeric_limits<int32_t>::max())
|
if (var->num >= std::numeric_limits<int32_t>::max())
|
||||||
use_int64_t = true;
|
use_int64_t = true;
|
||||||
in_id ++;
|
|
||||||
}
|
}
|
||||||
for (Var* var : outputs()) {
|
for (Var* var : outputs()) {
|
||||||
if (var->mem_ptr) {
|
|
||||||
/*
|
|
||||||
jk << JK::key << "alloc_o" << JK::hex1(in_id)
|
|
||||||
<< JK::hex1(var->allocator->flags()) << JK::end;
|
|
||||||
*/
|
|
||||||
use_cuda_op &= var->allocator->is_cuda();
|
|
||||||
}
|
|
||||||
if (var->num >= std::numeric_limits<int32_t>::max())
|
if (var->num >= std::numeric_limits<int32_t>::max())
|
||||||
use_int64_t = true;
|
use_int64_t = true;
|
||||||
out_id ++;
|
|
||||||
}
|
}
|
||||||
jk << _CS("[JIT:1]");
|
jk << _CS("[JIT:1]");
|
||||||
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
|
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
|
||||||
jk << _CS("[JIT_cuda:1]");
|
jk << _CS("[JIT_cuda:1]");
|
||||||
flags.set(NodeFlags::_cpu, 0);
|
flags.set(NodeFlags::_cpu, 0);
|
||||||
// TODO: 64bit index in CUDA
|
// TODO: 64bit index in CUDA
|
||||||
use_int64_t = false;
|
// use_int64_t = false;
|
||||||
} else {
|
} else {
|
||||||
if (use_cuda==2) {
|
if (use_cuda==2) {
|
||||||
if (flags.get(NodeFlags::_cuda))
|
if (flags.get(NodeFlags::_cuda))
|
||||||
|
|
|
@ -40,12 +40,7 @@ void CopyOp::run() {
|
||||||
auto y_ptr = outputs().front()->mem_ptr;
|
auto y_ptr = outputs().front()->mem_ptr;
|
||||||
#ifdef HAS_CUDA
|
#ifdef HAS_CUDA
|
||||||
if (flags.get(NodeFlags::_cuda)) {
|
if (flags.get(NodeFlags::_cuda)) {
|
||||||
// TODO: check why cpu allocator in x
|
|
||||||
#ifdef IS_CUDA
|
|
||||||
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0));
|
|
||||||
#else
|
|
||||||
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDeviceToDevice, 0));
|
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDeviceToDevice, 0));
|
||||||
#endif
|
|
||||||
} else
|
} else
|
||||||
#endif
|
#endif
|
||||||
{
|
{
|
||||||
|
|
|
@ -101,6 +101,12 @@ class TestCuda(unittest.TestCase):
|
||||||
assert a.shape == [3,4,5] and a.dtype == 'float'
|
assert a.shape == [3,4,5] and a.dtype == 'float'
|
||||||
assert (-na.flatten() == range(3*4*5)).all(), na
|
assert (-na.flatten() == range(3*4*5)).all(), na
|
||||||
|
|
||||||
|
def test_cuda_fused_op(self):
|
||||||
|
a = jt.array([1,2,3])
|
||||||
|
a.sync()
|
||||||
|
with jt.flag_scope(use_cuda=1):
|
||||||
|
((a+a)*2).data
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(jt.compiler.has_cuda, "Only test without CUDA")
|
@unittest.skipIf(jt.compiler.has_cuda, "Only test without CUDA")
|
||||||
class TestNoCuda(unittest.TestCase):
|
class TestNoCuda(unittest.TestCase):
|
||||||
|
|
|
@ -162,7 +162,6 @@ def check_share():
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
kernel<<<1024,16*16>>>(in0_p, out0_p);
|
kernel<<<1024,16*16>>>(in0_p, out0_p);
|
||||||
LOGir << "aaa";
|
|
||||||
""").sync()
|
""").sync()
|
||||||
jt.sync_all(True)
|
jt.sync_all(True)
|
||||||
# print(a[0]+1)
|
# print(a[0]+1)
|
||||||
|
|
Loading…
Reference in New Issue