test batchnorm backward

This commit is contained in:
guowei yang 2020-04-21 22:59:19 +08:00
parent ab872906c2
commit d4cbd373f2
4 changed files with 65 additions and 18 deletions

View File

@ -38,6 +38,8 @@ MpiBroadcastOp::MpiBroadcastOp(Var* x, int root) : x(x), root(root) {
void MpiBroadcastOp::infer_shape() {
y->set_shape(x->shape);
if (root == mpi_world_rank)
y->share_with(x);
}
VarPtr MpiBroadcastOp::grad(Var* out, Var* dout, Var* v, int v_index) {
@ -56,14 +58,8 @@ void MpiBroadcastOp::jit_prepare() {
void MpiBroadcastOp::jit_run() {
@for(i, 0, XDIM, index_t xshape@i = x->shape[@i];)
int size = 1 @for(i, 0, XDIM, * xshape@{i});
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
if (mpi_world_rank == root) {
for (int i = 0; i < mpi_world_size; i++) {
MPI_Send(xp, size, MPI_FLOAT, i, 0, MPI_COMM_WORLD);
}
}
MPI_Recv(yp, size, MPI_FLOAT, root, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
MPI_Bcast(yp, size, MPI_FLOAT, root, MPI_COMM_WORLD);
}
#else
void MpiBroadcastOp::jit_run() {

View File

@ -59,8 +59,9 @@ void MpiReduceOp::jit_run() {
auto* __restrict__ xp = x->ptr<Tx>();
auto* __restrict__ yp = y->ptr<Tx>();
index_t num = y->num;
for (index_t i=0; i<num; i++)
yp[i] = 0;
if (root != mpi_world_rank)
for (index_t i=0; i<num; i++)
yp[i] = 0;
MPI_Reduce(xp, yp, size, MPI_FLOAT, MPI_SUM, root, MPI_COMM_WORLD);
}
#else

View File

@ -270,15 +270,13 @@ class BatchNorm(Module):
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
def execute(self, x):
mpi = jt.compile_extern.mpi
if self.is_train:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
if self.sync and jt.compile_extern.mpi_ops is not None:
tmpx = jt.compile_extern.mpi_ops.mpi_all_reduce(x)/jt.compile_extern.mpi.world_size()
tmpx2 = jt.compile_extern.mpi_ops.mpi_all_reduce(x*x)/jt.compile_extern.mpi.world_size()
xmean = jt.mean(tmpx, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(tmpx2, dims=[0,2,3], keepdims=1)
else:
xmean = jt.mean(x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(x*x, dims=[0,2,3], keepdims=1)
xmean = jt.compile_extern.mpi_ops.mpi_all_reduce(xmean)/jt.compile_extern.mpi.world_size()
x2mean = jt.compile_extern.mpi_ops.mpi_all_reduce(x2mean)/jt.compile_extern.mpi.world_size()
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)

View File

@ -9,12 +9,43 @@
import unittest
import os, sys
import jittor as jt
from jittor import init
from jittor import nn
import numpy as np
from jittor.test.test_mpi import run_mpi_test
mpi = jt.compile_extern.mpi
class FakeMpiBatchNorm(nn.Module):
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=None, is_train=True):
assert affine == None
self.num_features = num_features
self.is_train = is_train
self.eps = eps
self.momentum = momentum
self.weight = init.constant((num_features,), "float32", 1.0)
self.bias = init.constant((num_features,), "float32", 0.0)
self.running_mean = init.constant((num_features,), "float32", 0.0).stop_grad()
self.running_var = init.constant((num_features,), "float32", 1.0).stop_grad()
def execute(self, x, global_x):
if self.is_train:
xmean = jt.mean(global_x, dims=[0,2,3], keepdims=1)
x2mean = jt.mean(global_x*global_x, dims=[0,2,3], keepdims=1)
xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
else:
running_mean = self.running_mean.broadcast(x, [0,2,3])
running_var = self.running_var.broadcast(x, [0,2,3])
norm_x = (x-running_mean)/jt.sqrt(running_var+self.eps)
w = self.weight.broadcast(x, [0,2,3])
b = self.bias.broadcast(x, [0,2,3])
return norm_x * w + b
@unittest.skipIf(mpi is None, "no inside mpirun")
class TestMpiBatchnorm(unittest.TestCase):
@classmethod
@ -28,19 +59,40 @@ class TestMpiBatchnorm(unittest.TestCase):
x1 = jt.array(data)
x2 = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
bn1 = nn.BatchNorm(3, sync=True)
bn2 = nn.BatchNorm(3, sync=False)
bn1 = nn.BatchNorm(3, sync=False)
bn2 = nn.BatchNorm(3, sync=True)
bn3 = FakeMpiBatchNorm(3)
y1 = bn1(x1).data
y2 = bn2(x2).data
y3 = bn3(x2,x1).data
assert np.allclose(y2, y3, atol=1e-4), (y2, y3)
assert np.allclose(bn1.running_mean.data, bn2.running_mean.data), \
(bn1.running_mean.data, bn2.running_mean.data)
assert np.allclose(bn1.running_var.data, bn2.running_var.data)
def test_batchnorm_backward(self):
mpi = jt.compile_extern.mpi
data = np.random.rand(30,3,10,10).astype("float32")
global_x = jt.array(data)
x = jt.array(data[mpi.world_rank()*10:(mpi.world_rank()+1)*10,...])
bn1 = nn.BatchNorm(3, sync=True)
bn2 = FakeMpiBatchNorm(3)
y1 = bn1(x)
y2 = bn2(x,global_x)
gs1 = jt.grad(y1,bn1.parameters())
gs2 = jt.grad(y2,bn2.parameters())
assert np.allclose(y1.data, y2.data, atol=1e-5),(mpi.world_rank(),y1.data, y2.data, y1.data-y2.data)
for i in range(len(gs1)):
assert np.allclose(gs1[i].data, gs2[i].data, rtol=1e-3),(mpi.world_rank(),gs1[i].data, gs2[i].data,gs1[i].data-gs2[i].data)
@unittest.skipIf(not jt.has_cuda, "no cuda")
@jt.flag_scope(use_cuda=1)
def test_batchnorm_cuda(self):
self.test_batchnorm()
self.test_batchnorm_backward()
@unittest.skipIf(not jt.compile_extern.has_mpi, "no mpi found")