From d2c5d04ecfc5abae37e6b5b4aee061134d5c595e Mon Sep 17 00:00:00 2001 From: li-xl <1905692338@qq.com> Date: Wed, 16 Dec 2020 18:05:31 +0800 Subject: [PATCH] add nan checker --- python/jittor/compiler.py | 15 ++++++-- python/jittor/misc.py | 3 +- src/executor.cc | 8 +++- src/grad.cc | 1 + src/misc/nan_checker.cc | 71 ++++++++++++++++++++++++++++++++++++ src/misc/nan_checker.cu | 47 ++++++++++++++++++++++++ src/misc/nan_checker.h | 13 +++++++ src/opt/gopt/setitem_gopt.cc | 24 ++++++++++-- 8 files changed, 172 insertions(+), 10 deletions(-) create mode 100644 src/misc/nan_checker.cc create mode 100644 src/misc/nan_checker.cu create mode 100644 src/misc/nan_checker.h diff --git a/python/jittor/compiler.py b/python/jittor/compiler.py index 3839e813..7f7f86b9 100644 --- a/python/jittor/compiler.py +++ b/python/jittor/compiler.py @@ -74,10 +74,17 @@ def compile(compiler, flags, inputs, output, combind_build=False): for input, obj_file in zip(inputs, obj_files): cc = compiler nflags = oflags - if has_cuda and input.endswith(".cu"): - nflags = convert_nvcc_flags(oflags) - cc = nvcc_path + if input.endswith(".cu"): + if has_cuda: + nflags = convert_nvcc_flags(oflags) + cc = nvcc_path + else: + continue cmd = f"{cc} {input} {nflags} -c {lto_flags} -o {obj_file}" + if "nan_checker" in input: + # nan checker needs to disable fast_math + cmd = cmd.replace("--use_fast_math", "") + cmd = cmd.replace("-Ofast", "-O2") cmds.append(cmd) jit_utils.run_cmds(cmds, cache_path, jittor_path, "Compiling "+base_output) cmd = f"{compiler} {' '.join(obj_files)} {flags} {lto_flags} {link} -o {output}" @@ -945,7 +952,7 @@ pyjt_gen_src = pyjt_compiler.compile(cache_path, jittor_path) # 3. op_utils # 4. other files2 = pyjt_gen_src -files4 = run_cmd('find -L src | grep "cc$"', jittor_path).splitlines() +files4 = run_cmd('find -L src | grep "c[cu]$"', jittor_path).splitlines() at_beginning = [ "src/ops/op_utils.cc", "src/event_queue.cc", diff --git a/python/jittor/misc.py b/python/jittor/misc.py index 99fc79dd..ff01aec1 100644 --- a/python/jittor/misc.py +++ b/python/jittor/misc.py @@ -627,12 +627,11 @@ def gather(x,dim,index): return x.reindex(ins) jt.Var.gather = gather -def prod(x,dim=0): +def _prod(x,dim=0): x = jt.log(x) x = x.sum(dim=dim) return jt.exp(x) -jt.Var.prod = prod def cumsum_forward(np, data): a = data['inputs'][0] diff --git a/src/executor.cc b/src/executor.cc index 978a444e..51dea96e 100644 --- a/src/executor.cc +++ b/src/executor.cc @@ -21,6 +21,7 @@ #include "fuser.h" #include "profiler/profiler_guard.h" #include "parallel_compiler.h" +#include "misc/nan_checker.h" namespace jittor { @@ -46,7 +47,10 @@ void load_fused_op(FusedOp& fused_op, vector& fuse_ops, vector& ops, i for (Op* op : fused_op.ops) { uint fid1 = op->custom_data; int iid = 0; - for (Var* v : op->inputs()) { + for (auto ve : op->_inputs) { + // this is a control dependency edge, dont used + if (ve.back->index<0) continue; + auto v = ve.node->var(); iid++; int iop_id; int iv_id; @@ -450,6 +454,8 @@ void Executor::run_sync(vector vars, bool device_sync) { if (use_cuda) checkCudaErrors(cudaDeviceSynchronize()); #endif + for (Var* var : op->outputs()) + check_nan(var); } LOGvvv << "Finished Op(" >> op->name() << rid >> "/" >> queue.size() >> ") output:" << op->outputs(); diff --git a/src/grad.cc b/src/grad.cc index 34e5faf9..33706480 100644 --- a/src/grad.cc +++ b/src/grad.cc @@ -22,6 +22,7 @@ static auto make_number = get_op_info("number") VarPtr make_grad(Op* op, Var* out, Var* dout, Var* x, int x_index) { if (dout == nullptr) return nullptr; + if (x_index<0) return nullptr; LOGvvvv << "Make grad op:" >> op->name() << "inputs:" >> op->inputs() << "out:" >> out << "dout:" >> dout << "x:" >> x << "xid:" >> x_index; auto dx = op->grad(out, dout, x, x_index); diff --git a/src/misc/nan_checker.cc b/src/misc/nan_checker.cc new file mode 100644 index 00000000..f3a8c5bc --- /dev/null +++ b/src/misc/nan_checker.cc @@ -0,0 +1,71 @@ +// *************************************************************** +// Copyright (c) 2020 Jittor. All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#include +#include +#include "misc/nan_checker.h" +#include "misc/cuda_flags.h" +#include +#include "helper_cuda.h" +#include "mem/allocator.h" +#include "op.h" + +namespace jittor { + + +#ifdef HAS_CUDA +extern void check_nan_float32(float32* ptr, int64 num); +extern void check_nan_float64(float64* ptr, int64 num); +#endif + +bool check_nan(Var* v) { + if (!v->dtype().is_float()) return true; + if (v->input() && ( + v->input()->name() == string("empty") || + v->input()->name() == string("setitem"))) + return true; + #ifdef HAS_CUDA + if (v->allocator->is_cuda()) { + if (v->dtype() == ns_float32) { + check_nan_float32((float32*)v->mem_ptr, v->num); + } else + if (v->dtype() == ns_float64) { + check_nan_float64((float64*)v->mem_ptr, v->num); + } + ASSERT(cudaDeviceSynchronize()==0) << "detect nan or inf at" << v; + } else + #endif + { + if (v->dtype() == ns_float32) { + auto* __restrict__ ptr = v->ptr(); + auto num = v->num; + bool ok = true; + int64 i=0; + for (; idtype() == ns_float64) { + auto* __restrict__ ptr = v->ptr(); + auto num = v->num; + bool ok = true; + int64 i=0; + for (; i +#include "helper_cuda.h" +#include + +namespace jittor { + + +#ifdef HAS_CUDA +__global__ void _check_nan_float32(float32* __restrict__ ptr, int64 num) { + int64 i = threadIdx.x + blockIdx.x * (int64)blockDim.x; + if (i>>(ptr, num); +} + +void check_nan_float32(float32* ptr, int64 num) { + int block_num = std::max((int64)1, (num-1)/1024+1); + int thread_num = std::min((int64)1024, num); + _check_nan_float32<<>>(ptr, num); +} + +#endif + +} \ No newline at end of file diff --git a/src/misc/nan_checker.h b/src/misc/nan_checker.h new file mode 100644 index 00000000..1b865c96 --- /dev/null +++ b/src/misc/nan_checker.h @@ -0,0 +1,13 @@ +// *************************************************************** +// Copyright (c) 2020 Jittor. All Rights Reserved. +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "var.h" + +namespace jittor { + +bool check_nan(Var* v); + +} \ No newline at end of file diff --git a/src/opt/gopt/setitem_gopt.cc b/src/opt/gopt/setitem_gopt.cc index 5a9a0dcd..c3b97e02 100644 --- a/src/opt/gopt/setitem_gopt.cc +++ b/src/opt/gopt/setitem_gopt.cc @@ -16,6 +16,17 @@ inline static bool fast_strcmp(const char* a, const char* b) { return !*b; } +// add dependency b -> a +static inline void add_dependency(Node* a, const vector& b) { + a->add_inputs(b); + auto edge = a->_inputs.end(); + for (int i=0; iback->index = -1; + } +} + static void setitem_inplace(SetitemOp* op) { // LOGir << "in setitem_inplace"; auto input = op->inputs().front(); @@ -37,7 +48,7 @@ static void setitem_inplace(SetitemOp* op) { } auto output = op->outputs().front(); output->share_with(input); - return; + // return; // LOGir << "pass setitem optim one"; @@ -52,7 +63,12 @@ static void setitem_inplace(SetitemOp* op) { } VarSlices vs = op->vs; - if (!(data->is_finished() == 0 && (data->outputs().size() == 1 || (!input_op || input_op->inputs().size() == 0)))) + if (!(data->is_finished() == 0 && + (data->outputs().size() == 1 || + (!input_op + || input_op->inputs().size() == 0)))) + return; + if (data->allocator) return; auto in_shape = input->shape; @@ -73,7 +89,7 @@ static void setitem_inplace(SetitemOp* op) { else if (s.is_slice()) size = s.slice.start * input->size / in_shape[0]; - data->input()->add_inputs(vector{input}); + add_dependency(data->input(), {input->node()}); data->share_with(input, size); // LOGir << "pass setitem optim two"; } @@ -176,6 +192,7 @@ static void getitem_inplace(GetitemOp* op) { void SetitemOp::graph_optimize() { // LOGir << "hello graph_optimize"; setitem_inplace(this); + (void)setitem_inplace; } void GetitemOp::graph_optimize() { @@ -185,6 +202,7 @@ void GetitemOp::graph_optimize() { (void)setitem_grad_opt; // (void)getitem_inplace; getitem_inplace(this); + (void)getitem_inplace; } }