mirror of https://github.com/Jittor/Jittor
mpi op
This commit is contained in:
parent
f22a8ec1fe
commit
e745e46bb3
|
@ -0,0 +1,64 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mpi_warper.h"
|
||||
#include "var.h"
|
||||
#include "mpi_all_reduce_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
MpiAllReduceOp::MpiAllReduceOp(Var* x) : x(x) {
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*nccl_all_reduce)(Var*) = nullptr;
|
||||
if (!nccl_all_reduce && has_op("nccl_all_reduce")) {
|
||||
nccl_all_reduce = get_op_info("nccl_all_reduce")
|
||||
.get_constructor<VarPtr, Var*>();
|
||||
}
|
||||
if (nccl_all_reduce) {
|
||||
LOGr << "nccl";
|
||||
auto var = nccl_all_reduce(x);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void MpiAllReduceOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void MpiAllReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MpiAllReduceOp::jit_run() {
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
MPI_Allreduce(xp, yp, size, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD);
|
||||
}
|
||||
#else
|
||||
void MpiAllReduceOp::jit_run() {
|
||||
// cuda device code
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,24 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MpiAllReduceOp : Op {
|
||||
Var* x, * y;
|
||||
|
||||
MpiAllReduceOp(Var* x);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_all_reduce"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,69 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mpi_warper.h"
|
||||
#include "var.h"
|
||||
#include "mpi_broadcast_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*nccl_broadcast)(Var*, int) = nullptr;
|
||||
if (!nccl_broadcast && has_op("nccl_broadcast")) {
|
||||
nccl_broadcast = get_op_info("nccl_broadcast")
|
||||
.get_constructor<VarPtr, Var*, int>();
|
||||
}
|
||||
if (nccl_broadcast) {
|
||||
LOGr << "nccl";
|
||||
auto var = nccl_broadcast(x, root);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void MpiBroadcastOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void MpiBroadcastOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MpiBroadcastOp::jit_run() {
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
if (mpi_world_rank == root) {
|
||||
for (int i = 0; i < mpi_world_size; i++) {
|
||||
MPI_Send(xp, size, MPI_FLOAT, i, 0, MPI_COMM_WORLD);
|
||||
}
|
||||
}
|
||||
MPI_Recv(yp, size, MPI_FLOAT, root, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||||
}
|
||||
#else
|
||||
void MpiBroadcastOp::jit_run() {
|
||||
// cuda device code
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,25 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MpiBroadcastOp : Op {
|
||||
Var* x, * y;
|
||||
int root;
|
||||
|
||||
MpiBroadcastOp(Var* x, int root);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_broadcast"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,64 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mpi_warper.h"
|
||||
#include "var.h"
|
||||
#include "mpi_reduce_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
MpiReduceOp::MpiReduceOp(Var* x, int root) : x(x), root(root) {
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*nccl_reduce)(Var*, int) = nullptr;
|
||||
if (!nccl_reduce && has_op("nccl_reduce")) {
|
||||
nccl_reduce = get_op_info("nccl_reduce")
|
||||
.get_constructor<VarPtr, Var*, int>();
|
||||
}
|
||||
if (nccl_reduce) {
|
||||
LOGr << "nccl";
|
||||
auto var = nccl_reduce(x, root);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void MpiReduceOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void MpiReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MpiReduceOp::jit_run() {
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
MPI_Reduce(xp, yp, size, MPI_FLOAT, MPI_SUM, root, MPI_COMM_WORLD);
|
||||
}
|
||||
#else
|
||||
void MpiReduceOp::jit_run() {
|
||||
// cuda device code
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,25 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MpiReduceOp : Op {
|
||||
Var* x, * y;
|
||||
int root;
|
||||
|
||||
MpiReduceOp(Var* x, int root);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_reduce"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,67 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import os, sys
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
def test_all_reduce():
|
||||
print("test all_reduce")
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
|
||||
def test_broadcast():
|
||||
print("test broadcast")
|
||||
mpi = jt.compile_extern.mpi
|
||||
data = jt.random([5, 5])
|
||||
if mpi.world_rank() == 0:
|
||||
x = data
|
||||
else:
|
||||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
|
||||
def test_reduce():
|
||||
print("test reduce")
|
||||
mpi = jt.compile_extern.mpi
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
|
||||
y.sync()
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
|
||||
def main():
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
with jt.flag_scope(use_cuda=0):
|
||||
if jt.compile_extern.mpi_ops:
|
||||
test_all_reduce()
|
||||
test_broadcast()
|
||||
test_reduce()
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
if jt.compile_extern.mpi_ops:
|
||||
test_all_reduce()
|
||||
test_broadcast()
|
||||
test_reduce()
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
|
||||
class TestMpiOps(unittest.TestCase):
|
||||
def test(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
if mpi.world_size() == 1:
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
cmd = f"{mpirun_path} -np 3 {sys.executable} -m jittor.test.test_mpi_op"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
else:
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue