mirror of https://github.com/Jittor/Jittor
update adam
This commit is contained in:
parent
7e7649e345
commit
33f85afe6c
|
@ -247,8 +247,8 @@ class Adam(Optimizer):
|
|||
super().__init__(params, lr)
|
||||
self.eps = eps
|
||||
self.betas = betas
|
||||
# self.weight_decay = weight_decay
|
||||
assert weight_decay==0, "weight_decay is not supported yet"
|
||||
self.weight_decay = weight_decay
|
||||
# assert weight_decay==0, "weight_decay is not supported yet"
|
||||
|
||||
# initialize required arguments for each param_groups
|
||||
for pg in self.param_groups:
|
||||
|
@ -274,9 +274,11 @@ class Adam(Optimizer):
|
|||
# get arguments from each param_groups
|
||||
lr = pg.get("lr", self.lr)
|
||||
eps = pg.get("eps", self.eps)
|
||||
weight_decay = pg.get("weight_decay", self.weight_decay)
|
||||
b0, b1 = pg.get("betas", self.betas)
|
||||
for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]):
|
||||
if p.is_stop_grad(): continue
|
||||
g = p * weight_decay + g
|
||||
m.update(b0 * m + (1-b0) * g)
|
||||
v.update(b1 * v + (1-b1) * g * g)
|
||||
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
|
||||
|
|
Loading…
Reference in New Issue