mirror of https://github.com/Jittor/Jittor
polish fused cpu and gpu op
This commit is contained in:
parent
1987728950
commit
5062b2d6e6
|
@ -509,6 +509,11 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
var->alloc(cpu_allocator);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
for (Var* v : op->inputs()) {
|
||||
if (!v->allocator->is_cuda())
|
||||
migrate_to_gpu(v, allocator);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#ifdef NODE_MEMCHECK
|
||||
|
|
|
@ -9,6 +9,7 @@
|
|||
#include "op_compiler.h"
|
||||
#include "profiler/profiler.h"
|
||||
#include "misc/fast_shared_ptr.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -42,6 +43,7 @@ void FusedOp::update_ops() {
|
|||
loop_options_tuned.clear();
|
||||
loop_options = loop_options_origin = nullptr;
|
||||
|
||||
_inputs.clear();
|
||||
_outputs.clear();
|
||||
for (Op* op : ops) {
|
||||
for (Var* o : op->outputs()) {
|
||||
|
@ -101,6 +103,7 @@ void FusedOp::update_ops() {
|
|||
if (!(c&2)) {
|
||||
c += 2 + vars.size()*4;
|
||||
vars.push_back({i, 0});
|
||||
_inputs.emplace_back((Node*)i);
|
||||
}
|
||||
}
|
||||
for (Var* o : opi->outputs()) {
|
||||
|
@ -135,6 +138,7 @@ FusedOp::FusedOp(const FusedOp& other) {
|
|||
}
|
||||
|
||||
FusedOp::~FusedOp() {
|
||||
_inputs.clear();
|
||||
_outputs.clear();
|
||||
Op::number_of_lived_ops++;
|
||||
}
|
||||
|
@ -159,20 +163,15 @@ void FusedOp::statistics(uint64_t& in, uint64_t& out, uint64_t& compute) {
|
|||
|
||||
void FusedOp::do_jit_prepare(JK& jk) {
|
||||
jk.clear();
|
||||
int8 flags = 3;
|
||||
for (uint i=0; i<ops.size(); i++) {
|
||||
Op* op = ops[i];
|
||||
jk << "[opkey" << i << JK::val;
|
||||
op->do_jit_prepare(jk);
|
||||
jk << op->name();
|
||||
op->jit_prepare(jk);
|
||||
jk << JK::end;
|
||||
if (op->flags.get(NodeFlags::_cpu))
|
||||
flags &= 1; // only cpu
|
||||
else
|
||||
flags &= 2; // only gpu
|
||||
}
|
||||
ASSERT(flags) << "FusedOp cannot contain both cpu and cuda ops.";
|
||||
jk << _CS("[JIT:1]");
|
||||
if (flags==1) {
|
||||
if (!use_cuda) {
|
||||
// only cpu
|
||||
jk << _CS("[JIT_cpu:1]");
|
||||
this->flags.set(NodeFlags::_cuda, 0);
|
||||
|
@ -189,9 +188,17 @@ void FusedOp::do_jit_prepare(JK& jk) {
|
|||
jk << JK::hex2(i) << JK::hex1(j) << JK::hex2(k) << JK::hex1(l) << ',';
|
||||
}
|
||||
jk << _CS("][var_info:") << JK::val;
|
||||
for (auto& vi : vars)
|
||||
bool use_int64_t = false;
|
||||
for (auto& vi : vars) {
|
||||
jk << JK::hex1(vi.type) << JK::hex1(vi.var->shape.size());
|
||||
if (vi.type != 1 && vi.var->num >= std::numeric_limits<int32_t>::max())
|
||||
use_int64_t = true;
|
||||
}
|
||||
jk << JK::end;
|
||||
if (use_int64_t)
|
||||
jk << _CS("[index_t:int64]");
|
||||
else
|
||||
jk << _CS("[index_t:int32]");
|
||||
if (loop_options->size()) {
|
||||
if (get_loop_option("compile_shapes")) {
|
||||
jk << _CS("[shapes:");
|
||||
|
|
|
@ -123,43 +123,24 @@ void Op::do_jit_prepare(JK& jk) {
|
|||
if (has_cuda && has_cpu && !use_cuda)
|
||||
flags.set(NodeFlags::_cuda, 0);
|
||||
} else {
|
||||
// check use int64_t as index_t if array is too big
|
||||
int in_id=0, out_id=0;
|
||||
bool use_int64_t = false;
|
||||
// TODO: fused op do not have inputs,
|
||||
// check use_cuda_op from outputs may not be enough
|
||||
bool use_cuda_op = use_cuda;
|
||||
for (Var* var : inputs()) {
|
||||
if (var->mem_ptr) {
|
||||
/* jit key don't include here, because
|
||||
parallel compiler don't known
|
||||
jk << JK::key << "alloc_i" << JK::hex1(in_id)
|
||||
<< JK::hex1(var->allocator->flags()) << JK::end;
|
||||
*/
|
||||
use_cuda_op &= var->allocator->is_cuda();
|
||||
}
|
||||
if (var->num >= std::numeric_limits<int32_t>::max())
|
||||
use_int64_t = true;
|
||||
in_id ++;
|
||||
}
|
||||
for (Var* var : outputs()) {
|
||||
if (var->mem_ptr) {
|
||||
/*
|
||||
jk << JK::key << "alloc_o" << JK::hex1(in_id)
|
||||
<< JK::hex1(var->allocator->flags()) << JK::end;
|
||||
*/
|
||||
use_cuda_op &= var->allocator->is_cuda();
|
||||
}
|
||||
if (var->num >= std::numeric_limits<int32_t>::max())
|
||||
use_int64_t = true;
|
||||
out_id ++;
|
||||
}
|
||||
jk << _CS("[JIT:1]");
|
||||
if (use_cuda_op && flags.get(NodeFlags::_cuda)) {
|
||||
jk << _CS("[JIT_cuda:1]");
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
// TODO: 64bit index in CUDA
|
||||
use_int64_t = false;
|
||||
// use_int64_t = false;
|
||||
} else {
|
||||
if (use_cuda==2) {
|
||||
if (flags.get(NodeFlags::_cuda))
|
||||
|
|
|
@ -40,12 +40,7 @@ void CopyOp::run() {
|
|||
auto y_ptr = outputs().front()->mem_ptr;
|
||||
#ifdef HAS_CUDA
|
||||
if (flags.get(NodeFlags::_cuda)) {
|
||||
// TODO: check why cpu allocator in x
|
||||
#ifdef IS_CUDA
|
||||
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDefault, 0));
|
||||
#else
|
||||
checkCudaErrors(cudaMemcpyAsync(y_ptr, x_ptr, size, cudaMemcpyDeviceToDevice, 0));
|
||||
#endif
|
||||
} else
|
||||
#endif
|
||||
{
|
||||
|
|
|
@ -101,6 +101,12 @@ class TestCuda(unittest.TestCase):
|
|||
assert a.shape == [3,4,5] and a.dtype == 'float'
|
||||
assert (-na.flatten() == range(3*4*5)).all(), na
|
||||
|
||||
def test_cuda_fused_op(self):
|
||||
a = jt.array([1,2,3])
|
||||
a.sync()
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
((a+a)*2).data
|
||||
|
||||
|
||||
@unittest.skipIf(jt.compiler.has_cuda, "Only test without CUDA")
|
||||
class TestNoCuda(unittest.TestCase):
|
||||
|
|
|
@ -162,7 +162,6 @@ def check_share():
|
|||
}
|
||||
}
|
||||
kernel<<<1024,16*16>>>(in0_p, out0_p);
|
||||
LOGir << "aaa";
|
||||
""").sync()
|
||||
jt.sync_all(True)
|
||||
# print(a[0]+1)
|
||||
|
|
Loading…
Reference in New Issue