mirror of https://github.com/Jittor/Jittor
merge mpi ops
This commit is contained in:
commit
fd80146901
|
@ -0,0 +1,64 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mpi_warper.h"
|
||||
#include "var.h"
|
||||
#include "mpi_all_reduce_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
MpiAllReduceOp::MpiAllReduceOp(Var* x) : x(x) {
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*nccl_all_reduce)(Var*) = nullptr;
|
||||
if (!nccl_all_reduce && has_op("nccl_all_reduce")) {
|
||||
nccl_all_reduce = get_op_info("nccl_all_reduce")
|
||||
.get_constructor<VarPtr, Var*>();
|
||||
}
|
||||
if (nccl_all_reduce) {
|
||||
LOGr << "nccl";
|
||||
auto var = nccl_all_reduce(x);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void MpiAllReduceOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void MpiAllReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MpiAllReduceOp::jit_run() {
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
MPI_Allreduce(xp, yp, size, MPI_FLOAT, MPI_SUM, MPI_COMM_WORLD);
|
||||
}
|
||||
#else
|
||||
void MpiAllReduceOp::jit_run() {
|
||||
// cuda device code
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,24 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MpiAllReduceOp : Op {
|
||||
Var* x, * y;
|
||||
|
||||
MpiAllReduceOp(Var* x);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_all_reduce"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,69 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mpi_warper.h"
|
||||
#include "var.h"
|
||||
#include "mpi_broadcast_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*nccl_broadcast)(Var*, int) = nullptr;
|
||||
if (!nccl_broadcast && has_op("nccl_broadcast")) {
|
||||
nccl_broadcast = get_op_info("nccl_broadcast")
|
||||
.get_constructor<VarPtr, Var*, int>();
|
||||
}
|
||||
if (nccl_broadcast) {
|
||||
LOGr << "nccl";
|
||||
auto var = nccl_broadcast(x, root);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void MpiBroadcastOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void MpiBroadcastOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MpiBroadcastOp::jit_run() {
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
if (mpi_world_rank == root) {
|
||||
for (int i = 0; i < mpi_world_size; i++) {
|
||||
MPI_Send(xp, size, MPI_FLOAT, i, 0, MPI_COMM_WORLD);
|
||||
}
|
||||
}
|
||||
MPI_Recv(yp, size, MPI_FLOAT, root, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
|
||||
}
|
||||
#else
|
||||
void MpiBroadcastOp::jit_run() {
|
||||
// cuda device code
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,25 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MpiBroadcastOp : Op {
|
||||
Var* x, * y;
|
||||
int root;
|
||||
|
||||
MpiBroadcastOp(Var* x, int root);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_broadcast"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,64 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "mpi_warper.h"
|
||||
#include "var.h"
|
||||
#include "mpi_reduce_op.h"
|
||||
#include "ops/op_register.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "misc/cuda_flags.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
MpiReduceOp::MpiReduceOp(Var* x, int root) : x(x), root(root) {
|
||||
#ifdef HAS_CUDA
|
||||
if (use_cuda) {
|
||||
static VarPtr(*nccl_reduce)(Var*, int) = nullptr;
|
||||
if (!nccl_reduce && has_op("nccl_reduce")) {
|
||||
nccl_reduce = get_op_info("nccl_reduce")
|
||||
.get_constructor<VarPtr, Var*, int>();
|
||||
}
|
||||
if (nccl_reduce) {
|
||||
LOGr << "nccl";
|
||||
auto var = nccl_reduce(x, root);
|
||||
forward(var);
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void MpiReduceOp::infer_shape() {
|
||||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
void MpiReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
}
|
||||
|
||||
#else // JIT
|
||||
#ifdef JIT_cpu
|
||||
void MpiReduceOp::jit_run() {
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
MPI_Reduce(xp, yp, size, MPI_FLOAT, MPI_SUM, root, MPI_COMM_WORLD);
|
||||
}
|
||||
#else
|
||||
void MpiReduceOp::jit_run() {
|
||||
// cuda device code
|
||||
}
|
||||
#endif // JIT_cpu
|
||||
#endif // JIT
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,25 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020
|
||||
// Guowei Yang <471184555@qq.com>.
|
||||
// Dun Liang <randonlang@gmail.com>.
|
||||
// All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct MpiReduceOp : Op {
|
||||
Var* x, * y;
|
||||
int root;
|
||||
|
||||
MpiReduceOp(Var* x, int root);
|
||||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "mpi_reduce"; }
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
} // jittor
|
|
@ -66,6 +66,33 @@ def batch_norm(x, is_train, eps=1e-5, momentum=0.1):
|
|||
|
||||
return norm_x * w + b
|
||||
|
||||
@jt.var_scope('sync_batch_norm')
|
||||
def sync_batch_norm(x, is_train, eps=1e-5, momentum=0.1):
|
||||
assert not (jt.compile_extern.mpi_ops is None)
|
||||
w = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
|
||||
b = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
|
||||
running_mean = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
|
||||
running_var = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
|
||||
|
||||
w = w.broadcast(x, [0,2,3])
|
||||
b = b.broadcast(x, [0,2,3])
|
||||
if is_train:
|
||||
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
|
||||
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
|
||||
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+eps)
|
||||
|
||||
running_mean += (xmean.sum([0,2,3])-running_mean)*momentum
|
||||
running_var += (xvar.sum([0,2,3])-running_var)*momentum
|
||||
else:
|
||||
running_mean = running_mean.broadcast(x, [0,2,3])
|
||||
running_var = running_var.broadcast(x, [0,2,3])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+eps)
|
||||
|
||||
return norm_x * w + b
|
||||
|
||||
@jt.var_scope('conv')
|
||||
def conv(x, in_planes, out_planes, kernel_size, padding, stride = 1, init_method=None):
|
||||
Kw = kernel_size
|
||||
|
@ -277,6 +304,39 @@ class BatchNorm(Module):
|
|||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
class SyncBatchNorm(Module):
|
||||
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True):
|
||||
assert affine == None
|
||||
|
||||
self.num_features = num_features
|
||||
self.is_train = is_train
|
||||
self.eps = eps
|
||||
self.momentum = momentum
|
||||
self.weight = init.constant((num_features,), "float32", 1.0)
|
||||
self.bias = init.constant((num_features,), "float32", 0.0)
|
||||
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
|
||||
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
|
||||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
assert not (jt.compile_extern.mpi_ops is None)
|
||||
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
|
||||
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
|
||||
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
|
||||
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
|
||||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
|
||||
w = self.weight.broadcast(x, [0,2,3])
|
||||
b = self.bias.broadcast(x, [0,2,3])
|
||||
return norm_x * w + b
|
||||
|
||||
Relu = jt.make_module(relu)
|
||||
ReLU = Relu
|
||||
Leaky_relu = jt.make_module(leaky_relu, 2)
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import os, sys
|
||||
import jittor as jt
|
||||
from jittor import nn
|
||||
import numpy as np
|
||||
|
||||
def test_batchnorm():
|
||||
print("test batchnorm")
|
||||
mpi = jt.compile_extern.mpi
|
||||
data = np.random.rand(30,3,10,10)
|
||||
x1 = jt.array(data)
|
||||
x2 = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
|
||||
|
||||
bn1 = nn.BatchNorm(3)
|
||||
bn2 = nn.SyncBatchNorm(3)
|
||||
y1 = bn1(x1).data
|
||||
y2 = bn2(x2).data
|
||||
|
||||
assert bn1.running_mean==bn2.running_mean
|
||||
assert bn1.running_var==bn2.running_var
|
||||
|
||||
def main():
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
with jt.flag_scope(use_cuda=0):
|
||||
test_batchnorm()
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
test_batchnorm()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestMpiOps(unittest.TestCase):
|
||||
def test(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
cmd = f"{mpirun_path} -np 3 {sys.executable} -m jittor.test.test_mpi_batchnorm"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
else:
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,67 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Guowei Yang <471184555@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import os, sys
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
def test_all_reduce():
|
||||
print("test all_reduce")
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
|
||||
def test_broadcast():
|
||||
print("test broadcast")
|
||||
mpi = jt.compile_extern.mpi
|
||||
data = jt.random([5, 5])
|
||||
if mpi.world_rank() == 0:
|
||||
x = data
|
||||
else:
|
||||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
|
||||
def test_reduce():
|
||||
print("test reduce")
|
||||
mpi = jt.compile_extern.mpi
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
|
||||
y.sync()
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
|
||||
def main():
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
with jt.flag_scope(use_cuda=0):
|
||||
if jt.compile_extern.mpi_ops:
|
||||
test_all_reduce()
|
||||
test_broadcast()
|
||||
test_reduce()
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
if jt.compile_extern.mpi_ops:
|
||||
test_all_reduce()
|
||||
test_broadcast()
|
||||
test_reduce()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestMpiOps(unittest.TestCase):
|
||||
def test(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
cmd = f"{mpirun_path} -np 3 {sys.executable} -m jittor.test.test_mpi_op"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
else:
|
||||
main()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue