mirror of https://github.com/Jittor/Jittor
stop grad continue
This commit is contained in:
parent
d95e8c0f01
commit
be141b067c
|
@ -123,6 +123,7 @@ class SGD(Optimizer):
|
|||
|
||||
# optimize main body
|
||||
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
|
||||
if p.is_stop_grad(): continue
|
||||
dp = p * weight_decay + g
|
||||
v.assign(momentum * v + dp * (1 - dampening))
|
||||
if nesterov:
|
||||
|
@ -163,6 +164,7 @@ class Adam(Optimizer):
|
|||
eps = pg.get("eps", self.eps)
|
||||
b0, b1 = pg.get("betas", self.betas)
|
||||
for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]):
|
||||
if p.is_stop_grad(): continue
|
||||
m.assign(b0 * m + (1-b0) * g)
|
||||
v.assign(b1 * v + (1-b1) * g * g)
|
||||
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
|
||||
|
|
Loading…
Reference in New Issue