stop grad continue

This commit is contained in:
Dun Liang 2020-05-13 16:28:43 +08:00
parent d95e8c0f01
commit be141b067c
1 changed files with 2 additions and 0 deletions

View File

@ -123,6 +123,7 @@ class SGD(Optimizer):
# optimize main body
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
if p.is_stop_grad(): continue
dp = p * weight_decay + g
v.assign(momentum * v + dp * (1 - dampening))
if nesterov:
@ -163,6 +164,7 @@ class Adam(Optimizer):
eps = pg.get("eps", self.eps)
b0, b1 = pg.get("betas", self.betas)
for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]):
if p.is_stop_grad(): continue
m.assign(b0 * m + (1-b0) * g)
v.assign(b1 * v + (1-b1) * g * g)
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)