mirror of https://github.com/Jittor/Jittor
add nccl mpi test together
This commit is contained in:
parent
45b8375e80
commit
5e5c8de82f
|
@ -0,0 +1,21 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "mpi_warper.h"
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
#include <nccl.h>
|
||||
#include <helper_cuda.h>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
extern ncclComm_t comm;
|
||||
extern ncclUniqueId id;
|
||||
|
||||
} // jittor
|
|
@ -7,15 +7,8 @@
|
|||
#include "nccl_test_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "nccl_warper.h"
|
||||
|
||||
#ifndef JIT
|
||||
const char *_cudaGetErrorEnum(ncclResult_t error) {
|
||||
return ncclGetErrorString(error);
|
||||
}
|
||||
#endif
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -33,16 +26,41 @@ void NcclTestOp::jit_prepare() {
|
|||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
static void test_with_mpi() {
|
||||
int size = 32*1024*1024;
|
||||
int myRank = mpi_world_rank;
|
||||
int nRanks = mpi_world_size;
|
||||
int localRank = mpi_local_rank;
|
||||
|
||||
float *sendbuff, *recvbuff;
|
||||
cudaStream_t s;
|
||||
checkCudaErrors(cudaMalloc(&sendbuff, size * sizeof(float)));
|
||||
checkCudaErrors(cudaMalloc(&recvbuff, size * sizeof(float)));
|
||||
checkCudaErrors(cudaStreamCreate(&s));
|
||||
|
||||
//communicating using NCCL
|
||||
checkCudaErrors(ncclAllReduce((const void*)sendbuff, (void*)recvbuff, size, ncclFloat, ncclSum,
|
||||
comm, s));
|
||||
|
||||
//completing NCCL operation by synchronizing on the CUDA stream
|
||||
checkCudaErrors(cudaStreamSynchronize(s));
|
||||
|
||||
//free device buffers
|
||||
checkCudaErrors(cudaFree(sendbuff));
|
||||
checkCudaErrors(cudaFree(recvbuff));
|
||||
checkCudaErrors(cudaStreamDestroy(s));
|
||||
|
||||
LOGi << "MPI rank" << myRank << "Success";
|
||||
}
|
||||
|
||||
void NcclTestOp::jit_run() {
|
||||
auto args = split(cmd, " ");
|
||||
if (!cmd.size()) args.clear();
|
||||
vector<char*> v(args.size());
|
||||
for (uint i=0; i<args.size(); i++)
|
||||
v[i] = &args[i][0];
|
||||
output->ptr<T>()[0] = 123;
|
||||
if (cmd == "test_with_mpi") {
|
||||
test_with_mpi();
|
||||
return;
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
//managing 4 devices
|
||||
int nDev;
|
||||
checkCudaErrors(cudaGetDeviceCount(&nDev));
|
||||
|
|
|
@ -0,0 +1,40 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor.
|
||||
// Authors:
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "nccl_warper.h"
|
||||
|
||||
|
||||
const char *_cudaGetErrorEnum(ncclResult_t error) {
|
||||
return ncclGetErrorString(error);
|
||||
}
|
||||
|
||||
namespace jittor {
|
||||
|
||||
ncclComm_t comm;
|
||||
ncclUniqueId id;
|
||||
|
||||
|
||||
struct nccl_initer {
|
||||
|
||||
nccl_initer() {
|
||||
if (mpi_world_rank == 0)
|
||||
checkCudaErrors(ncclGetUniqueId(&id));
|
||||
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
|
||||
checkCudaErrors(cudaSetDevice(mpi_local_rank));
|
||||
checkCudaErrors(ncclCommInitRank(&comm, mpi_world_size, id, mpi_world_rank));
|
||||
}
|
||||
|
||||
~nccl_initer() {
|
||||
checkCudaErrors(ncclCommDestroy(comm));
|
||||
}
|
||||
|
||||
};
|
||||
|
||||
static nccl_initer nccl_init;
|
||||
|
||||
} // jittor
|
|
@ -6,6 +6,7 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#define OMPI_SKIP_MPICXX
|
||||
#include <mpi.h>
|
||||
|
||||
|
@ -25,6 +26,7 @@ namespace jittor {
|
|||
|
||||
extern int mpi_world_size;
|
||||
extern int mpi_world_rank;
|
||||
extern int mpi_local_rank;
|
||||
|
||||
// @pyjt(world_size)
|
||||
int _mpi_world_size();
|
||||
|
@ -32,4 +34,7 @@ int _mpi_world_size();
|
|||
// @pyjt(world_rank)
|
||||
int _mpi_world_rank();
|
||||
|
||||
// @pyjt(local_rank)
|
||||
int _mpi_local_rank();
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -6,6 +6,10 @@
|
|||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <unistd.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "mpi_warper.h"
|
||||
#include "common.h"
|
||||
|
||||
|
@ -26,6 +30,7 @@ namespace jittor {
|
|||
|
||||
int mpi_world_size;
|
||||
int mpi_world_rank;
|
||||
int mpi_local_rank;
|
||||
|
||||
int _mpi_world_size() {
|
||||
return mpi_world_size;
|
||||
|
@ -35,6 +40,31 @@ int _mpi_world_rank() {
|
|||
return mpi_world_rank;
|
||||
}
|
||||
|
||||
int _mpi_local_rank() {
|
||||
return mpi_local_rank;
|
||||
}
|
||||
|
||||
|
||||
|
||||
static uint64_t getHostHash(const char* string) {
|
||||
// Based on DJB2, result = result * 33 + char
|
||||
uint64_t result = 5381;
|
||||
for (int c = 0; string[c] != '\0'; c++){
|
||||
result = ((result << 5) + result) + string[c];
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
static void getHostName(char* hostname, int maxlen) {
|
||||
gethostname(hostname, maxlen);
|
||||
for (int i=0; i< maxlen; i++) {
|
||||
if (hostname[i] == '.') {
|
||||
hostname[i] = '\0';
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct mpi_initer {
|
||||
|
||||
|
@ -42,6 +72,18 @@ mpi_initer() {
|
|||
MPI_CHECK(MPI_Init(NULL, NULL));
|
||||
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
|
||||
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
|
||||
|
||||
//calculating localRank based on hostname which is used in selecting a GPU
|
||||
uint64_t hostHashs[mpi_world_rank];
|
||||
char hostname[1024];
|
||||
getHostName(hostname, 1024);
|
||||
hostHashs[mpi_world_rank] = getHostHash(hostname);
|
||||
MPI_CHECK(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));
|
||||
mpi_local_rank = 0;
|
||||
for (int p=0; p<mpi_world_size; p++) {
|
||||
if (p == mpi_world_rank) break;
|
||||
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_rank++;
|
||||
}
|
||||
}
|
||||
|
||||
~mpi_initer() {
|
||||
|
|
|
@ -265,7 +265,7 @@ def setup_nccl():
|
|||
global nccl_ops, use_nccl
|
||||
use_nccl = os.environ.get("use_nccl", "1")=="1"
|
||||
nccl_ops = None
|
||||
if not has_cuda:
|
||||
if not has_cuda or mpi is None:
|
||||
use_nccl = False
|
||||
return
|
||||
if not use_nccl: return
|
||||
|
@ -293,9 +293,13 @@ def setup_nccl():
|
|||
# We do not link manualy, link in custom ops
|
||||
ctypes.CDLL(nccl_lib_name, dlopen_flags)
|
||||
|
||||
nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops")
|
||||
nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)]
|
||||
nccl_ops = compile_custom_ops(nccl_op_files,
|
||||
nccl_src_dir = os.path.join(jittor_path, "extern", "cuda", "nccl")
|
||||
nccl_src_files = []
|
||||
for r, _, f in os.walk(nccl_src_dir):
|
||||
for fname in f:
|
||||
nccl_src_files.append(os.path.join(r, fname))
|
||||
|
||||
nccl_ops = compile_custom_ops(nccl_src_files,
|
||||
extra_flags=f" -I'{nccl_include_path}' {mpi_compile_flags} ")
|
||||
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
|
||||
|
||||
|
@ -353,6 +357,7 @@ def setup_mpi():
|
|||
|
||||
# mpi compile flags add for nccl
|
||||
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
|
||||
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
|
||||
|
||||
mpi = compile_custom_ops(mpi_src_files,
|
||||
extra_flags=f" {mpi_flags} ", return_module=True)
|
||||
|
|
|
@ -12,7 +12,12 @@ import jittor as jt
|
|||
import numpy as np
|
||||
|
||||
def main():
|
||||
jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||
print("test mpi_test")
|
||||
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||
if jt.compile_extern.nccl_ops:
|
||||
print("test test_with_mpi")
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
|
||||
class TestMpi(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue