fix bug nccl op

This commit is contained in:
guowei yang 2020-04-08 19:16:46 +08:00
parent e745e46bb3
commit 93d7f59985
3 changed files with 3 additions and 4 deletions

View File

@ -42,7 +42,6 @@ void NcclBroadcastOp::jit_run() {
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
checkCudaErrors(ncclBroadcast(xp, yp, size, ncclFloat, root, comm, 0));
checkCudaErrors(cudaStreamSynchronize(0));
}
#endif

View File

@ -41,8 +41,7 @@ void NcclReduceOp::jit_run() {
int size = 1 @for(i, 0, XDIM, * xshape@{i});
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
checkCudaErrors(ncclReduce(xp, yp, size, ncclFloat, root, comm, 0));
checkCudaErrors(cudaStreamSynchronize(0));
checkCudaErrors(ncclReduce(xp, yp, size, ncclFloat, ncclSum, root, comm, 0));
}
#endif

View File

@ -32,7 +32,8 @@ def test_reduce():
print("test reduce")
mpi = jt.compile_extern.mpi
x = jt.random([5, 5])
y = jt.compile_extern.nccl_ops.nccl_all_reduce(x)
y = jt.compile_extern.nccl_ops.nccl_reduce(x, 0)
y.sync()
if mpi.world_rank() == 0:
assert np.allclose(y.data, (x*3).data)