mirror of https://github.com/Jittor/Jittor
fix
This commit is contained in:
parent
fd80146901
commit
a089378ec6
|
@ -21,7 +21,6 @@ NcclAllReduceOp::NcclAllReduceOp(Var* x) : x(x) {
|
|||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void NcclAllReduceOp::infer_shape() {
|
||||
|
@ -37,11 +36,17 @@ void NcclAllReduceOp::jit_prepare() {
|
|||
#ifdef JIT_cuda
|
||||
|
||||
void NcclAllReduceOp::jit_run() {
|
||||
@define(T_NCCL,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
|
||||
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
|
||||
@if(@strcmp(@Tx,int64)==0, ncclInt64)
|
||||
)
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
checkCudaErrors(ncclAllReduce(xp, yp, size, ncclFloat, ncclSum, comm, 0));
|
||||
checkCudaErrors(ncclAllReduce(xp, yp, size, @T_NCCL, ncclSum, comm, 0));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -21,7 +21,6 @@ NcclBroadcastOp::NcclBroadcastOp(Var* x, int root) : x(x), root(root) {
|
|||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void NcclBroadcastOp::infer_shape() {
|
||||
|
@ -37,11 +36,17 @@ void NcclBroadcastOp::jit_prepare() {
|
|||
#ifdef JIT_cuda
|
||||
|
||||
void NcclBroadcastOp::jit_run() {
|
||||
@define(T_NCCL,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
|
||||
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
|
||||
@if(@strcmp(@Tx,int64)==0, ncclInt64)
|
||||
)
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
checkCudaErrors(ncclBroadcast(xp, yp, size, ncclFloat, root, comm, 0));
|
||||
checkCudaErrors(ncclBroadcast(xp, yp, size, @T_NCCL, root, comm, 0));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -21,7 +21,6 @@ NcclReduceOp::NcclReduceOp(Var* x, int root) : x(x), root(root) {
|
|||
flags.set(NodeFlags::_cpu, 0);
|
||||
flags.set(NodeFlags::_cuda, 1);
|
||||
y = create_output(nullptr, x->dtype());
|
||||
ASSERT(x->dtype().is_float());
|
||||
}
|
||||
|
||||
void NcclReduceOp::infer_shape() {
|
||||
|
@ -37,11 +36,17 @@ void NcclReduceOp::jit_prepare() {
|
|||
#ifdef JIT_cuda
|
||||
|
||||
void NcclReduceOp::jit_run() {
|
||||
@define(T_NCCL,
|
||||
@if(@strcmp(@Tx,float)==0 || @strcmp(@Tx,float32)==0, ncclFloat)
|
||||
@if(@strcmp(@Tx,int)==0 || @strcmp(@Tx,int32)==0, ncclInt)
|
||||
@if(@strcmp(@Tx,float64)==0, ncclFloat64)
|
||||
@if(@strcmp(@Tx,int64)==0, ncclInt64)
|
||||
)
|
||||
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
|
||||
int size = 1 @for(i, 0, XDIM, * xshape@{i});
|
||||
auto* __restrict__ xp = x->ptr<Tx>();
|
||||
auto* __restrict__ yp = y->ptr<Tx>();
|
||||
checkCudaErrors(ncclReduce(xp, yp, size, ncclFloat, ncclSum, root, comm, 0));
|
||||
checkCudaErrors(ncclReduce(xp, yp, size, @T_NCCL, ncclSum, root, comm, 0));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
|
|
@ -7,10 +7,7 @@
|
|||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "nccl_warper.h"
|
||||
|
||||
#ifdef HAS_CUDA
|
||||
#include "event_queue.h"
|
||||
#endif
|
||||
|
||||
const char *_cudaGetErrorEnum(ncclResult_t error) {
|
||||
return ncclGetErrorString(error);
|
||||
|
|
|
@ -16,7 +16,7 @@ with lock.lock_scope():
|
|||
from jittor_core import *
|
||||
from jittor_core.ops import *
|
||||
from . import compile_extern
|
||||
from .compile_extern import mkl_ops
|
||||
from .compile_extern import mkl_ops, inside_mpi, mpi_ops
|
||||
|
||||
import contextlib
|
||||
import numpy as np
|
||||
|
@ -579,6 +579,15 @@ class Module:
|
|||
for p in self.parameters():
|
||||
if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]:
|
||||
p.start_grad()
|
||||
|
||||
def mpi_sync(self):
|
||||
if not inside_mpi():
|
||||
return
|
||||
ps = self.parameters()
|
||||
for p in ps:
|
||||
temp = mpi_ops.mpi_broadcast(p, 0)
|
||||
p.assign(temp.detach())
|
||||
p.detach_inplace()
|
||||
|
||||
def make_module(func, exec_n_args=1):
|
||||
class MakeModule(Module):
|
||||
|
|
|
@ -175,6 +175,9 @@ class SGD(object):
|
|||
def step(self, loss):
|
||||
ps = self.parameters
|
||||
gs = jt.grad(loss, ps)
|
||||
if jt.compile_extern.inside_mpi():
|
||||
for g in gs:
|
||||
g.assign(jt.compile_extern.mpi_ops.mpi_all_reduce(g))
|
||||
for p, g, v in zip(ps, gs, self.values):
|
||||
dp = p * self.weight_decay + g
|
||||
v.assign(self.momentum * v + dp * (1 - self.dampening))
|
||||
|
@ -189,14 +192,6 @@ class SGD(object):
|
|||
# sync such parameters to reduce memory consumption
|
||||
jt.sync(self.no_grad_parameters)
|
||||
|
||||
def sync(self):
|
||||
ps = self.parameters
|
||||
for p in ps:
|
||||
temp = jt.compile_extern.nccl_ops.nccl_broadcast(p, 0)
|
||||
p -= p
|
||||
p += temp
|
||||
p.detach_inplace()
|
||||
|
||||
class Adam(object):
|
||||
""" Usage:
|
||||
optimizer = nn.Adam(model.parameters(), lr)
|
||||
|
@ -225,6 +220,9 @@ class Adam(object):
|
|||
def step(self, loss):
|
||||
ps = self.parameters
|
||||
gs = jt.grad(loss, ps)
|
||||
if jt.compile_extern.inside_mpi():
|
||||
for g in gs:
|
||||
g.assign(jt.compile_extern.mpi_ops.mpi_all_reduce(g))
|
||||
self.adam_step += 1
|
||||
n, (b0, b1) = float(self.adam_step), self.betas
|
||||
for p, g, v, m in zip(ps, gs, self.values, self.m):
|
||||
|
|
|
@ -25,7 +25,6 @@ def test_all_reduce():
|
|||
|
||||
def test_broadcast():
|
||||
print("test broadcast")
|
||||
mpi = jt.compile_extern.mpi
|
||||
data = jt.random([5, 5])
|
||||
if mpi.world_rank() == 0:
|
||||
x = data
|
||||
|
@ -36,7 +35,6 @@ def test_broadcast():
|
|||
|
||||
def test_reduce():
|
||||
print("test reduce")
|
||||
mpi = jt.compile_extern.mpi
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.nccl_ops.nccl_reduce(x, 0)
|
||||
y_ = y.data
|
||||
|
@ -55,9 +53,8 @@ class Model(Module):
|
|||
return self.linear2(x)
|
||||
|
||||
def test_sync():
|
||||
mpi = jt.compile_extern.mpi
|
||||
print("test mpi_sync")
|
||||
net = Model()
|
||||
SGD = nn.SGD(net.parameters(), 0.1, 0.9, 0.00001)
|
||||
if mpi.world_rank() == 0:
|
||||
net.linear1.weight *= 0
|
||||
net.linear2.weight *= 0
|
||||
|
@ -65,12 +62,11 @@ def test_sync():
|
|||
net.linear1.weight += 1
|
||||
net.linear2.weight += 1
|
||||
net.linear1.bias += 1
|
||||
SGD.sync()
|
||||
net.mpi_sync()
|
||||
assert np.allclose(net.linear1.weight.data, jt.ones(net.linear1.weight.shape).data)
|
||||
assert np.allclose(net.linear2.weight.data, jt.ones(net.linear2.weight.shape).data)
|
||||
assert np.allclose(net.linear1.bias.data, jt.ones(net.linear1.bias.shape).data)
|
||||
|
||||
|
||||
def main():
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
|
|
Loading…
Reference in New Issue