mirror of https://github.com/Jittor/Jittor
fix bug nccl op
This commit is contained in:
parent
e745e46bb3
commit
93d7f59985
|
@ -42,7 +42,6 @@ void NcclBroadcastOp::jit_run() {
|
|||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
checkCudaErrors(ncclBroadcast(xp, yp, size, ncclFloat, root, comm, 0));
|
||||
checkCudaErrors(cudaStreamSynchronize(0));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -41,8 +41,7 @@ void NcclReduceOp::jit_run() {
|
|||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
checkCudaErrors(ncclReduce(xp, yp, size, ncclFloat, root, comm, 0));
|
||||
checkCudaErrors(cudaStreamSynchronize(0));
|
||||
checkCudaErrors(ncclReduce(xp, yp, size, ncclFloat, ncclSum, root, comm, 0));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -32,7 +32,8 @@ def test_reduce():
|
|||
print("test reduce")
|
||||
mpi = jt.compile_extern.mpi
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.nccl_ops.nccl_all_reduce(x)
|
||||
y = jt.compile_extern.nccl_ops.nccl_reduce(x, 0)
|
||||
y.sync()
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
|
||||
|
|
Loading…
Reference in New Issue