mirror of https://github.com/Jittor/Jittor
fix bug
This commit is contained in:
parent
ab533a5f45
commit
5325225e9f
|
@ -134,7 +134,7 @@ class SGD(object):
|
|||
if jt.mpi:
|
||||
for g in gs:
|
||||
g.assign(g.mpi_all_reduce("mean"))
|
||||
if self.sgd_step%param_sync_iter==0:
|
||||
if self.sgd_step%self.param_sync_iter==0:
|
||||
for p in ps:
|
||||
p.assign(p.mpi_all_reduce("mean"))
|
||||
for p, g, v in zip(ps, gs, self.values):
|
||||
|
@ -186,7 +186,7 @@ class Adam(object):
|
|||
if jt.mpi:
|
||||
for g in gs:
|
||||
g.assign(g.mpi_all_reduce("mean"))
|
||||
if self.adam_step%param_sync_iter==0:
|
||||
if self.adam_step%self.param_sync_iter==0:
|
||||
for p in ps:
|
||||
p.assign(p.mpi_all_reduce("mean"))
|
||||
n, (b0, b1) = float(self.adam_step), self.betas
|
||||
|
|
Loading…
Reference in New Issue