Initial commit

This commit is contained in:
Sylvain Jeaugey 2017-08-08 16:18:34 -07:00
commit b188a15299
12 changed files with 2083 additions and 0 deletions

27
LICENSE.txt Normal file
View File

@ -0,0 +1,27 @@
Copyright (c) 2016-2017, NVIDIA CORPORATION. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of NVIDIA CORPORATION, nor the names of their
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

20
Makefile Normal file
View File

@ -0,0 +1,20 @@
#
# Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
#
# See LICENCE.txt for license information
#
.PHONY : all clean
default : src.build
TARGETS=src
all: ${TARGETS:%=%.build}
clean: ${TARGETS:%=%.clean}
%.build:
${MAKE} -C $* build
%.clean:
${MAKE} -C $* clean

62
README.md Normal file
View File

@ -0,0 +1,62 @@
# NCCL Tests
These tests check both the performance and the correctness of NCCL operations. They can be compiled against [NCCL 1](http://github.com/nvidia/nccl) and [NCCL 2](http://developer.nvidia.com/nccl).
## Build
To build the tests, just type `make`.
If CUDA is not installed in /usr/local/cuda, you may specify CUDA\_HOME. Similarly, if NCCL is not installed in /usr, you may specify NCCL\_HOME.
```shell
$ make CUDA_HOME=/path/to/cuda NCCL_HOME=/path/to/nccl
```
NCCL tests rely on MPI to work on multiple processes, hence multiple nodes. If you want to compile the tests with MPI support, you need to set MPI=1 and set MPI\_HOME to the path where MPI is installed.
```shell
$ make MPI=1 MPI_HOME=/path/to/mpi CUDA_HOME=/path/to/cuda NCCL_HOME=/path/to/nccl
```
## Usage
NCCL tests can run on multiple processes, multiple threads, and multiple CUDA devices per thread. The number of process is managed by MPI and is therefore not passed to the tests as argument. The total number of ranks (=CUDA devices) will be equal to (number of processes)\*(number of threads)\*(number of gpus per thread).
### Quick examples
Run on 8 GPUs (`-g 8`), scanning from 8 Bytes to 128MBytes :
```shell
$ ./build/all_reduce_perf -b 8 -e 128M -f 2 -g 8
```
Run with MPI on 40 processes (potentially on multiple nodes) with 4 GPUs each, disabling checks :
```shell
$ mpirun -np 40 ./build/all_reduce_perf -b 8 -e 128M -f 2 -g 4 -c 0
```
All tests support the same arguments :
* Number of GPUs
* `-t,--nthreads <num threads>` number of threads per process. Default : 1.
* `-g,--ngpus <gpus per thread>` number of gpus per process. Default : 1.
* Sizes to scan
* `-b,--minbytes <min size in bytes>` minimum size to start with. Default : 32M.
* `-e,--maxbytes <max size in bytes>` maximum size to end at. Default : 32M.
* Increments can be either fixes of a multiplication factor. Only one of those should be used
* `-i,--stepbytes <increment size>` fixed increment between sizes. Default : (max-min)/10.
* `-f,--stepfactor <increment factor>` multiplication factor between sizes. Default : disabled.
* Performance
* `-n,--iters <iteration count>` number of iterations. Default : 20.
* `-w,--warmup_iters <warmup iteration count>` number of warmup iterations (not timed). Default : 5.
* `-s,--swap_args <0/1>` when used with multiple threads, have threads manage different GPUs for each iteration. Default : 0.
* `-p,--parallel_init <0/1>` use threads to initialize NCCL in parallel.
* `-c,--check <0/1>` check correctness of results. This can be quite slow on large numbers of GPUs. Default : 1.
* NCCL operations arguments
* `-o,--op <sum/prod/min/max/all>` Specify which reduction operation to perform. Only relevant for reduction operations. Default : Sum.
* `-d,--datatype <nccltype/all>` Specify which datatype to use. Default : Float.
* `-r,--root <root/all>` Specify which root to use. Only for operations with a root like broadcast or reduce.
* `-z,--blocking <0/1>` Make NCCL collective blocking, i.e. have CPUs wait and sync after each collective. Default : 0.
## Copyright
NCCL tests are provided under the BSD licence. All source code and accompanying documentation is copyright (c) 2016-2017, NVIDIA CORPORATION. All rights reserved.

78
src/Makefile Normal file
View File

@ -0,0 +1,78 @@
#
# Copyright (c) 2015-2017, NVIDIA CORPORATION. All rights reserved.
#
# See LICENCE.txt for license information
#
CUDA_HOME ?= /usr/local/cuda
PREFIX ?= /usr/local
VERBOSE ?= 0
DEBUG ?= 0
CUDA_LIB ?= $(CUDA_HOME)/lib64
CUDA_INC ?= $(CUDA_HOME)/include
NVCC = $(CUDA_HOME)/bin/nvcc
# Better define NVCC_GENCODE in your environment to the minimal set
# of archs to reduce compile time.
NVCC_GENCODE ?= -gencode=arch=compute_30,code=sm_30 \
-gencode=arch=compute_35,code=sm_35 \
-gencode=arch=compute_50,code=sm_50 \
-gencode=arch=compute_52,code=sm_52 \
-gencode=arch=compute_60,code=sm_60 \
-gencode=arch=compute_61,code=sm_61 \
-gencode=arch=compute_61,code=compute_61
NVCUFLAGS := -ccbin $(CXX) $(NVCC_GENCODE) -std=c++11
LDFLAGS := -L${CUDA_LIB} -lcudart -lrt
NVLDFLAGS := -L${CUDA_LIB} -lcudart -lrt
ifeq ($(DEBUG), 0)
NVCUFLAGS += -O3
CXXFLAGS += -O3
else
NVCUFLAGS += -O0 -G -g
CXXFLAGS += -O0 -g -ggdb3
endif
ifeq ($(VERBOSE), 0)
.SILENT:
endif
.PHONY: build clean
BUILDDIR ?= ../build
ifneq ($(NCCLDIR), "")
NVCUFLAGS += -I$(NCCLDIR)/include/
NVLDFLAGS += -L$(NCCLDIR)/lib
endif
ifeq ($(MPI), 1)
NVCUFLAGS += -DMPI_SUPPORT -I$(MPI_HOME)/include
NVLDFLAGS += -L$(MPI_HOME)/lib -lmpi
endif
LIBRARIES += curand nccl nvToolsExt
NVLDFLAGS += $(LIBRARIES:%=-l%)
DST_DIR := $(BUILDDIR)
SRC_FILES := $(wildcard *.cu)
OBJ_FILES := $(SRC_FILES:%.cu=${DST_DIR}/%.o)
BIN_FILES_LIST := all_reduce all_gather broadcast reduce_scatter reduce
BIN_FILES := $(BIN_FILES_LIST:%=${DST_DIR}/%_perf)
build: ${BIN_FILES}
clean:
rm -rf ${DST_DIR}
${DST_DIR}/%.o: %.cu
@printf "Compiling %-35s > %s\n" $< $@
@mkdir -p ${DST_DIR}
$(NVCC) -o $@ $(NVCUFLAGS) -c $<
${DST_DIR}/%_perf:${DST_DIR}/%.o ${DST_DIR}/common.o
@printf "Linking %-35s > %s\n" $< $@
@mkdir -p ${DST_DIR}
$(NVCC) -o $@ $(NVCUFLAGS) $^ ${NVLDFLAGS}

106
src/all_gather.cu Normal file
View File

@ -0,0 +1,106 @@
/*************************************************************************
* Copyright (c) 2016-2017, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
************************************************************************/
#include "cuda_runtime.h"
#include "common.h"
void print_header() {
PRINT("# %10s %12s %6s %6s out-of-place in-place\n", "", "", "", "");
PRINT("# %10s %12s %6s %7s %5s %5s %7s %7s %5s %5s %7s\n", "bytes", "N", "type",
"time", "algbw", "busbw", "res", "time", "algbw", "busbw", "res");
}
void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) {
PRINT("%12li %12li %6s", size, count, typeName);
}
void getCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t *procSharedCount, int *sameExpected, size_t count, int nranks) {
*sendcount = count/nranks;
*recvcount = (count/nranks)*nranks;
*sameExpected = 1;
*procSharedCount = 0;
*sendInplaceOffset = count/nranks;
*recvInplaceOffset = 0;
*paramcount = *sendcount;
}
void InitRecvResult(struct threadArgs_t* args, ncclDataType_t type, ncclRedOp_t op, int root, int in_place, int is_first) {
size_t nBytes = args->nbytes;
size_t count = nBytes / wordSize(type);
int proc = args->proc;
int nThreads = args->nThreads;
int t = args->thread;
int nGpus = args->nGpus;
while (args->sync[args->sync_idx] != t) pthread_yield();
for (int i=0; i<nGpus; i++) {
int device;
int rank = ((args->proc*args->nThreads + args->thread)*args->nGpus + i);
NCCLCHECK(ncclCommCuDevice(args->comms[i], &device));
CUDACHECK(cudaSetDevice(device));
void* data = in_place ? (void *)((uintptr_t)args->recvbuffs[i] + args->sendInplaceOffset*rank) : args->sendbuffs[i];
CUDACHECK(cudaMemcpy((void *)((uintptr_t)args->expectedHost[0] + ((proc*nThreads + t)*nGpus + i)*nBytes),
data,
nBytes, cudaMemcpyDeviceToHost));
if (in_place == 0) {
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
}
CUDACHECK(cudaDeviceSynchronize());
}
args->sync[args->sync_idx] = t + 1;
if (t+1 == nThreads) {
#ifdef MPI_SUPPORT
// Last thread does the MPI allgather
MPI_Allgather(MPI_IN_PLACE, nBytes*nThreads*nGpus, MPI_BYTE,
args->expectedHost[0],
nBytes*nThreads*nGpus, MPI_BYTE, MPI_COMM_WORLD);
#endif
args->sync[args->sync_idx] = 0;
} else {
while (args->sync[args->sync_idx]) pthread_yield();
}
args->sync_idx=!args->sync_idx;
}
void GetBw(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks) {
double baseBw = (double)(count * typesize * (nranks - 1)) / 1.0E9 / sec;
*algBw = baseBw;
double factor = 1;
*busBw = baseBw * factor;
}
void RunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
NCCLCHECK(ncclAllGather(sendbuff, recvbuff, count, type, comm, stream));
}
void RunTest(struct threadArgs_t* args, int root, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName) {
ncclDataType_t *run_types;
const char **run_typenames;
int type_count;
if ((int)type != -1) {
type_count = 1;
run_types = &type;
run_typenames = &typeName;
} else {
type_count = ncclNumTypes;
run_types = test_types;
run_typenames = test_typenames;
}
for (int i=0; i<type_count; i++) {
TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, NULL, 0, 1);
}
}

130
src/all_reduce.cu Normal file
View File

@ -0,0 +1,130 @@
/*************************************************************************
* Copyright (c) 2016-2017, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
************************************************************************/
#include "cuda_runtime.h"
#include "common.h"
void print_header() {
PRINT("# %10s %12s %6s %6s out-of-place in-place\n", "", "", "", "");
PRINT("# %10s %12s %6s %6s %7s %5s %5s %7s %7s %5s %5s %7s\n", "bytes", "N", "type", "op",
"time", "algbw", "busbw", "res", "time", "algbw", "busbw", "res");
}
void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) {
PRINT("%12li %12li %6s %6s", size, count, typeName, opName);
}
void getCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t *procSharedCount, int *sameExpected, size_t count, int nranks) {
*sendcount = count;
*recvcount = count;
*sameExpected = 1;
*procSharedCount = 0;
*sendInplaceOffset = 0;
*recvInplaceOffset = 0;
*paramcount = *sendcount;
}
void InitRecvResult(struct threadArgs_t* args, ncclDataType_t type, ncclRedOp_t op, int root, int in_place, int is_first) {
size_t count = args->nbytes / wordSize(type);
while (args->sync[args->sync_idx] != args->thread) pthread_yield();
for (int i=0; i<args->nGpus; i++) {
int device;
NCCLCHECK(ncclCommCuDevice(args->comms[i], &device));
CUDACHECK(cudaSetDevice(device));
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
if (is_first && i == 0) {
CUDACHECK(cudaMemcpy(args->expected[0], data, count*wordSize(type), cudaMemcpyDeviceToHost));
} else {
Accumulate(args->expected[0], data, count, type, op);
}
if (in_place == 0) {
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->nbytes));
}
CUDACHECK(cudaDeviceSynchronize());
}
args->sync[args->sync_idx] = args->thread + 1;
if (args->thread+1 == args->nThreads) {
#ifdef MPI_SUPPORT
// Last thread does the MPI reduction
if (args->nbytes > 0) {
void* remote, *remoteHost = malloc(args->nbytes);
void* myInitialData = malloc(args->nbytes);
memcpy(myInitialData, args->expectedHost[0], args->nbytes);
CUDACHECK(cudaHostRegister(remoteHost, args->nbytes, cudaHostRegisterPortable | cudaHostRegisterMapped));
CUDACHECK(cudaHostGetDevicePointer(&remote, remoteHost, 0));
for (int i=0; i<args->nProcs; i++) {
if (i == args->proc) {
MPI_Bcast(myInitialData, args->nbytes, MPI_BYTE, i, MPI_COMM_WORLD);
free(myInitialData);
} else {
MPI_Bcast(remoteHost, args->nbytes, MPI_BYTE, i, MPI_COMM_WORLD);
Accumulate(args->expected[0], remote, count, type, op);
cudaDeviceSynchronize();
}
}
CUDACHECK(cudaHostUnregister(remoteHost));
free(remoteHost);
}
#endif
args->sync[args->sync_idx] = 0;
} else {
while (args->sync[args->sync_idx]) pthread_yield();
}
args->sync_idx = !args->sync_idx;
}
void GetBw(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks) {
double baseBw = (double)(count * typesize) / 1.0E9 / sec;
*algBw = baseBw;
double factor = ((double)(2*(nranks - 1)))/((double)nranks);
*busBw = baseBw * factor;
}
void RunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
NCCLCHECK(ncclAllReduce(sendbuff, recvbuff, count, type, op, comm, stream));
}
void RunTest(struct threadArgs_t* args, int root, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName) {
ncclDataType_t *run_types;
ncclRedOp_t *run_ops;
const char **run_typenames, **run_opnames;
int type_count, op_count;
if ((int)type != -1) {
type_count = 1;
run_types = &type;
run_typenames = &typeName;
} else {
type_count = ncclNumTypes;
run_types = test_types;
run_typenames = test_typenames;
}
if ((int)op != -1) {
op_count = 1;
run_ops = &op;
run_opnames = &opName;
} else {
op_count = ncclNumOps;
run_ops = test_ops;
run_opnames = test_opnames;
}
for (int i=0; i<type_count; i++) {
for (int j=0; j<op_count; j++) {
TimeTest(args, run_types[i], run_typenames[i], run_ops[j], run_opnames[j], 0, 1);
}
}
}

121
src/broadcast.cu Normal file
View File

@ -0,0 +1,121 @@
/*************************************************************************
* Copyright (c) 2016-2017, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
************************************************************************/
#include "cuda_runtime.h"
#include "common.h"
#include <assert.h>
void print_header() {
PRINT("# %10s %12s %6s %6s out-of-place\n", "", "", "", "");
PRINT("# %10s %12s %6s %6s %7s %5s %5s %7s\n", "bytes", "N", "type", "root",
"time", "algbw", "busbw", "res");
}
void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) {
PRINT("%12li %12li %6s %6i", size, count, typeName, root);
}
void getCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t *procSharedCount, int *sameExpected, size_t count, int nranks) {
*sendcount = count;
*recvcount = count;
*sameExpected = 0;
*procSharedCount = count;
*sendInplaceOffset = 0;
*recvInplaceOffset = 0;
*paramcount = *sendcount;
}
void InitRecvResult(struct threadArgs_t* args, ncclDataType_t type, ncclRedOp_t op, int root, int in_place, int is_first) {
int root_proc = root/(args->nThreads*args->nGpus);
int root_thread = (root/args->nGpus)%(args->nThreads);
int root_gpu = root%args->nGpus;
assert(args->expectedBytes == args->nbytes);
if (root_thread == args->thread) {
if (root_proc == args->proc) {
CUDACHECK(cudaMemcpy(args->procSharedHost,
args->sendbuffs[root_gpu],
args->nbytes, cudaMemcpyDeviceToHost));
}
#ifdef MPI_SUPPORT
MPI_Bcast(args->procSharedHost, args->nbytes, MPI_BYTE, root_proc, MPI_COMM_WORLD);
#endif
args->sync[0] = 0;
}
Barrier(args);
for (int i=0; i<args->nGpus; i++) {
int device;
NCCLCHECK(ncclCommCuDevice(args->comms[i], &device));
CUDACHECK(cudaSetDevice(device));
//set expected buf to zero at root, copy over source data at others
if ((root_proc == args->proc)
&& (root_thread == args->thread)
&& (root_gpu == i)) {
memset(args->expectedHost[i], 0, args->nbytes);
} else {
memcpy(args->expectedHost[i], args->procSharedHost, args->nbytes);
}
//reset recvbufs to zero
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->nbytes));
CUDACHECK(cudaDeviceSynchronize());
}
Barrier(args);
}
void GetBw(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks) {
double baseBw = (double)(count * typesize) / 1.0E9 / sec;
*algBw = baseBw;
double factor = 1;
*busBw = baseBw * factor;
}
void RunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
int rank;
NCCLCHECK(ncclCommUserRank(comm, &rank));
if (rank == root) {
NCCLCHECK(ncclBcast(sendbuff, count, type, root, comm, stream));
} else {
NCCLCHECK(ncclBcast(recvbuff, count, type, root, comm, stream));
}
}
void RunTest(struct threadArgs_t* args, int root, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName) {
ncclDataType_t *run_types;
const char **run_typenames;
int type_count;
int begin_root, end_root;
if ((int)type != -1) {
type_count = 1;
run_types = &type;
run_typenames = &typeName;
} else {
type_count = ncclNumTypes;
run_types = test_types;
run_typenames = test_typenames;
}
if (root != -1) {
begin_root = end_root = root;
} else {
begin_root = 0;
end_root = args->nProcs*args->nThreads*args->nGpus-1;
}
for (int i=0; i<type_count; i++) {
for (int j=begin_root; j<=end_root; j++) {
TimeTest(args, run_types[i], run_typenames[i], (ncclRedOp_t)0, NULL, j, 0);
}
}
}

1036
src/common.cu Normal file

File diff suppressed because it is too large Load Diff

158
src/common.h Normal file
View File

@ -0,0 +1,158 @@
/*************************************************************************
* Copyright (c) 2016-2017, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
************************************************************************/
#include "nccl.h"
#include <stdio.h>
#include <algorithm>
#include <curand.h>
#ifdef MPI_SUPPORT
#include "mpi.h"
#endif
#include <pthread.h>
#include "nccl1_compat.h"
#define CUDACHECK(cmd) do { \
cudaError_t e = cmd; \
if( e != cudaSuccess ) { \
printf("Cuda failure %s:%d '%s'\n", \
__FILE__,__LINE__,cudaGetErrorString(e)); \
exit(EXIT_FAILURE); \
} \
} while(0)
#define NCCLCHECK(cmd) do { \
ncclResult_t r = cmd; \
if (r!= ncclSuccess) { \
printf("NCCL failure %s:%d '%s'\n", \
__FILE__,__LINE__,ncclGetErrorString(r)); \
exit(EXIT_FAILURE); \
} \
} while(0)
struct threadArgs_t {
void *proc_args;
size_t nbytes;
size_t minbytes;
size_t maxbytes;
size_t stepbytes;
size_t stepfactor;
int nProcs;
int proc;
int nThreads;
int thread;
int nGpus;
int localRank;
void** sendbuffs;
size_t sendBytes;
size_t sendInplaceOffset;
void** recvbuffs;
size_t recvInplaceOffset;
ncclUniqueId ncclId;
ncclComm_t* comms;
cudaStream_t* streams;
void** expectedHost;
void** expected;
size_t expectedBytes;
void* procSharedHost;
void* procShared;
volatile int* sync;
int sync_idx;
volatile int* barrier;
int barrier_idx;
int syncRank;
int syncNranks;
double* deltaThreads;
double* deltaHost;
double* delta;
int* errors;
double* bw;
int* bw_count;
};
#include <chrono>
// Provided by common.cu
extern void Barrier(struct threadArgs_t* args);
extern void TimeTest(struct threadArgs_t* args, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName, int root, int inPlace);
extern void Randomize(void* ptr, size_t count, ncclDataType_t type, int seed);
extern void Accumulate(void* out, void* in, size_t n, ncclDataType_t type, ncclRedOp_t op);
extern void CheckDelta(void* expected, void* results, size_t count, ncclDataType_t type, double* devmax);
extern double DeltaMaxValue(ncclDataType_t type);
// Provided by each coll
void RunTest(struct threadArgs_t* args, int root, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName);
extern void GetBw(size_t count, int typeSize, double sec, double* algBw, double* busBw, int nranks);
extern void RunColl(void* sendbuf, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream);
extern void InitData(struct threadArgs_t* args, ncclDataType_t type, ncclRedOp_t op, int in_place, int is_first);
extern double CheckData(struct threadArgs_t* args, ncclDataType_t type, ncclRedOp_t op);
extern void AllocateBuffs(void **sendbuff, void **recvbuff, void **expected, void **expectedHost, size_t nbytes, int nranks);
extern void InitRecvResult(struct threadArgs_t* args, ncclDataType_t type, ncclRedOp_t op, int root, int in_place, int is_first);
extern void getCollByteCount(size_t *sendbytes, size_t *recvbytes, size_t *parambytes, size_t *sendInlineOffset, size_t *recvInlineOffset, size_t *procSharedBytes, int *sameexpected, size_t nbytes, int nranks);
extern void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root);
extern void print_header();
#include <unistd.h>
static void getHostName(char* hostname, int maxlen) {
gethostname(hostname, maxlen);
for (int i=0; i< maxlen; i++) {
if (hostname[i] == '.') {
hostname[i] = '\0';
return;
}
}
}
#include <stdint.h>
static uint64_t getHostHash(const char* string) {
// Based on DJB2, result = result * 33 + char
uint64_t result = 5381;
for (int c = 0; string[c] != '\0'; c++){
result = ((result << 5) + result) + string[c];
}
return result;
}
static size_t wordSize(ncclDataType_t type) {
switch(type) {
case ncclChar:
#if NCCL_MAJOR >= 2
//case ncclInt8:
case ncclUint8:
#endif
return 1;
case ncclHalf:
//case ncclFloat16:
return 2;
case ncclInt:
case ncclFloat:
#if NCCL_MAJOR >= 2
//case ncclInt32:
case ncclUint32:
//case ncclFloat32:
#endif
return 4;
case ncclInt64:
case ncclUint64:
case ncclDouble:
//case ncclFloat64:
return 8;
default: return 0;
}
}
extern ncclDataType_t test_types[ncclNumTypes];
extern const char *test_typenames[ncclNumTypes];
extern ncclRedOp_t test_ops[ncclNumOps];
extern const char *test_opnames[ncclNumOps];
extern thread_local int is_main_thread;
#define PRINT if (is_main_thread) printf

47
src/nccl1_compat.h Normal file
View File

@ -0,0 +1,47 @@
/*************************************************************************
* Copyright (c) 2017, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
************************************************************************/
#ifndef NCCL1_COMPAT_H
#define NCCL1_COMPAT_H
#ifndef NCCL_MAJOR // NCCL 1.x
#define ncclNumOps nccl_NUM_OPS
#define ncclNumTypes nccl_NUM_TYPES
static ncclResult_t ncclGroupStart() { return ncclSuccess; }
static ncclResult_t ncclGroupEnd() { return ncclSuccess; }
#define CHECKCOUNT(count) if (count > INT_MAX) return ncclInvalidArgument;
static ncclResult_t ncclReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
CHECKCOUNT(count);
return ncclReduce(sendbuff, recvbuff, (int)count, datatype, op, root, comm, stream);
}
static ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
CHECKCOUNT(count);
return ncclAllReduce(sendbuff, recvbuff, (int)count, datatype, op, comm, stream);
}
static ncclResult_t ncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root,
ncclComm_t comm, cudaStream_t stream) {
CHECKCOUNT(count);
return ncclBcast(buff, (int)count, datatype, root, comm, stream);
}
static ncclResult_t ncclReduceScatter(const void* sendbuff, void* recvbuff,
size_t recvcount, ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm,
cudaStream_t stream) {
CHECKCOUNT(recvcount);
return ncclReduceScatter(sendbuff, recvbuff, (int)recvcount, datatype, op, comm, stream);
}
static ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
CHECKCOUNT(sendcount);
return ncclAllGather(sendbuff, (int)sendcount, datatype, recvbuff, comm, stream);
}
#endif
#endif

159
src/reduce.cu Normal file
View File

@ -0,0 +1,159 @@
/*************************************************************************
* Copyright (c) 2016-2017, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
************************************************************************/
#include <assert.h>
#include "cuda_runtime.h"
#include "common.h"
void print_header() {
PRINT("# %10s %12s %6s %6s out-of-place in-place\n", "", "", "", "");
PRINT("# %10s %12s %6s %6s %6s %7s %5s %5s %7s %7s %5s %5s %7s\n", "bytes", "N", "type", "op", "root",
"time", "algbw", "busbw", "res", "time", "algbw", "busbw", "res");
}
void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) {
PRINT("%12li %12li %6s %6s %6i", size, count, typeName, opName, root);
}
void getCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t *procSharedCount, int *sameExpected, size_t count, int nranks) {
*sendcount = count;
*recvcount = count;
*sameExpected = 0;
*procSharedCount = count;
*sendInplaceOffset = 0;
*recvInplaceOffset = 0;
*paramcount = *sendcount;
}
void InitRecvResult(struct threadArgs_t* args, ncclDataType_t type, ncclRedOp_t op, int root, int in_place, int is_first) {
size_t count = args->expectedBytes / wordSize(type);
int root_gpu = root%args->nGpus;
assert(args->expectedBytes == args->nbytes);
while (args->sync[args->sync_idx] != args->thread) pthread_yield();
for (int i=0; i<args->nGpus; i++) {
int device;
NCCLCHECK(ncclCommCuDevice(args->comms[i], &device));
CUDACHECK(cudaSetDevice(device));
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
if (is_first && i == 0) {
CUDACHECK(cudaMemcpy(args->procSharedHost, data, count*wordSize(type), cudaMemcpyDeviceToHost));
} else {
Accumulate(args->procShared, data, count, type, op);
}
if (in_place == 0) {
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, args->expectedBytes));
}
CUDACHECK(cudaDeviceSynchronize());
}
args->sync[args->sync_idx] = args->thread + 1;
if (args->thread+1 == args->nThreads) {
#ifdef MPI_SUPPORT
int root_proc = root/(args->nThreads*args->nGpus);
if (args->expectedBytes) {
// Last thread does the MPI reduction
if (root_proc == args->proc) {
void* temp, *tempHost = malloc(args->expectedBytes);
CUDACHECK(cudaHostRegister(tempHost, args->expectedBytes, 0));
CUDACHECK(cudaHostGetDevicePointer(&temp, tempHost, 0));
for (int i=0; i<args->nProcs; i++) {
if (i == args->proc) continue;
MPI_Recv(tempHost, args->expectedBytes, MPI_BYTE, i, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
Accumulate(args->procShared, temp, count, type, op);
CUDACHECK(cudaDeviceSynchronize());
}
CUDACHECK(cudaHostUnregister(tempHost));
free(tempHost);
} else {
MPI_Send(args->procSharedHost, args->expectedBytes, MPI_BYTE, root_proc, 0, MPI_COMM_WORLD);
}
}
#endif
args->sync[args->sync_idx] = 0;
} else {
while (args->sync[args->sync_idx]) pthread_yield();
}
//if root fill expected bytes with reduced data
// else if in_place, leave fill it with original data, else set to zero
for (int i=0; i<args->nGpus; i++) {
int rank = (args->proc*args->nThreads + args->thread)*args->nGpus + i;
if (rank == root) {
memcpy(args->expectedHost[root_gpu], args->procSharedHost, args->expectedBytes);
} else {
if (in_place == 1) {
CUDACHECK(cudaMemcpy(args->expectedHost[i], args->recvbuffs[i], args->expectedBytes, cudaMemcpyDeviceToHost));
} else {
memset(args->expectedHost[i], 0, args->expectedBytes);
}
}
}
args->sync_idx = !args->sync_idx;
}
void GetBw(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks) {
double baseBw = (double)(count * typesize) / 1.0E9 / sec;
*algBw = baseBw;
*busBw = baseBw;
}
void RunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
NCCLCHECK(ncclReduce(sendbuff, recvbuff, count, type, op, root, comm, stream));
}
void RunTest(struct threadArgs_t* args, int root, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName) {
ncclDataType_t *run_types;
ncclRedOp_t *run_ops;
const char **run_typenames, **run_opnames;
int type_count, op_count;
int begin_root, end_root;
if ((int)type != -1) {
type_count = 1;
run_types = &type;
run_typenames = &typeName;
} else {
type_count = ncclNumTypes;
run_types = test_types;
run_typenames = test_typenames;
}
if ((int)op != -1) {
op_count = 1;
run_ops = &op;
run_opnames = &opName;
} else {
op_count = ncclNumOps;
run_ops = test_ops;
run_opnames = test_opnames;
}
if (root != -1) {
begin_root = end_root = root;
} else {
begin_root = 0;
end_root = args->nProcs*args->nThreads*args->nGpus-1;
}
for (int i=0; i<type_count; i++) {
for (int j=0; j<op_count; j++) {
for (int k=begin_root; k<=end_root; k++) {
TimeTest(args, run_types[i], run_typenames[i], run_ops[j], run_opnames[j], k, 1);
}
}
}
}

139
src/reduce_scatter.cu Normal file
View File

@ -0,0 +1,139 @@
/*************************************************************************
* Copyright (c) 2016-2017, NVIDIA CORPORATION. All rights reserved.
*
* See LICENCE.txt for license information
************************************************************************/
#include "cuda_runtime.h"
#include "common.h"
void print_header() {
PRINT("# %10s %12s %6s %6s out-of-place in-place\n", "", "", "", "");
PRINT("# %10s %12s %6s %6s %7s %5s %5s %7s %7s %5s %5s %7s\n", "bytes", "N", "type", "op",
"time", "algbw", "busbw", "res", "time", "algbw", "busbw", "res");
}
void print_line_header (size_t size, size_t count, const char *typeName, const char *opName, int root) {
PRINT("%12li %12li %6s %6s", size, count, typeName, opName);
}
void getCollByteCount(size_t *sendcount, size_t *recvcount, size_t *paramcount, size_t *sendInplaceOffset, size_t *recvInplaceOffset, size_t *procSharedCount, int *sameExpected, size_t count, int nranks) {
*sendcount = (count/nranks)*nranks;
*recvcount = count/nranks;
*sameExpected = 0;
*procSharedCount = *sendcount;
*sendInplaceOffset = 0;
*recvInplaceOffset = count/nranks;
*paramcount = *recvcount;
}
void InitRecvResult(struct threadArgs_t* args, ncclDataType_t type, ncclRedOp_t op, int root, int in_place, int is_first) {
size_t recvbytes = args->expectedBytes;
size_t recvcount = args->expectedBytes / wordSize(type);
size_t sendbytes = args->sendBytes;
size_t sendcount = args->sendBytes / wordSize(type);
while (args->sync[args->sync_idx] != args->thread) pthread_yield();
for (int i=0; i<args->nGpus; i++) {
int device;
NCCLCHECK(ncclCommCuDevice(args->comms[i], &device));
CUDACHECK(cudaSetDevice(device));
void* data = in_place ? args->recvbuffs[i] : args->sendbuffs[i];
if (is_first && i == 0) {
CUDACHECK(cudaMemcpy(args->procSharedHost, data, sendbytes, cudaMemcpyDeviceToHost));
} else {
Accumulate(args->procShared, data, sendcount, type, op);
}
CUDACHECK(cudaDeviceSynchronize());
if (in_place == 0) {
CUDACHECK(cudaMemset(args->recvbuffs[i], 0, recvbytes));
}
CUDACHECK(cudaDeviceSynchronize());
}
args->sync[args->sync_idx] = args->thread + 1;
if (args->thread+1 == args->nThreads) {
#ifdef MPI_SUPPORT
if (sendbytes > 0) {
// Last thread does the MPI reduction
void* remote, *remoteHost = malloc(sendbytes);
void* myInitialData = malloc(sendbytes);
memcpy(myInitialData, args->procSharedHost, sendbytes);
CUDACHECK(cudaHostRegister(remoteHost, sendbytes, 0));
CUDACHECK(cudaHostGetDevicePointer(&remote, remoteHost, 0));
for (int i=0; i<args->nProcs; i++) {
if (i == args->proc) {
MPI_Bcast(myInitialData, sendbytes, MPI_BYTE, i, MPI_COMM_WORLD);
free(myInitialData);
} else {
MPI_Bcast(remoteHost, sendbytes, MPI_BYTE, i, MPI_COMM_WORLD);
Accumulate(args->procShared, remote, sendcount, type, op);
cudaDeviceSynchronize();
}
}
CUDACHECK(cudaHostUnregister(remoteHost));
free(remoteHost);
}
#endif
args->sync[args->sync_idx] = 0;
} else {
while (args->sync[args->sync_idx]) pthread_yield();
}
for (int i=0; i<args->nGpus; i++) {
int offset = ((args->proc*args->nThreads + args->thread)*args->nGpus + i)*recvbytes;
memcpy(args->expectedHost[i], (void *)((uintptr_t)args->procSharedHost + offset), recvbytes);
}
args->sync_idx = !args->sync_idx;
}
void GetBw(size_t count, int typesize, double sec, double* algBw, double* busBw, int nranks) {
double baseBw = (double)(count * typesize * (nranks - 1)) / 1.0E9 / sec;
*algBw = baseBw;
double factor = 1;
*busBw = baseBw * factor;
}
void RunColl(void* sendbuff, void* recvbuff, size_t count, ncclDataType_t type, ncclRedOp_t op, int root, ncclComm_t comm, cudaStream_t stream) {
NCCLCHECK(ncclReduceScatter(sendbuff, recvbuff, count, type, op, comm, stream));
}
void RunTest(struct threadArgs_t* args, int root, ncclDataType_t type, const char* typeName, ncclRedOp_t op, const char* opName) {
ncclDataType_t *run_types;
ncclRedOp_t *run_ops;
const char **run_typenames, **run_opnames;
int type_count, op_count;
if ((int)type != -1) {
type_count = 1;
run_types = &type;
run_typenames = &typeName;
} else {
type_count = ncclNumTypes;
run_types = test_types;
run_typenames = test_typenames;
}
if ((int)op != -1) {
run_ops = &op;
run_opnames = &opName;
op_count = 1;
} else {
op_count = sizeof(test_ops)/sizeof(test_ops[0]);
run_ops = test_ops;
run_opnames = test_opnames;
}
for (int i=0; i<type_count; i++) {
for (int j=0; j<op_count; j++) {
TimeTest(args, run_types[i], run_typenames[i], run_ops[j], run_opnames[j], 0, 1);
}
}
}