mirror of https://github.com/Jittor/Jittor
auto sync param
This commit is contained in:
parent
c58f3fc3b7
commit
225ddd137f
|
@ -105,12 +105,14 @@ class SGD(object):
|
|||
optimizer = nn.SGD(model.parameters(), lr)
|
||||
optimizer.step(loss)
|
||||
"""
|
||||
def __init__(self, parameters, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False):
|
||||
def __init__(self, parameters, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False, param_sync_iter=10000):
|
||||
self.lr = lr
|
||||
self.momentum = momentum
|
||||
self.weight_decay = weight_decay
|
||||
self.dampening = dampening
|
||||
self.nesterov = nesterov
|
||||
self.sgd_step = 0
|
||||
self.param_sync_iter = param_sync_iter
|
||||
|
||||
self.no_grad_parameters = []
|
||||
self.parameters = []
|
||||
|
@ -126,11 +128,15 @@ class SGD(object):
|
|||
self.values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.sgd_step += 1
|
||||
ps = self.parameters
|
||||
gs = jt.grad(loss, ps)
|
||||
if jt.mpi:
|
||||
for g in gs:
|
||||
g.assign(g.mpi_all_reduce("mean"))
|
||||
if self.sgd_step%param_sync_iter==0:
|
||||
for p in ps:
|
||||
p.assign(p.mpi_all_reduce("mean"))
|
||||
for p, g, v in zip(ps, gs, self.values):
|
||||
dp = p * self.weight_decay + g
|
||||
v.assign(self.momentum * v + dp * (1 - self.dampening))
|
||||
|
@ -150,13 +156,14 @@ class Adam(object):
|
|||
optimizer = nn.Adam(model.parameters(), lr)
|
||||
optimizer.step(loss)
|
||||
"""
|
||||
def __init__(self, parameters, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0):
|
||||
def __init__(self, parameters, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0, param_sync_iter=10000):
|
||||
self.lr = lr
|
||||
self.eps = eps
|
||||
self.betas = betas
|
||||
# self.weight_decay = weight_decay
|
||||
assert weight_decay==0, "weight_decay is not supported yet"
|
||||
self.adam_step = 0
|
||||
self.param_sync_iter = param_sync_iter
|
||||
|
||||
self.no_grad_parameters = []
|
||||
self.parameters = []
|
||||
|
@ -173,12 +180,15 @@ class Adam(object):
|
|||
self.m.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.adam_step += 1
|
||||
ps = self.parameters
|
||||
gs = jt.grad(loss, ps)
|
||||
if jt.mpi:
|
||||
for g in gs:
|
||||
g.assign(g.mpi_all_reduce("mean"))
|
||||
self.adam_step += 1
|
||||
if self.adam_step%param_sync_iter==0:
|
||||
for p in ps:
|
||||
p.assign(p.mpi_all_reduce("mean"))
|
||||
n, (b0, b1) = float(self.adam_step), self.betas
|
||||
for p, g, v, m in zip(ps, gs, self.values, self.m):
|
||||
m.assign(b0 * m + (1-b0) * g)
|
||||
|
|
Loading…
Reference in New Issue