fix pow cuda grad nan && mpi verison

This commit is contained in:
Dun Liang 2020-04-15 16:31:32 +08:00
parent 273be0db93
commit 68608dd74b
3 changed files with 15 additions and 2 deletions

View File

@ -361,7 +361,8 @@ def setup_mpi():
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
if get_version(mpicc_path).startswith("(1."):
mpi_version = get_version(mpicc_path)
if mpi_version.startswith("(1.") or mpi_version.startswith("(2."):
# mpi version 1.x need to link like this
manual_link(mpi_flags)
# mpi(4.x) cannot use deepbind, it need to

View File

@ -129,5 +129,13 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
assert a.min().data == a.data.min(), (a.min(), a.data.min())
assert a.max().data == a.data.max(), (a.max(), a.data.max())
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_cuda_pow_grad_nan(self):
a = jt.float32([1,-1, -1000.1])
da = jt.grad(a**2, a)
assert np.isnan(da.data).sum()==0, da.data
if __name__ == "__main__":
unittest.main()

View File

@ -13,6 +13,8 @@
namespace jittor {
#ifndef JIT
static auto make_array = get_op_info("array")
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
static auto make_broadcast_to = get_op_info("broadcast_to")
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
static auto make_binary = get_op_info("binary")
@ -122,7 +124,9 @@ VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index == 0) {
// dout * y * x^(y-1)
auto d = make_binary(dout, y, ns_multiply);
auto ones = make_number(1, dout);
// auto ones = make_number(1, dout);
int number = 1;
auto ones = make_array(&number, 1, ns_int32);
auto y_1 = make_binary(y, ones, ns_subtract);
auto x_y_1 = make_binary(x, y_1, ns_pow);
return make_binary(d, x_y_1, ns_multiply);