auto sync param

This commit is contained in:
guowei yang 2020-04-24 14:42:20 +08:00
parent c58f3fc3b7
commit 225ddd137f
1 changed files with 13 additions and 3 deletions

View File

@ -105,12 +105,14 @@ class SGD(object):
optimizer = nn.SGD(model.parameters(), lr)
optimizer.step(loss)
"""
def __init__(self, parameters, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False):
def __init__(self, parameters, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False, param_sync_iter=10000):
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.dampening = dampening
self.nesterov = nesterov
self.sgd_step = 0
self.param_sync_iter = param_sync_iter
self.no_grad_parameters = []
self.parameters = []
@ -126,11 +128,15 @@ class SGD(object):
self.values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
def step(self, loss):
self.sgd_step += 1
ps = self.parameters
gs = jt.grad(loss, ps)
if jt.mpi:
for g in gs:
g.assign(g.mpi_all_reduce("mean"))
if self.sgd_step%param_sync_iter==0:
for p in ps:
p.assign(p.mpi_all_reduce("mean"))
for p, g, v in zip(ps, gs, self.values):
dp = p * self.weight_decay + g
v.assign(self.momentum * v + dp * (1 - self.dampening))
@ -150,13 +156,14 @@ class Adam(object):
optimizer = nn.Adam(model.parameters(), lr)
optimizer.step(loss)
"""
def __init__(self, parameters, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0):
def __init__(self, parameters, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0, param_sync_iter=10000):
self.lr = lr
self.eps = eps
self.betas = betas
# self.weight_decay = weight_decay
assert weight_decay==0, "weight_decay is not supported yet"
self.adam_step = 0
self.param_sync_iter = param_sync_iter
self.no_grad_parameters = []
self.parameters = []
@ -173,12 +180,15 @@ class Adam(object):
self.m.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
def step(self, loss):
self.adam_step += 1
ps = self.parameters
gs = jt.grad(loss, ps)
if jt.mpi:
for g in gs:
g.assign(g.mpi_all_reduce("mean"))
self.adam_step += 1
if self.adam_step%param_sync_iter==0:
for p in ps:
p.assign(p.mpi_all_reduce("mean"))
n, (b0, b1) = float(self.adam_step), self.betas
for p, g, v, m in zip(ps, gs, self.values, self.m):
m.assign(b0 * m + (1-b0) * g)