add nccl mpi test together

This commit is contained in:
Dun Liang 2020-04-03 14:54:12 +08:00
parent 45b8375e80
commit 5e5c8de82f
7 changed files with 156 additions and 20 deletions

21
extern/cuda/nccl/inc/nccl_warper.h vendored Normal file
View File

@ -0,0 +1,21 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "mpi_warper.h"
#include <cuda_runtime.h>
#include <nccl.h>
#include <helper_cuda.h>
namespace jittor {
extern ncclComm_t comm;
extern ncclUniqueId id;
} // jittor

View File

@ -7,15 +7,8 @@
#include "nccl_test_op.h"
#include "misc/str_utils.h"
#include <nccl.h>
#include <cuda_runtime.h>
#include <helper_cuda.h>
#include "nccl_warper.h"
#ifndef JIT
const char *_cudaGetErrorEnum(ncclResult_t error) {
return ncclGetErrorString(error);
}
#endif
namespace jittor {
@ -33,16 +26,41 @@ void NcclTestOp::jit_prepare() {
#else // JIT
#ifdef JIT_cuda
static void test_with_mpi() {
int size = 32*1024*1024;
int myRank = mpi_world_rank;
int nRanks = mpi_world_size;
int localRank = mpi_local_rank;
float *sendbuff, *recvbuff;
cudaStream_t s;
checkCudaErrors(cudaMalloc(&sendbuff, size * sizeof(float)));
checkCudaErrors(cudaMalloc(&recvbuff, size * sizeof(float)));
checkCudaErrors(cudaStreamCreate(&s));
//communicating using NCCL
checkCudaErrors(ncclAllReduce((const void*)sendbuff, (void*)recvbuff, size, ncclFloat, ncclSum,
comm, s));
//completing NCCL operation by synchronizing on the CUDA stream
checkCudaErrors(cudaStreamSynchronize(s));
//free device buffers
checkCudaErrors(cudaFree(sendbuff));
checkCudaErrors(cudaFree(recvbuff));
checkCudaErrors(cudaStreamDestroy(s));
LOGi << "MPI rank" << myRank << "Success";
}
void NcclTestOp::jit_run() {
auto args = split(cmd, " ");
if (!cmd.size()) args.clear();
vector<char*> v(args.size());
for (uint i=0; i<args.size(); i++)
v[i] = &args[i][0];
output->ptr<T>()[0] = 123;
if (cmd == "test_with_mpi") {
test_with_mpi();
return;
}
//managing 4 devices
int nDev;
checkCudaErrors(cudaGetDeviceCount(&nDev));

40
extern/cuda/nccl/src/nccl_warper.cc vendored Normal file
View File

@ -0,0 +1,40 @@
// ***************************************************************
// Copyright (c) 2020 Jittor.
// Authors:
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "nccl_warper.h"
const char *_cudaGetErrorEnum(ncclResult_t error) {
return ncclGetErrorString(error);
}
namespace jittor {
ncclComm_t comm;
ncclUniqueId id;
struct nccl_initer {
nccl_initer() {
if (mpi_world_rank == 0)
checkCudaErrors(ncclGetUniqueId(&id));
MPI_CHECK(MPI_Bcast((void *)&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD));
checkCudaErrors(cudaSetDevice(mpi_local_rank));
checkCudaErrors(ncclCommInitRank(&comm, mpi_world_size, id, mpi_world_rank));
}
~nccl_initer() {
checkCudaErrors(ncclCommDestroy(comm));
}
};
static nccl_initer nccl_init;
} // jittor

View File

@ -6,6 +6,7 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#define OMPI_SKIP_MPICXX
#include <mpi.h>
@ -25,6 +26,7 @@ namespace jittor {
extern int mpi_world_size;
extern int mpi_world_rank;
extern int mpi_local_rank;
// @pyjt(world_size)
int _mpi_world_size();
@ -32,4 +34,7 @@ int _mpi_world_size();
// @pyjt(world_rank)
int _mpi_world_rank();
// @pyjt(local_rank)
int _mpi_local_rank();
} // jittor

View File

@ -6,6 +6,10 @@
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <unistd.h>
#include <stdint.h>
#include <stdio.h>
#include "mpi_warper.h"
#include "common.h"
@ -26,6 +30,7 @@ namespace jittor {
int mpi_world_size;
int mpi_world_rank;
int mpi_local_rank;
int _mpi_world_size() {
return mpi_world_size;
@ -35,6 +40,31 @@ int _mpi_world_rank() {
return mpi_world_rank;
}
int _mpi_local_rank() {
return mpi_local_rank;
}
static uint64_t getHostHash(const char* string) {
// Based on DJB2, result = result * 33 + char
uint64_t result = 5381;
for (int c = 0; string[c] != '\0'; c++){
result = ((result << 5) + result) + string[c];
}
return result;
}
static void getHostName(char* hostname, int maxlen) {
gethostname(hostname, maxlen);
for (int i=0; i< maxlen; i++) {
if (hostname[i] == '.') {
hostname[i] = '\0';
return;
}
}
}
struct mpi_initer {
@ -42,6 +72,18 @@ mpi_initer() {
MPI_CHECK(MPI_Init(NULL, NULL));
MPI_CHECK(MPI_Comm_size(MPI_COMM_WORLD, &mpi_world_size));
MPI_CHECK(MPI_Comm_rank(MPI_COMM_WORLD, &mpi_world_rank));
//calculating localRank based on hostname which is used in selecting a GPU
uint64_t hostHashs[mpi_world_rank];
char hostname[1024];
getHostName(hostname, 1024);
hostHashs[mpi_world_rank] = getHostHash(hostname);
MPI_CHECK(MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, hostHashs, sizeof(uint64_t), MPI_BYTE, MPI_COMM_WORLD));
mpi_local_rank = 0;
for (int p=0; p<mpi_world_size; p++) {
if (p == mpi_world_rank) break;
if (hostHashs[p] == hostHashs[mpi_world_rank]) mpi_local_rank++;
}
}
~mpi_initer() {

View File

@ -265,7 +265,7 @@ def setup_nccl():
global nccl_ops, use_nccl
use_nccl = os.environ.get("use_nccl", "1")=="1"
nccl_ops = None
if not has_cuda:
if not has_cuda or mpi is None:
use_nccl = False
return
if not use_nccl: return
@ -293,9 +293,13 @@ def setup_nccl():
# We do not link manualy, link in custom ops
ctypes.CDLL(nccl_lib_name, dlopen_flags)
nccl_op_dir = os.path.join(jittor_path, "extern", "cuda", "nccl", "ops")
nccl_op_files = [os.path.join(nccl_op_dir, name) for name in os.listdir(nccl_op_dir)]
nccl_ops = compile_custom_ops(nccl_op_files,
nccl_src_dir = os.path.join(jittor_path, "extern", "cuda", "nccl")
nccl_src_files = []
for r, _, f in os.walk(nccl_src_dir):
for fname in f:
nccl_src_files.append(os.path.join(r, fname))
nccl_ops = compile_custom_ops(nccl_src_files,
extra_flags=f" -I'{nccl_include_path}' {mpi_compile_flags} ")
LOG.vv("Get nccl_ops: "+str(dir(nccl_ops)))
@ -353,6 +357,7 @@ def setup_mpi():
# mpi compile flags add for nccl
mpi_compile_flags += f" -I'{os.path.join(mpi_src_dir, 'inc')}' "
mpi_compile_flags = mpi_compile_flags.replace("-pthread", "")
mpi = compile_custom_ops(mpi_src_files,
extra_flags=f" {mpi_flags} ", return_module=True)

View File

@ -12,7 +12,12 @@ import jittor as jt
import numpy as np
def main():
jt.compile_extern.mpi_ops.mpi_test("").data == 123
print("test mpi_test")
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
if jt.compile_extern.nccl_ops:
print("test test_with_mpi")
with jt.flag_scope(use_cuda=1):
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
class TestMpi(unittest.TestCase):