From 5062b2d6e6b9670d220f15485733127a194c8e08 Mon Sep 17 00:00:00 2001 From: Dun Liang Date: Tue, 22 Mar 2022 22:47:59 +0800 Subject: [PATCH] polish fused cpu and gpu op --- python/jittor/src/executor.cc | 5 +++++ python/jittor/src/fused_op.cc | 25 ++++++++++++++++--------- python/jittor/src/op.cc | 21 +-------------------- python/jittor/src/ops/copy_op.cc | 5 ----- python/jittor/test/test_cuda.py | 6 ++++++ python/jittor/test/test_fp16.py | 1 - 6 files changed, 28 insertions(+), 35 deletions(-) diff --git a/python/jittor/src/executor.cc b/python/jittor/src/executor.cc index 64afde7f..186eeb0c 100644 --- a/python/jittor/src/executor.cc +++ b/python/jittor/src/executor.cc @@ -509,6 +509,11 @@ void Executor::run_sync(vector vars, bool device_sync) { var->alloc(cpu_allocator); } } + } else { + for (Var* v : op->inputs()) { + if (!v->allocator->is_cuda()) + migrate_to_gpu(v, allocator); + } } #endif #ifdef NODE_MEMCHECK diff --git a/python/jittor/src/fused_op.cc b/python/jittor/src/fused_op.cc index 74da506b..924a1113 100644 --- a/python/jittor/src/fused_op.cc +++ b/python/jittor/src/fused_op.cc @@ -9,6 +9,7 @@ #include "op_compiler.h" #include "profiler/profiler.h" #include "misc/fast_shared_ptr.h" +#include "misc/cuda_flags.h" namespace jittor { @@ -42,6 +43,7 @@ void FusedOp::update_ops() { loop_options_tuned.clear(); loop_options = loop_options_origin = nullptr; + _inputs.clear(); _outputs.clear(); for (Op* op : ops) { for (Var* o : op->outputs()) { @@ -101,6 +103,7 @@ void FusedOp::update_ops() { if (!(c&2)) { c += 2 + vars.size()*4; vars.push_back({i, 0}); + _inputs.emplace_back((Node*)i); } } for (Var* o : opi->outputs()) { @@ -135,6 +138,7 @@ FusedOp::FusedOp(const FusedOp& other) { } FusedOp::~FusedOp() { + _inputs.clear(); _outputs.clear(); Op::number_of_lived_ops++; } @@ -159,20 +163,15 @@ void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) { void FusedOp::do_jit_prepare(JK& jk) { jk.clear(); - int8 flags = 3; for (uint i=0; ido_jit_prepare(jk); + jk << op->name(); + op->jit_prepare(jk); jk << JK::end; - if (op->flags.get(NodeFlags::_cpu)) - flags &= 1; // only cpu - else - flags &= 2; // only gpu } - ASSERT(flags) << "FusedOp cannot contain both cpu and cuda ops."; jk << _CS("[JIT:1]"); - if (flags==1) { + if (!use_cuda) { // only cpu jk << _CS("[JIT_cpu:1]"); this->flags.set(NodeFlags::_cuda, 0); @@ -189,9 +188,17 @@ void FusedOp::do_jit_prepare(JK& jk) { jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ','; } jk << _CS("][var_info:") << JK::val; - for (auto& vi : vars) + bool use_int64_t = false; + for (auto& vi : vars) { jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size()); + if (vi.type != 1 && vi.var->num >= std::numeric_limits::max()) + use_int64_t = true; + } jk << JK::end; + if (use_int64_t) + jk << _CS("[index_t:int64]"); + else + jk << _CS("[index_t:int32]"); if (loop_options->size()) { if (get_loop_option("compile_shapes")) { jk << _CS("[shapes:"); diff --git a/python/jittor/src/op.cc b/python/jittor/src/op.cc index c24b8b30..35a42f06 100644 --- a/python/jittor/src/op.cc +++ b/python/jittor/src/op.cc @@ -123,43 +123,24 @@ void Op::do_jit_prepare(JK& jk) { if (has_cuda && has_cpu && !use_cuda) flags.set(NodeFlags::_cuda, 0); } else { - // check use int64_t as index_t if array is too big - int in_id=0, out_id=0; bool use_int64_t = false; // TODO: fused op do not have inputs, // check use_cuda_op from outputs may not be enough bool use_cuda_op = use_cuda; for (Var* var : inputs()) { - if (var->mem_ptr) { - /* jit key don't include here, because - parallel compiler don't known - jk << JK::key << "alloc_i" << JK::hex1(in_id) - << JK::hex1(var->allocator->flags()) << JK::end; - */ - use_cuda_op &= var->allocator->is_cuda(); - } if (var->num >= std::numeric_limits::max()) use_int64_t = true; - in_id ++; } for (Var* var : outputs()) { - if (var->mem_ptr) { - /* - jk << JK::key << "alloc_o" << JK::hex1(in_id) - << JK::hex1(var->allocator->flags()) << JK::end; - */ - use_cuda_op &= var->allocator->is_cuda(); - } if (var->num >= std::numeric_limits::max()) use_int64_t = true; - out_id ++; } jk << _CS("[JIT:1]"); if (use_cuda_op && flags.get(NodeFlags::_cuda)) { jk << _CS("[JIT_cuda:1]"); flags.set(NodeFlags::_cpu, 0); // TODO: 64bit index in CUDA - use_int64_t = false; + // use_int64_t = false; } else { if (use_cuda==2) { if (flags.get(NodeFlags::_cuda)) diff --git a/python/jittor/src/ops/copy_op.cc b/python/jittor/src/ops/copy_op.cc index 3dc3ff9a..5d62e1c8 100644 --- a/python/jittor/src/ops/copy_op.cc +++ b/python/jittor/src/ops/copy_op.cc @@ -40,12 +40,7 @@ void CopyOp::run() { auto y_ptr = outputs().front()->mem_ptr; #ifdef HAS_CUDA if (flags.get(NodeFlags::_cuda)) { - // TODO: check why cpu allocator in x - #ifdef IS_CUDA - checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0)); - #else checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDeviceToDevice, 0)); - #endif } else #endif { diff --git a/python/jittor/test/test_cuda.py b/python/jittor/test/test_cuda.py index 4c28436c..16a16546 100644 --- a/python/jittor/test/test_cuda.py +++ b/python/jittor/test/test_cuda.py @@ -101,6 +101,12 @@ class TestCuda(unittest.TestCase): assert a.shape == [3,4,5] and a.dtype == 'float' assert (-na.flatten() == range(3*4*5)).all(), na + def test_cuda_fused_op(self): + a = jt.array([1,2,3]) + a.sync() + with jt.flag_scope(use_cuda=1): + ((a+a)*2).data + @unittest.skipIf(jt.compiler.has_cuda, "Only test without CUDA") class TestNoCuda(unittest.TestCase): diff --git a/python/jittor/test/test_fp16.py b/python/jittor/test/test_fp16.py index fe599386..86569a59 100644 --- a/python/jittor/test/test_fp16.py +++ b/python/jittor/test/test_fp16.py @@ -162,7 +162,6 @@ def check_share(): } } kernel<<<1024,16*16>>>(in0_p, out0_p); - LOGir << "aaa"; """).sync() jt.sync_all(True) # print(a[0]+1)