update adam

This commit is contained in:
li-xl 2021-03-17 16:49:38 +08:00
parent 7e7649e345
commit 33f85afe6c
1 changed files with 4 additions and 2 deletions

View File

@ -247,8 +247,8 @@ class Adam(Optimizer):
super().__init__(params, lr)
self.eps = eps
self.betas = betas
# self.weight_decay = weight_decay
assert weight_decay==0, "weight_decay is not supported yet"
self.weight_decay = weight_decay
# assert weight_decay==0, "weight_decay is not supported yet"
# initialize required arguments for each param_groups
for pg in self.param_groups:
@ -274,9 +274,11 @@ class Adam(Optimizer):
# get arguments from each param_groups
lr = pg.get("lr", self.lr)
eps = pg.get("eps", self.eps)
weight_decay = pg.get("weight_decay", self.weight_decay)
b0, b1 = pg.get("betas", self.betas)
for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]):
if p.is_stop_grad(): continue
g = p * weight_decay + g
m.update(b0 * m + (1-b0) * g)
v.update(b1 * v + (1-b1) * g * g)
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)