test mpi batchnorm

This commit is contained in:
guowei yang 2020-04-14 15:35:15 +08:00
parent 8cf35074a3
commit 43279ae58f
3 changed files with 115 additions and 10 deletions

View File

@ -51,14 +51,35 @@ def batch_norm(x, is_train, eps=1e-5, momentum=0.1):
w = w.broadcast(x, [0,2,3])
b = b.broadcast(x, [0,2,3])
if is_train:
if (jt.compile_extern.mpi_ops is None):
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
else:
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+eps)
running_mean += (xmean.sum([0,2,3])-running_mean)*momentum
running_var += (xvar.sum([0,2,3])-running_var)*momentum
else:
running_mean = running_mean.broadcast(x, [0,2,3])
running_var = running_var.broadcast(x, [0,2,3])
norm_x = (x-running_mean)/jt.sqrt(running_var+eps)
return norm_x * w + b
@jt.var_scope('sync_batch_norm')
def sync_batch_norm(x, is_train, eps=1e-5, momentum=0.1):
assert not (jt.compile_extern.mpi_ops is None)
w = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
b = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
running_mean = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 0.0))
running_var = jt.make_var([x.shape[1]], init=lambda *a: init.constant(*a, 1.0))
w = w.broadcast(x, [0,2,3])
b = b.broadcast(x, [0,2,3])
if is_train:
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+eps)
@ -274,6 +295,39 @@ class BatchNorm(Module):
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
class SyncBatchNorm(Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True):
assert affine == None
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
def execute(self, x):
if self.is_train:
assert not (jt.compile_extern.mpi_ops is None)
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
else:
running_mean = self.running_mean.broadcast(x, [0,2,3])
running_var = self.running_var.broadcast(x, [0,2,3])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
w = self.weight.broadcast(x, [0,2,3])
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
Relu = jt.make_module(relu)
ReLU = Relu
Leaky_relu = jt.make_module(leaky_relu, 2)

View File

@ -0,0 +1,51 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import os, sys
import jittor as jt
from jittor import nn
import numpy as np
def test_batchnorm():
print("test batchnorm")
mpi = jt.compile_extern.mpi
data = np.random.rand(30,3,10,10)
x1 = jt.array(data)
x2 = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
bn1 = nn.BatchNorm(3)
bn2 = nn.SyncBatchNorm(3)
y1 = bn1(x1).data
y2 = bn2(x2).data
assert bn1.running_mean==bn2.running_mean
assert bn1.running_var==bn2.running_var
def main():
np.random.seed(0)
jt.set_seed(3)
with jt.flag_scope(use_cuda=0):
test_batchnorm()
with jt.flag_scope(use_cuda=1):
test_batchnorm()
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestMpiOps(unittest.TestCase):
def test(self):
mpi = jt.compile_extern.mpi
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
cmd = f"{mpirun_path} -np 3 {sys.executable} -m jittor.test.test_mpi_batchnorm"
print("run cmd", cmd)
jt.compiler.run_cmd(cmd)
else:
main()
if __name__ == "__main__":
unittest.main()

View File

@ -51,11 +51,11 @@ def main():
test_broadcast()
test_reduce()
@unittest.skipIf(jt.compile_extern.mpi_ops is None, "no mpi found")
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")
class TestMpiOps(unittest.TestCase):
def test(self):
mpi = jt.compile_extern.mpi
if mpi.world_size() == 1:
if not jt.compile_extern.inside_mpi():
mpirun_path = jt.compiler.env_or_try_find('mpirun_path', 'mpirun')
cmd = f"{mpirun_path} -np 3 {sys.executable} -m jittor.test.test_mpi_op"
print("run cmd", cmd)