This commit is contained in:
cxjyxx_me 2020-04-16 15:16:21 +08:00
parent fd80146901
commit a089378ec6
7 changed files with 39 additions and 24 deletions

View File

@ -21,7 +21,6 @@ NcclAllReduceOp::NcclAllReduceOp(Var* x) : x(x) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
ASSERT(x->dtype().is_float());
}
void NcclAllReduceOp::infer_shape() {
@ -37,11 +36,17 @@ void NcclAllReduceOp::jit_prepare() {
#ifdef JIT_cuda
void NcclAllReduceOp::jit_run() {
@define(T_NCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
@if(@strcmp(@Tx,int64)==0, ncclInt64)
)
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
int size = 1 @for(i, 0, XDIM, * xshape@{i});
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
checkCudaErrors(ncclAllReduce(xp, yp, size, ncclFloat, ncclSum, comm, 0));
checkCudaErrors(ncclAllReduce(xp, yp, size, @T_NCCL, ncclSum, comm, 0));
}
#endif

View File

@ -21,7 +21,6 @@ NcclBroadcastOp::NcclBroadcastOp(Var* x, int root) : x(x), root(root) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
ASSERT(x->dtype().is_float());
}
void NcclBroadcastOp::infer_shape() {
@ -37,11 +36,17 @@ void NcclBroadcastOp::jit_prepare() {
#ifdef JIT_cuda
void NcclBroadcastOp::jit_run() {
@define(T_NCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
@if(@strcmp(@Tx,int64)==0, ncclInt64)
)
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
int size = 1 @for(i, 0, XDIM, * xshape@{i});
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
checkCudaErrors(ncclBroadcast(xp, yp, size, ncclFloat, root, comm, 0));
checkCudaErrors(ncclBroadcast(xp, yp, size, @T_NCCL, root, comm, 0));
}
#endif

View File

@ -21,7 +21,6 @@ NcclReduceOp::NcclReduceOp(Var* x, int root) : x(x), root(root) {
flags.set(NodeFlags::_cpu, 0);
flags.set(NodeFlags::_cuda, 1);
y = create_output(nullptr, x->dtype());
ASSERT(x->dtype().is_float());
}
void NcclReduceOp::infer_shape() {
@ -37,11 +36,17 @@ void NcclReduceOp::jit_prepare() {
#ifdef JIT_cuda
void NcclReduceOp::jit_run() {
@define(T_NCCL,
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat)
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
@if(@strcmp(@Tx,int64)==0, ncclInt64)
)
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
int size = 1 @for(i, 0, XDIM, * xshape@{i});
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
checkCudaErrors(ncclReduce(xp, yp, size, ncclFloat, ncclSum, root, comm, 0));
checkCudaErrors(ncclReduce(xp, yp, size, @T_NCCL, ncclSum, root, comm, 0));
}
#endif

View File

@ -7,10 +7,7 @@
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "nccl_warper.h"
#ifdef HAS_CUDA
#include "event_queue.h"
#endif
const char *_cudaGetErrorEnum(ncclResult_t error) {
return ncclGetErrorString(error);

View File

@ -16,7 +16,7 @@ with lock.lock_scope():
from jittor_core import *
from jittor_core.ops import *
from . import compile_extern
from .compile_extern import mkl_ops
from .compile_extern import mkl_ops, inside_mpi, mpi_ops
import contextlib
import numpy as np
@ -579,6 +579,15 @@ class Module:
for p in self.parameters():
if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]:
p.start_grad()
def mpi_sync(self):
if not inside_mpi():
return
ps = self.parameters()
for p in ps:
temp = mpi_ops.mpi_broadcast(p, 0)
p.assign(temp.detach())
p.detach_inplace()
def make_module(func, exec_n_args=1):
class MakeModule(Module):

View File

@ -175,6 +175,9 @@ class SGD(object):
def step(self, loss):
ps = self.parameters
gs = jt.grad(loss, ps)
if jt.compile_extern.inside_mpi():
for g in gs:
g.assign(jt.compile_extern.mpi_ops.mpi_all_reduce(g))
for p, g, v in zip(ps, gs, self.values):
dp = p * self.weight_decay + g
v.assign(self.momentum * v + dp * (1 - self.dampening))
@ -189,14 +192,6 @@ class SGD(object):
# sync such parameters to reduce memory consumption
jt.sync(self.no_grad_parameters)
def sync(self):
ps = self.parameters
for p in ps:
temp = jt.compile_extern.nccl_ops.nccl_broadcast(p, 0)
p -= p
p += temp
p.detach_inplace()
class Adam(object):
""" Usage:
optimizer = nn.Adam(model.parameters(), lr)
@ -225,6 +220,9 @@ class Adam(object):
def step(self, loss):
ps = self.parameters
gs = jt.grad(loss, ps)
if jt.compile_extern.inside_mpi():
for g in gs:
g.assign(jt.compile_extern.mpi_ops.mpi_all_reduce(g))
self.adam_step += 1
n, (b0, b1) = float(self.adam_step), self.betas
for p, g, v, m in zip(ps, gs, self.values, self.m):

View File

@ -25,7 +25,6 @@ def test_all_reduce():
def test_broadcast():
print("test broadcast")
mpi = jt.compile_extern.mpi
data = jt.random([5, 5])
if mpi.world_rank() == 0:
x = data
@ -36,7 +35,6 @@ def test_broadcast():
def test_reduce():
print("test reduce")
mpi = jt.compile_extern.mpi
x = jt.random([5, 5])
y = jt.compile_extern.nccl_ops.nccl_reduce(x, 0)
y_ = y.data
@ -55,9 +53,8 @@ class Model(Module):
return self.linear2(x)
def test_sync():
mpi = jt.compile_extern.mpi
print("test mpi_sync")
net = Model()
SGD = nn.SGD(net.parameters(), 0.1, 0.9, 0.00001)
if mpi.world_rank() == 0:
net.linear1.weight *= 0
net.linear2.weight *= 0
@ -65,12 +62,11 @@ def test_sync():
net.linear1.weight += 1
net.linear2.weight += 1
net.linear1.bias += 1
SGD.sync()
net.mpi_sync()
assert np.allclose(net.linear1.weight.data, jt.ones(net.linear1.weight.shape).data)
assert np.allclose(net.linear2.weight.data, jt.ones(net.linear2.weight.shape).data)
assert np.allclose(net.linear1.bias.data, jt.ones(net.linear1.bias.shape).data)
def main():
np.random.seed(0)
jt.set_seed(3)