mirror of https://github.com/Jittor/Jittor
fix pow cuda grad nan && mpi verison
This commit is contained in:
parent
273be0db93
commit
68608dd74b
|
@ -361,7 +361,8 @@ def setup_mpi():
|
|||
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
|
||||
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
|
||||
|
||||
if get_version(mpicc_path).startswith("(1."):
|
||||
mpi_version = get_version(mpicc_path)
|
||||
if mpi_version.startswith("(1.") or mpi_version.startswith("(2."):
|
||||
# mpi version 1.x need to link like this
|
||||
manual_link(mpi_flags)
|
||||
# mpi(4.x) cannot use deepbind, it need to
|
||||
|
|
|
@ -129,5 +129,13 @@ jt.mkl_ops.mkl_conv(x, w, 1, 2).sync()
|
|||
assert a.min().data == a.data.min(), (a.min(), a.data.min())
|
||||
assert a.max().data == a.data.max(), (a.max(), a.data.max())
|
||||
|
||||
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_cuda_pow_grad_nan(self):
|
||||
a = jt.float32([1,-1, -1000.1])
|
||||
da = jt.grad(a**2, a)
|
||||
assert np.isnan(da.data).sum()==0, da.data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -13,6 +13,8 @@
|
|||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
static auto make_array = get_op_info("array")
|
||||
.get_constructor<VarPtr, const void*, NanoVector, NanoString>();
|
||||
static auto make_broadcast_to = get_op_info("broadcast_to")
|
||||
.get_constructor<VarPtr, Var*, Var*, NanoVector>();
|
||||
static auto make_binary = get_op_info("binary")
|
||||
|
@ -122,7 +124,9 @@ VarPtr BinaryOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
|||
if (v_index == 0) {
|
||||
// dout * y * x^(y-1)
|
||||
auto d = make_binary(dout, y, ns_multiply);
|
||||
auto ones = make_number(1, dout);
|
||||
// auto ones = make_number(1, dout);
|
||||
int number = 1;
|
||||
auto ones = make_array(&number, 1, ns_int32);
|
||||
auto y_1 = make_binary(y, ones, ns_subtract);
|
||||
auto x_y_1 = make_binary(x, y_1, ns_pow);
|
||||
return make_binary(d, x_y_1, ns_multiply);
|
||||
|
|
Loading…
Reference in New Issue