mirror of https://github.com/Jittor/Jittor
Merge branch 'ygy' of https://github.com/Jittor/jittor into ygy
This commit is contained in:
commit
6c9bd429f6
|
@ -0,0 +1,51 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "nccl_all_reduce_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "nccl_warper.h"
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
NcclAllReduceOp::NcclAllReduceOp(Var* x) : x(x) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void NcclAllReduceOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void NcclAllReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
void NcclAllReduceOp::jit_run() {
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
checkCudaErrors(ncclAllReduce(xp, yp, size, ncclFloat, ncclSum, comm, 0));
|
||||
checkCudaErrors(cudaStreamSynchronize(0));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,24 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct NcclAllReduceOp : Op {
|
||||
Var* x, * y;
|
||||
|
||||
NcclAllReduceOp(Var* x);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "nccl_all_reduce"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,51 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "var.h"
|
||||
#include "nccl_broadcast_op.h"
|
||||
#include "misc/str_utils.h"
|
||||
|
||||
#include <nccl.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "nccl_warper.h"
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
NcclBroadcastOp::NcclBroadcastOp(Var* x, int root) : x(x), root(root) {
|
||||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void NcclBroadcastOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void NcclBroadcastOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cuda
|
||||
|
||||
void NcclBroadcastOp::jit_run() {
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
checkCudaErrors(ncclBroadcast(xp, yp, size, ncclFloat, root, comm, 0));
|
||||
checkCudaErrors(cudaStreamSynchronize(0));
|
||||
}
|
||||
|
||||
#endif
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,25 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guoye Yang <498731903@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct NcclBroadcastOp : Op {
|
||||
Var* x, * y;
|
||||
int root;
|
||||
|
||||
NcclBroadcastOp(Var* x, int root);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "nccl_broadcast"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,52 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import os, sys
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
def test_all_reduce():
|
||||
print("test all_reduce")
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.nccl_ops.nccl_all_reduce(x)
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
|
||||
def test_broadcast():
|
||||
print("test broadcast")
|
||||
mpi = jt.compile_extern.mpi
|
||||
data = jt.random([5, 5])
|
||||
if mpi.world_rank() == 0:
|
||||
x = data
|
||||
else:
|
||||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.nccl_ops.nccl_broadcast(x, 0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
|
||||
def main():
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
if jt.compile_extern.nccl_ops:
|
||||
test_all_reduce()
|
||||
test_broadcast()
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
|
||||
class TestNcclOps(unittest.TestCase):
|
||||
def test(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
if mpi.world_size() == 1:
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
cmd = f"{mpirun_path} -np 3 {sys.executable} -m jittor.test.test_nccl_ops"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
else:
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue