Merge branch 'ygy' of https://github.com/Jittor/jittor into ygy

This commit is contained in:
guowei yang 2020-04-06 22:16:22 +08:00
commit f22a8ec1fe
3 changed files with 85 additions and 0 deletions

51
extern/cuda/nccl/ops/nccl_reduce_op.cc vendored Normal file
View File

@ -0,0 +1,51 @@
// ***************************************************************
// Copyright (c) 2020
// Guoye Yang <498731903@qq.com>.
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "var.h"
#include "nccl_reduce_op.h"
#include "misc/str_utils.h"
#include <nccl.h>
#include <cuda_runtime.h>
#include <helper_cuda.h>
#include "nccl_warper.h"
namespace jittor {
#ifndef JIT
NcclReduceOp::NcclReduceOp(Var* x, int root) : x(x), root(root) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
ASSERT(x->dtype().is_float());
}
void NcclReduceOp::infer_shape() {
y->set_shape(x->shape);
}
void NcclReduceOp::jit_prepare() {
add_jit_define("Tx", x->dtype());
add_jit_define("XDIM", JK::hex1(x->shape.size()));
}
#else // JIT
#ifdef JIT_cuda
void NcclReduceOp::jit_run() {
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
int size = 1 @for(i, 0, XDIM, * xshape@{i});
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
checkCudaErrors(ncclReduce(xp, yp, size, ncclFloat, root, comm, 0));
checkCudaErrors(cudaStreamSynchronize(0));
}
#endif
#endif // JIT
} // jittor

25
extern/cuda/nccl/ops/nccl_reduce_op.h vendored Normal file
View File

@ -0,0 +1,25 @@
// ***************************************************************
// Copyright (c) 2020
// Guoye Yang <498731903@qq.com>.
// Dun Liang <randonlang@gmail.com>.
// All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "op.h"
namespace jittor {
struct NcclReduceOp : Op {
Var* x, * y;
int root;
NcclReduceOp(Var* x, int root);
void infer_shape() override;
const char* name() const override { return "nccl_reduce"; }
DECLARE_jit_run;
};
} // jittor

View File

@ -28,6 +28,14 @@ def test_broadcast():
y = jt.compile_extern.nccl_ops.nccl_broadcast(x, 0)
assert np.allclose(y.data, data.data)
def test_reduce():
print("test reduce")
mpi = jt.compile_extern.mpi
x = jt.random([5, 5])
y = jt.compile_extern.nccl_ops.nccl_all_reduce(x)
if mpi.world_rank() == 0:
assert np.allclose(y.data, (x*3).data)
def main():
np.random.seed(0)
jt.set_seed(3)
@ -35,6 +43,7 @@ def main():
if jt.compile_extern.nccl_ops:
test_all_reduce()
test_broadcast()
test_reduce()
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
class TestNcclOps(unittest.TestCase):