mirror of https://github.com/Jittor/Jittor
polish test_mpi_batchnorm
This commit is contained in:
parent
d60b37bb07
commit
2315928240
|
@ -271,7 +271,7 @@ class BatchNorm(Module):
|
|||
|
||||
def execute(self, x):
|
||||
if self.is_train:
|
||||
if self.sync and not (jt.compile_extern.mpi_ops is None):
|
||||
if self.sync and jt.compile_extern.mpi_ops is not None:
|
||||
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
|
||||
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
|
||||
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
|
||||
|
|
|
@ -17,7 +17,7 @@ class TestMpi(unittest.TestCase):
|
|||
def test_mpi_test_op(self):
|
||||
assert jt.compile_extern.mpi_ops.mpi_test("").data == 123
|
||||
|
||||
@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no inccl")
|
||||
@unittest.skipIf(jt.compile_extern.nccl_ops is None, "no nccl")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_nccl_with_mpi(self):
|
||||
assert jt.compile_extern.nccl_ops.nccl_test("test_with_mpi").data == 123
|
||||
|
@ -47,14 +47,17 @@ class TestMpi(unittest.TestCase):
|
|||
c = np.array(range(offset+i*toy.batch_size, offset+(i+1)*toy.batch_size))
|
||||
assert (c==a.data).all()
|
||||
|
||||
def run_mpi_test(num_procs, name):
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = f"{mpirun_path} -np {num_procs} {sys.executable} -m jittor.test.{name} -v"
|
||||
print("run cmd:", cmd)
|
||||
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestMpiEntry(unittest.TestCase):
|
||||
def test_entry(self):
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = f"{mpirun_path} -np 2 {sys.executable} -m jittor.test.test_mpi -v"
|
||||
print("run cmd:", cmd)
|
||||
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
||||
run_mpi_test(2, "test_mpi")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -11,41 +11,42 @@ import os, sys
|
|||
import jittor as jt
|
||||
from jittor import nn
|
||||
import numpy as np
|
||||
from jittor.test.test_mpi import run_mpi_test
|
||||
|
||||
def test_batchnorm():
|
||||
print("test batchnorm")
|
||||
mpi = jt.compile_extern.mpi
|
||||
data = np.random.rand(30,3,10,10)
|
||||
x1 = jt.array(data)
|
||||
x2 = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
|
||||
|
||||
bn1 = nn.BatchNorm(3)
|
||||
bn2 = nn.SyncBatchNorm(3)
|
||||
y1 = bn1(x1).data
|
||||
y2 = bn2(x2).data
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
assert bn1.running_mean==bn2.running_mean
|
||||
assert bn1.running_var==bn2.running_var
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
class TestMpiBatchnorm(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
np.random.seed(0)
|
||||
jt.seed(3)
|
||||
|
||||
def test_batchnorm(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
data = np.random.rand(30,3,10,10).astype("float32")
|
||||
x1 = jt.array(data)
|
||||
x2 = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
|
||||
|
||||
bn1 = nn.BatchNorm(3, sync=True)
|
||||
bn2 = nn.BatchNorm(3, sync=False)
|
||||
y1 = bn1(x1).data
|
||||
y2 = bn2(x2).data
|
||||
|
||||
assert np.allclose(bn1.running_mean.data, bn2.running_mean.data), \
|
||||
(bn1.running_mean.data, bn2.running_mean.data)
|
||||
assert np.allclose(bn1.running_var.data, bn2.running_var.data)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "no cuda")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_batchnorm_cuda(self):
|
||||
self.test_batchnorm()
|
||||
|
||||
def main():
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
with jt.flag_scope(use_cuda=0):
|
||||
test_batchnorm()
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
test_batchnorm()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestMpiOps(unittest.TestCase):
|
||||
class TestMpiBatchnormEntry(unittest.TestCase):
|
||||
def test(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
cmd = f"{mpirun_path} -np 3 {sys.executable} -m jittor.test.test_mpi_batchnorm"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
else:
|
||||
main()
|
||||
run_mpi_test(3, "test_mpi_batchnorm")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
|
|
@ -10,65 +10,54 @@ import unittest
|
|||
import os, sys
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor.test.test_mpi import run_mpi_test
|
||||
|
||||
def test_all_reduce():
|
||||
print("test all_reduce")
|
||||
mpi = jt.compile_extern.mpi
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
def test_broadcast():
|
||||
print("test broadcast")
|
||||
mpi = jt.compile_extern.mpi
|
||||
data = jt.random([5, 5])
|
||||
if mpi.world_rank() == 0:
|
||||
x = data
|
||||
else:
|
||||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
g = jt.grad(y,x)
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
||||
else:
|
||||
assert np.allclose(g.data, np.zeros([5,5]))
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
class TestMpiOps(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
np.random.seed(0)
|
||||
jt.seed(3)
|
||||
|
||||
def test_reduce():
|
||||
print("test reduce")
|
||||
mpi = jt.compile_extern.mpi
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
|
||||
y.sync()
|
||||
if mpi.world_rank() == 0:
|
||||
def test_all_reduce(self):
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
else:
|
||||
assert np.allclose(y.data, np.zeros([5,5]))
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5]))
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
||||
|
||||
def test_broadcast(self):
|
||||
data = jt.random([5, 5])
|
||||
if mpi.world_rank() == 0:
|
||||
x = data
|
||||
else:
|
||||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
g = jt.grad(y,x)
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(g.data, np.ones([5,5])*3)
|
||||
else:
|
||||
assert np.allclose(g.data, np.zeros([5,5]))
|
||||
|
||||
def test_reduce(self):
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
|
||||
y.sync()
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y.data, (x*3).data)
|
||||
else:
|
||||
assert np.allclose(y.data, np.zeros([5,5]))
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5]))
|
||||
|
||||
def main():
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
with jt.flag_scope(use_cuda=0):
|
||||
if jt.compile_extern.mpi_ops:
|
||||
test_all_reduce()
|
||||
test_broadcast()
|
||||
test_reduce()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestMpiOps(unittest.TestCase):
|
||||
class TestMpiOpsEntry(unittest.TestCase):
|
||||
def test(self):
|
||||
mpi = jt.compile_extern.mpi
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
|
||||
cmd = f"{mpirun_path} -np 3 {sys.executable} -m jittor.test.test_mpi_op"
|
||||
print("run cmd", cmd)
|
||||
jt.compiler.run_cmd(cmd)
|
||||
else:
|
||||
main()
|
||||
run_mpi_test(3, "test_mpi_op")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -15,138 +15,129 @@ from jittor import nn
|
|||
from jittor import nn, Module
|
||||
import copy
|
||||
from jittor.test.test_log import find_log_with_re
|
||||
from jittor.test.test_mpi import run_mpi_test
|
||||
from jittor.compile_extern import mpi, mpi_ops, nccl_ops
|
||||
n = 2
|
||||
mpi = jt.compile_extern.mpi
|
||||
|
||||
def test_all_reduce():
|
||||
print("test all_reduce")
|
||||
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
|
||||
assert np.allclose(y.data, (x*n).data)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5])*n)
|
||||
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_all_reduce.*)")
|
||||
assert len(logs)==2, len(logs)
|
||||
@unittest.skipIf(nccl_ops is None, "nccl not found")
|
||||
class TestNcclOps(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
np.random.seed(0)
|
||||
jt.seed(3)
|
||||
|
||||
def test_broadcast():
|
||||
print("test broadcast")
|
||||
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
data = jt.random([5, 5])
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_all_reduce(self):
|
||||
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_all_reduce(x)
|
||||
assert np.allclose(y.data, (x*n).data)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5])*n)
|
||||
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_all_reduce.*)")
|
||||
assert len(logs)==2, len(logs)
|
||||
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_broadcast(self):
|
||||
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
data = jt.random([5, 5])
|
||||
if mpi.world_rank() == 0:
|
||||
x = data
|
||||
else:
|
||||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
g = jt.grad(y.sum(),x)
|
||||
g_ = g.data
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(g_, np.ones([5,5])*n)
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_broadcast.*)")
|
||||
assert len(logs)==1, len(logs)
|
||||
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_reduce(self):
|
||||
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
|
||||
y_ = y.data
|
||||
x_ = (x*n).data
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y_, x_)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5]))
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_reduce.*)")
|
||||
assert len(logs)==1, len(logs)
|
||||
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_sync(self):
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
self.linear1 = nn.Linear(3, 3)
|
||||
self.linear2 = nn.Linear(3, 1024, False)
|
||||
|
||||
def execute(self, x):
|
||||
x = self.linear1(x)
|
||||
x = nn.relu(x)
|
||||
return self.linear2(x)
|
||||
|
||||
net = Model()
|
||||
if mpi.world_rank() == 0:
|
||||
x = data
|
||||
else:
|
||||
x = jt.zeros([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_broadcast(x, 0)
|
||||
assert np.allclose(y.data, data.data)
|
||||
g = jt.grad(y.sum(),x)
|
||||
g_ = g.data
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(g_, np.ones([5,5])*n)
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_broadcast.*)")
|
||||
assert len(logs)==1, len(logs)
|
||||
net.linear1.weight *= 0
|
||||
net.linear2.weight *= 0
|
||||
net.linear1.bias *= 0
|
||||
net.linear1.weight += 1
|
||||
net.linear2.weight += 1
|
||||
net.linear1.bias += 1
|
||||
net.mpi_sync()
|
||||
assert np.allclose(net.linear1.weight.data, jt.ones(net.linear1.weight.shape).data)
|
||||
assert np.allclose(net.linear2.weight.data, jt.ones(net.linear2.weight.shape).data)
|
||||
assert np.allclose(net.linear1.bias.data, jt.ones(net.linear1.bias.shape).data)
|
||||
|
||||
def test_reduce():
|
||||
print("test reduce")
|
||||
with jt.log_capture_scope(enable_tuner=1, log_silent=1,
|
||||
log_v=1, log_vprefix="op.cc=100,exe=1000"
|
||||
) as raw_log:
|
||||
x = jt.random([5, 5])
|
||||
y = jt.compile_extern.mpi_ops.mpi_reduce(x, 0)
|
||||
y_ = y.data
|
||||
x_ = (x*n).data
|
||||
if mpi.world_rank() == 0:
|
||||
assert np.allclose(y_, x_)
|
||||
g = jt.grad(y,x)
|
||||
assert np.allclose(g.data, np.ones([5,5]))
|
||||
logs = find_log_with_re(raw_log, "(Jit op key (not )?found: nccl_reduce.*)")
|
||||
assert len(logs)==1, len(logs)
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_optimizer(self):
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self):
|
||||
self.linear1 = nn.Linear(3, 3)
|
||||
self.linear2 = nn.Linear(3, 1024, False)
|
||||
class Model2(Module):
|
||||
def __init__(self, input_size):
|
||||
self.linear1 = nn.Linear(input_size, 10)
|
||||
self.relu1 = nn.Relu()
|
||||
self.linear2 = nn.Linear(10, 1)
|
||||
def execute(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu1(x)
|
||||
return self.linear2(x)
|
||||
|
||||
def get_data(n):
|
||||
for i in range(n):
|
||||
x = np.random.rand(50, 1)
|
||||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
def execute(self, x):
|
||||
x = self.linear1(x)
|
||||
x = nn.relu(x)
|
||||
return self.linear2(x)
|
||||
|
||||
def test_sync():
|
||||
print("test mpi_sync")
|
||||
net = Model()
|
||||
if mpi.world_rank() == 0:
|
||||
net.linear1.weight *= 0
|
||||
net.linear2.weight *= 0
|
||||
net.linear1.bias *= 0
|
||||
net.linear1.weight += 1
|
||||
net.linear2.weight += 1
|
||||
net.linear1.bias += 1
|
||||
net.mpi_sync()
|
||||
assert np.allclose(net.linear1.weight.data, jt.ones(net.linear1.weight.shape).data)
|
||||
assert np.allclose(net.linear2.weight.data, jt.ones(net.linear2.weight.shape).data)
|
||||
assert np.allclose(net.linear1.bias.data, jt.ones(net.linear1.bias.shape).data)
|
||||
|
||||
class Model2(Module):
|
||||
def __init__(self, input_size):
|
||||
self.linear1 = nn.Linear(input_size, 10)
|
||||
self.relu1 = nn.Relu()
|
||||
self.linear2 = nn.Linear(10, 1)
|
||||
def execute(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu1(x)
|
||||
return self.linear2(x)
|
||||
|
||||
def test_optimizer():
|
||||
print("test optimizer")
|
||||
def get_data(n):
|
||||
for i in range(n):
|
||||
x = np.random.rand(50, 1)
|
||||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
num = 2000
|
||||
model = Model2(1)
|
||||
model.mpi_sync()
|
||||
optimizer = nn.SGD(model.parameters(), 0.05)
|
||||
dataset = list(enumerate(get_data(num)))
|
||||
for i in range(mpi.world_rank(), num, n):
|
||||
id, (x, y) = dataset[i]
|
||||
pred_y = model(x)
|
||||
loss = (pred_y - y)*(pred_y - y)
|
||||
loss_mean = loss.mean()
|
||||
optimizer.step(loss_mean)
|
||||
assert loss_mean.data < 0.0025
|
||||
jt.clean()
|
||||
|
||||
def main():
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
with jt.flag_scope(use_cuda=1):
|
||||
if jt.compile_extern.nccl_ops:
|
||||
test_sync()
|
||||
test_all_reduce()
|
||||
test_broadcast()
|
||||
test_reduce()
|
||||
test_optimizer()
|
||||
|
||||
@unittest.skipIf(mpi is None, "no inside mpirun")
|
||||
class TestMpi(unittest.TestCase):
|
||||
def test(self):
|
||||
main()
|
||||
num = 2000
|
||||
model = Model2(1)
|
||||
model.mpi_sync()
|
||||
optimizer = nn.SGD(model.parameters(), 0.05)
|
||||
dataset = list(enumerate(get_data(num)))
|
||||
for i in range(mpi.world_rank(), num, n):
|
||||
id, (x, y) = dataset[i]
|
||||
pred_y = model(x)
|
||||
loss = (pred_y - y)*(pred_y - y)
|
||||
loss_mean = loss.mean()
|
||||
optimizer.step(loss_mean)
|
||||
assert loss_mean.data < 0.0025
|
||||
jt.clean()
|
||||
|
||||
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
|
||||
class TestNcclOps(unittest.TestCase):
|
||||
def test_entry(self):
|
||||
if not jt.compile_extern.inside_mpi():
|
||||
mpirun_path = jt.compile_extern.mpicc_path.replace("mpicc", "mpirun")
|
||||
cmd = f"{mpirun_path} -np {n} {sys.executable} -m jittor.test.test_nccl_ops -v"
|
||||
print("run cmd:", cmd)
|
||||
assert os.system(cmd)==0, "run cmd failed: "+cmd
|
||||
class TestNcclOpsEntry(unittest.TestCase):
|
||||
def test(self):
|
||||
run_mpi_test(2, "test_nccl_ops")
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue