mirror of https://github.com/Jittor/Jittor
Fix bug of unable to successfully cudaDeviceSynchronize
This commit is contained in:
parent
6c9bd429f6
commit
592cf0df78
|
@ -42,7 +42,6 @@ void NcclAllReduceOp::jit_run() {
|
|||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
checkCudaErrors(ncclAllReduce(xp, yp, size, ncclFloat, ncclSum, comm, 0));
|
||||
checkCudaErrors(cudaStreamSynchronize(0));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -8,6 +8,9 @@
|
|||
// ***************************************************************
|
||||
#include "nccl_warper.h"
|
||||
|
||||
#ifdef HAS_CUDA
|
||||
#include "event_queue.h"
|
||||
#endif
|
||||
|
||||
const char *_cudaGetErrorEnum(ncclResult_t error) {
|
||||
return ncclGetErrorString(error);
|
||||
|
@ -26,6 +29,11 @@ nccl_initer() {
|
|||
checkCudaErrors(ncclGetUniqueId(&id));
|
||||
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
|
||||
checkCudaErrors(cudaSetDevice(mpi_local_rank));
|
||||
#ifdef HAS_CUDA
|
||||
event_queue.run_sync([]() {
|
||||
checkCudaErrors(cudaSetDevice(mpi_local_rank));
|
||||
});
|
||||
#endif
|
||||
checkCudaErrors(ncclCommInitRank(&comm, mpi_world_size, id, mpi_world_rank));
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue