polish test_mpi_batchnorm

This commit is contained in:
Dun Liang 2020-04-20 22:05:38 +08:00
parent d60b37bb07
commit 2315928240
5 changed files with 195 additions and 211 deletions

View File

@ -271,7 +271,7 @@ class BatchNorm(Module):
def execute(self, x):
if self.is_train:
if self.sync and not (jt.compile_extern.mpi_ops is None):
if self.sync and jt.compile_extern.mpi_ops is not None:
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)

View File

@ -17,7 +17,7 @@ class TestMpi(unittest.TestCase):
def test_mpi_test_op(self):
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no inccl")
@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no nccl")
@jt.flag_scope(use_cuda=1)
def test_nccl_with_mpi(self):
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
@ -47,14 +47,17 @@ class TestMpi(unittest.TestCase):
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
assert (c==a.data).all()
def run_mpi_test(num_procs, name):
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = f"{mpirun_path} -np {num_procs} {sys.executable} -m jittor.test.{name} -v"
print("run cmd:", cmd)
assert os.system(cmd)==0, "run cmd failed: "+cmd
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestMpiEntry(unittest.TestCase):
def test_entry(self):
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi -v"
print("run cmd:", cmd)
assert os.system(cmd)==0, "run cmd failed: "+cmd
run_mpi_test(2, "test_mpi")
if __name__ == "__main__":
unittest.main()

View File

@ -11,41 +11,42 @@ import os, sys
import jittor as jt
from jittor import nn
import numpy as np
from jittor.test.test_mpi import run_mpi_test
def test_batchnorm():
print("test batchnorm")
mpi = jt.compile_extern.mpi
data = np.random.rand(30,3,10,10)
x1 = jt.array(data)
x2 = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
bn1 = nn.BatchNorm(3)
bn2 = nn.SyncBatchNorm(3)
y1 = bn1(x1).data
y2 = bn2(x2).data
mpi = jt.compile_extern.mpi
assert bn1.running_mean==bn2.running_mean
assert bn1.running_var==bn2.running_var
@unittest.skipIf(mpi is None, "no inside mpirun")
class TestMpiBatchnorm(unittest.TestCase):
@classmethod
def setUpClass(self):
np.random.seed(0)
jt.seed(3)
def test_batchnorm(self):
mpi = jt.compile_extern.mpi
data = np.random.rand(30,3,10,10).astype("float32")
x1 = jt.array(data)
x2 = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
bn1 = nn.BatchNorm(3, sync=True)
bn2 = nn.BatchNorm(3, sync=False)
y1 = bn1(x1).data
y2 = bn2(x2).data
assert np.allclose(bn1.running_mean.data, bn2.running_mean.data), \
(bn1.running_mean.data, bn2.running_mean.data)
assert np.allclose(bn1.running_var.data, bn2.running_var.data)
@unittest.skipIf(not jt.has_cuda, "no cuda")
@jt.flag_scope(use_cuda=1)
def test_batchnorm_cuda(self):
self.test_batchnorm()
def main():
np.random.seed(0)
jt.set_seed(3)
with jt.flag_scope(use_cuda=0):
test_batchnorm()
with jt.flag_scope(use_cuda=1):
test_batchnorm()
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestMpiOps(unittest.TestCase):
class TestMpiBatchnormEntry(unittest.TestCase):
def test(self):
mpi = jt.compile_extern.mpi
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
cmd = f"{mpirun_path} -np 3 {sys.executable} -m jittor.test.test_mpi_batchnorm"
print("run cmd", cmd)
jt.compiler.run_cmd(cmd)
else:
main()
run_mpi_test(3, "test_mpi_batchnorm")
if __name__ == "__main__":
unittest.main()
unittest.main()

View File

@ -10,65 +10,54 @@ import unittest
import os, sys
import jittor as jt
import numpy as np
from jittor.test.test_mpi import run_mpi_test
def test_all_reduce():
print("test all_reduce")
mpi = jt.compile_extern.mpi
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
assert np.allclose(y.data, (x*3).data)
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5])*3)
mpi = jt.compile_extern.mpi
def test_broadcast():
print("test broadcast")
mpi = jt.compile_extern.mpi
data = jt.random([5, 5])
if mpi.world_rank() == 0:
x = data
else:
x = jt.zeros([5, 5])
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
assert np.allclose(y.data, data.data)
g = jt.grad(y,x)
if mpi.world_rank() == 0:
assert np.allclose(g.data, np.ones([5,5])*3)
else:
assert np.allclose(g.data, np.zeros([5,5]))
@unittest.skipIf(mpi is None, "no inside mpirun")
class TestMpiOps(unittest.TestCase):
@classmethod
def setUpClass(self):
np.random.seed(0)
jt.seed(3)
def test_reduce():
print("test reduce")
mpi = jt.compile_extern.mpi
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
y.sync()
if mpi.world_rank() == 0:
def test_all_reduce(self):
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
assert np.allclose(y.data, (x*3).data)
else:
assert np.allclose(y.data, np.zeros([5,5]))
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5]))
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5])*3)
def test_broadcast(self):
data = jt.random([5, 5])
if mpi.world_rank() == 0:
x = data
else:
x = jt.zeros([5, 5])
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
assert np.allclose(y.data, data.data)
g = jt.grad(y,x)
if mpi.world_rank() == 0:
assert np.allclose(g.data, np.ones([5,5])*3)
else:
assert np.allclose(g.data, np.zeros([5,5]))
def test_reduce(self):
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
y.sync()
if mpi.world_rank() == 0:
assert np.allclose(y.data, (x*3).data)
else:
assert np.allclose(y.data, np.zeros([5,5]))
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5]))
def main():
np.random.seed(0)
jt.set_seed(3)
with jt.flag_scope(use_cuda=0):
if jt.compile_extern.mpi_ops:
test_all_reduce()
test_broadcast()
test_reduce()
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestMpiOps(unittest.TestCase):
class TestMpiOpsEntry(unittest.TestCase):
def test(self):
mpi = jt.compile_extern.mpi
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
cmd = f"{mpirun_path} -np 3 {sys.executable} -m jittor.test.test_mpi_op"
print("run cmd", cmd)
jt.compiler.run_cmd(cmd)
else:
main()
run_mpi_test(3, "test_mpi_op")
if __name__ == "__main__":
unittest.main()

View File

@ -15,138 +15,129 @@ from jittor import nn
from jittor import nn, Module
import copy
from jittor.test.test_log import find_log_with_re
from jittor.test.test_mpi import run_mpi_test
from jittor.compile_extern import mpi, mpi_ops, nccl_ops
n = 2
mpi = jt.compile_extern.mpi
def test_all_reduce():
print("test all_reduce")
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
log_v=1, log_vprefix="op.cc=100,exe=1000"
) as raw_log:
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
assert np.allclose(y.data, (x*n).data)
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5])*n)
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_all_reduce.*)")
assert len(logs)==2, len(logs)
@unittest.skipIf(nccl_ops is None, "nccl not found")
class TestNcclOps(unittest.TestCase):
@classmethod
def setUpClass(self):
np.random.seed(0)
jt.seed(3)
def test_broadcast():
print("test broadcast")
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
log_v=1, log_vprefix="op.cc=100,exe=1000"
) as raw_log:
data = jt.random([5, 5])
@jt.flag_scope(use_cuda=1)
def test_all_reduce(self):
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
log_v=1, log_vprefix="op.cc=100,exe=1000"
) as raw_log:
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
assert np.allclose(y.data, (x*n).data)
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5])*n)
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_all_reduce.*)")
assert len(logs)==2, len(logs)
@jt.flag_scope(use_cuda=1)
def test_broadcast(self):
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
log_v=1, log_vprefix="op.cc=100,exe=1000"
) as raw_log:
data = jt.random([5, 5])
if mpi.world_rank() == 0:
x = data
else:
x = jt.zeros([5, 5])
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
assert np.allclose(y.data, data.data)
g = jt.grad(y.sum(),x)
g_ = g.data
if mpi.world_rank() == 0:
assert np.allclose(g_, np.ones([5,5])*n)
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_broadcast.*)")
assert len(logs)==1, len(logs)
@jt.flag_scope(use_cuda=1)
def test_reduce(self):
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
log_v=1, log_vprefix="op.cc=100,exe=1000"
) as raw_log:
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
y_ = y.data
x_ = (x*n).data
if mpi.world_rank() == 0:
assert np.allclose(y_, x_)
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5]))
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_reduce.*)")
assert len(logs)==1, len(logs)
@jt.flag_scope(use_cuda=1)
def test_sync(self):
class Model(Module):
def __init__(self):
self.linear1 = nn.Linear(3, 3)
self.linear2 = nn.Linear(3, 1024, False)
def execute(self, x):
x = self.linear1(x)
x = nn.relu(x)
return self.linear2(x)
net = Model()
if mpi.world_rank() == 0:
x = data
else:
x = jt.zeros([5, 5])
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
assert np.allclose(y.data, data.data)
g = jt.grad(y.sum(),x)
g_ = g.data
if mpi.world_rank() == 0:
assert np.allclose(g_, np.ones([5,5])*n)
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_broadcast.*)")
assert len(logs)==1, len(logs)
net.linear1.weight *= 0
net.linear2.weight *= 0
net.linear1.bias *= 0
net.linear1.weight += 1
net.linear2.weight += 1
net.linear1.bias += 1
net.mpi_sync()
assert np.allclose(net.linear1.weight.data, jt.ones(net.linear1.weight.shape).data)
assert np.allclose(net.linear2.weight.data, jt.ones(net.linear2.weight.shape).data)
assert np.allclose(net.linear1.bias.data, jt.ones(net.linear1.bias.shape).data)
def test_reduce():
print("test reduce")
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
log_v=1, log_vprefix="op.cc=100,exe=1000"
) as raw_log:
x = jt.random([5, 5])
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
y_ = y.data
x_ = (x*n).data
if mpi.world_rank() == 0:
assert np.allclose(y_, x_)
g = jt.grad(y,x)
assert np.allclose(g.data, np.ones([5,5]))
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_reduce.*)")
assert len(logs)==1, len(logs)
@jt.flag_scope(use_cuda=1)
def test_optimizer(self):
class Model(Module):
def __init__(self):
self.linear1 = nn.Linear(3, 3)
self.linear2 = nn.Linear(3, 1024, False)
class Model2(Module):
def __init__(self, input_size):
self.linear1 = nn.Linear(input_size, 10)
self.relu1 = nn.Relu()
self.linear2 = nn.Linear(10, 1)
def execute(self, x):
x = self.linear1(x)
x = self.relu1(x)
return self.linear2(x)
def get_data(n):
for i in range(n):
x = np.random.rand(50, 1)
y = x*x
yield jt.float32(x), jt.float32(y)
def execute(self, x):
x = self.linear1(x)
x = nn.relu(x)
return self.linear2(x)
def test_sync():
print("test mpi_sync")
net = Model()
if mpi.world_rank() == 0:
net.linear1.weight *= 0
net.linear2.weight *= 0
net.linear1.bias *= 0
net.linear1.weight += 1
net.linear2.weight += 1
net.linear1.bias += 1
net.mpi_sync()
assert np.allclose(net.linear1.weight.data, jt.ones(net.linear1.weight.shape).data)
assert np.allclose(net.linear2.weight.data, jt.ones(net.linear2.weight.shape).data)
assert np.allclose(net.linear1.bias.data, jt.ones(net.linear1.bias.shape).data)
class Model2(Module):
def __init__(self, input_size):
self.linear1 = nn.Linear(input_size, 10)
self.relu1 = nn.Relu()
self.linear2 = nn.Linear(10, 1)
def execute(self, x):
x = self.linear1(x)
x = self.relu1(x)
return self.linear2(x)
def test_optimizer():
print("test optimizer")
def get_data(n):
for i in range(n):
x = np.random.rand(50, 1)
y = x*x
yield jt.float32(x), jt.float32(y)
num = 2000
model = Model2(1)
model.mpi_sync()
optimizer = nn.SGD(model.parameters(), 0.05)
dataset = list(enumerate(get_data(num)))
for i in range(mpi.world_rank(), num, n):
id, (x, y) = dataset[i]
pred_y = model(x)
loss = (pred_y - y)*(pred_y - y)
loss_mean = loss.mean()
optimizer.step(loss_mean)
assert loss_mean.data < 0.0025
jt.clean()
def main():
np.random.seed(0)
jt.set_seed(3)
with jt.flag_scope(use_cuda=1):
if jt.compile_extern.nccl_ops:
test_sync()
test_all_reduce()
test_broadcast()
test_reduce()
test_optimizer()
@unittest.skipIf(mpi is None, "no inside mpirun")
class TestMpi(unittest.TestCase):
def test(self):
main()
num = 2000
model = Model2(1)
model.mpi_sync()
optimizer = nn.SGD(model.parameters(), 0.05)
dataset = list(enumerate(get_data(num)))
for i in range(mpi.world_rank(), num, n):
id, (x, y) = dataset[i]
pred_y = model(x)
loss = (pred_y - y)*(pred_y - y)
loss_mean = loss.mean()
optimizer.step(loss_mean)
assert loss_mean.data < 0.0025
jt.clean()
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestNcclOps(unittest.TestCase):
def test_entry(self):
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
cmd = f"{mpirun_path} -np {n} {sys.executable} -m jittor.test.test_nccl_ops -v"
print("run cmd:", cmd)
assert os.system(cmd)==0, "run cmd failed: "+cmd
class TestNcclOpsEntry(unittest.TestCase):
def test(self):
run_mpi_test(2, "test_nccl_ops")
if __name__ == "__main__":
unittest.main()