mirror of https://github.com/Jittor/Jittor
test optimizer & nccl backward
This commit is contained in:
parent
4e79827ce9
commit
5ccf9502c7
|
@ -14,6 +14,7 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "nccl_warper.h"
|
||||
#include "ops/op_register.h"
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
@ -27,6 +28,12 @@ void NcclAllReduceOp::infer_shape() {
|
|||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
VarPtr NcclAllReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static VarPtr(*nccl_all_reduce)(Var*) =
|
||||
get_op_info("nccl_all_reduce").get_constructor<VarPtr, Var*>();
|
||||
return nccl_all_reduce(dout);
|
||||
}
|
||||
|
||||
void NcclAllReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
|
|
|
@ -18,6 +18,7 @@ struct NcclAllReduceOp : Op {
|
|||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "nccl_all_reduce"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "nccl_warper.h"
|
||||
#include "ops/op_register.h"
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
@ -27,6 +28,12 @@ void NcclBroadcastOp::infer_shape() {
|
|||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
VarPtr NcclBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static VarPtr(*nccl_reduce)(Var*, int) =
|
||||
get_op_info("nccl_reduce").get_constructor<VarPtr, Var*, int>();
|
||||
return nccl_reduce(dout,root);
|
||||
}
|
||||
|
||||
void NcclBroadcastOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
|
|
|
@ -19,6 +19,7 @@ struct NcclBroadcastOp : Op {
|
|||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "nccl_broadcast"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
#include <cuda_runtime.h>
|
||||
#include <helper_cuda.h>
|
||||
#include "nccl_warper.h"
|
||||
#include "ops/op_register.h"
|
||||
namespace jittor {
|
||||
|
||||
#ifndef JIT
|
||||
|
@ -27,6 +28,12 @@ void NcclReduceOp::infer_shape() {
|
|||
y->set_shape(x->shape);
|
||||
}
|
||||
|
||||
VarPtr NcclReduceOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
static VarPtr(*nccl_broadcast)(Var*, int) =
|
||||
get_op_info("nccl_broadcast").get_constructor<VarPtr, Var*, int>();
|
||||
return nccl_broadcast(dout,root);
|
||||
}
|
||||
|
||||
void NcclReduceOp::jit_prepare() {
|
||||
add_jit_define("Tx", x->dtype());
|
||||
add_jit_define("XDIM", JK::hex1(x->shape.size()));
|
||||
|
|
|
@ -19,6 +19,7 @@ struct NcclReduceOp : Op {
|
|||
void infer_shape() override;
|
||||
|
||||
const char* name() const override { return "nccl_reduce"; }
|
||||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
DECLARE_jit_run;
|
||||
};
|
||||
|
||||
|
|
|
@ -31,8 +31,9 @@ def test_broadcast():
|
|||
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
g = jt.grad(y,x)
|
||||
g_ = g.data
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
||||
assert np.allclose(g_, np.ones([5,5])*3)
|
||||
|
||||
def test_reduce():
|
||||
print("test reduce")
|
||||
|
|
|
@ -14,33 +14,58 @@ import numpy as np
|
|||
from jittor import nn
|
||||
from jittor import nn, Module
|
||||
import copy
|
||||
from jittor.test.test_log import find_log_with_re
|
||||
n = 2
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
def test_all_reduce():
|
||||
print("test all_reduce")
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.nccl_ops.nccl_all_reduce(x)
|
||||
assert np.allclose(y.data, (x*n).data)
|
||||
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
|
||||
assert np.allclose(y.data, (x*n).data)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5])*n)
|
||||
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_all_reduce.*)")
|
||||
assert len(logs)==2, len(logs)
|
||||
|
||||
def test_broadcast():
|
||||
print("test broadcast")
|
||||
data = jt.random([5, 5])
|
||||
if mpi.world_rank() == 0:
|
||||
x = data
|
||||
else:
|
||||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.nccl_ops.nccl_broadcast(x, 0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
data = jt.random([5, 5])
|
||||
if mpi.world_rank() == 0:
|
||||
x = data
|
||||
else:
|
||||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
g = jt.grad(y.sum(),x)
|
||||
g_ = g.data
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(g_, np.ones([5,5])*n)
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_broadcast.*)")
|
||||
assert len(logs)==1, len(logs)
|
||||
|
||||
def test_reduce():
|
||||
print("test reduce")
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.nccl_ops.nccl_reduce(x, 0)
|
||||
y_ = y.data
|
||||
x_ = (x*n).data
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y_, x_)
|
||||
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
|
||||
y_ = y.data
|
||||
x_ = (x*n).data
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y_, x_)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5]))
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_reduce.*)")
|
||||
assert len(logs)==1, len(logs)
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
|
@ -67,6 +92,37 @@ def test_sync():
|
|||
assert np.allclose(net.linear2.weight.data, jt.ones(net.linear2.weight.shape).data)
|
||||
assert np.allclose(net.linear1.bias.data, jt.ones(net.linear1.bias.shape).data)
|
||||
|
||||
class Model2(Module):
|
||||
def __init__(self, input_size):
|
||||
self.linear1 = nn.Linear(input_size, 10)
|
||||
self.relu1 = nn.Relu()
|
||||
self.linear2 = nn.Linear(10, 1)
|
||||
def execute(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu1(x)
|
||||
return self.linear2(x)
|
||||
|
||||
def test_optimizer():
|
||||
print("test optimizer")
|
||||
def get_data(n):
|
||||
for i in range(n):
|
||||
x = np.random.rand(50, 1)
|
||||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
num = 2000
|
||||
model = Model2(1)
|
||||
model.mpi_sync()
|
||||
optimizer = nn.SGD(model.parameters(), 0.05)
|
||||
dataset = list(enumerate(get_data(num)))
|
||||
for i in range(mpi.world_rank(), num, n):
|
||||
id, (x, y) = dataset[i]
|
||||
pred_y = model(x)
|
||||
loss = (pred_y - y)*(pred_y - y)
|
||||
loss_mean = loss.mean()
|
||||
optimizer.step(loss_mean)
|
||||
assert loss_mean.data < 0.0025
|
||||
jt.clean()
|
||||
|
||||
def main():
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
|
@ -76,18 +132,7 @@ def main():
|
|||
test_all_reduce()
|
||||
test_broadcast()
|
||||
test_reduce()
|
||||
|
||||
# @unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
# class TestNcclOps(unittest.TestCase):
|
||||
# def test(self):
|
||||
# mpi = jt.compile_extern.mpi
|
||||
# if mpi.world_size() == 1 and n != 1:
|
||||
# mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
# cmd = f"{mpirun_path} -np {n} {sys.executable} -m jittor.test.test_nccl_ops"
|
||||
# print("run cmd", cmd)
|
||||
# jt.compiler.run_cmd(cmd)
|
||||
# else:
|
||||
# main()
|
||||
test_optimizer()
|
||||
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
class TestMpi(unittest.TestCase):
|
||||
|
|
Loading…
Reference in New Issue