mirror of https://github.com/Jittor/Jittor
173 lines
5.9 KiB
Python
173 lines
5.9 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2020 Jittor. Authors:
|
|
# Guowei Yang <471184555@qq.com>
|
|
# Guoye Yang <498731903@qq.com>
|
|
# Wenyang Zhou <576825820@qq.com>
|
|
# Meng-Hao Guo <guomenghao1997@gmail.com>
|
|
# Dun Liang <randonlang@gmail.com>.
|
|
#
|
|
# All Rights Reserved.
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
import jittor as jt
|
|
import numpy as np
|
|
|
|
class Optimizer(object):
|
|
""" Basic class of Optimizer.
|
|
Example:
|
|
```
|
|
optimizer = nn.SGD(model.parameters(), lr)
|
|
optimizer.step(loss)
|
|
```
|
|
"""
|
|
def __init__(self, params, lr, param_sync_iter=10000):
|
|
self.param_groups = []
|
|
self.lr = lr
|
|
self.param_sync_iter = param_sync_iter
|
|
|
|
assert len(params) > 0, "Length of parameters should not be zero"
|
|
if not isinstance(params[0], dict):
|
|
params = [{'params': params}]
|
|
for pg in params:
|
|
assert isinstance(pg, dict)
|
|
self.param_groups.append(pg)
|
|
self.n_step = 0
|
|
|
|
def pre_step(self, loss):
|
|
""" something should be done before step,
|
|
such as calc gradients, mpi sync, and so on.
|
|
Example:
|
|
```
|
|
class MyOptimizer(Optimizer):
|
|
def step(self, loss):
|
|
self.post_step(loss)
|
|
...
|
|
```
|
|
"""
|
|
# clean prev grads
|
|
params = []
|
|
params_has_grad = []
|
|
for pg in self.param_groups:
|
|
pg["grads"] = [None] * len(pg['params'])
|
|
for p in pg['params']:
|
|
params.append(p)
|
|
if not p.is_stop_grad():
|
|
params_has_grad.append(p)
|
|
|
|
# sync params, reduce computing graph size
|
|
jt.sync(params)
|
|
|
|
# get gradient
|
|
grads = jt.grad(loss, params_has_grad)
|
|
|
|
# sync grads and model if in mpi
|
|
if jt.mpi:
|
|
for g in grads:
|
|
g.assign(g.mpi_all_reduce("mean"))
|
|
if self.n_step % self.param_sync_iter == 0:
|
|
for p in params:
|
|
p.assign(p.mpi_all_reduce("mean"))
|
|
self.n_step += 1
|
|
|
|
# set up grads in param_groups
|
|
pid = 0
|
|
for pg in self.param_groups:
|
|
pg_grads = pg["grads"]
|
|
for i, p in enumerate(pg['params']):
|
|
if not p.is_stop_grad():
|
|
pg_grads[i] = grads[pid]
|
|
pid += 1
|
|
|
|
def step(self, loss):
|
|
self.pre_step(loss)
|
|
for pg in self.param_groups:
|
|
lr = pg.get("lr", self.lr)
|
|
for p, g in zip(pg["params"], pg["grads"]):
|
|
if p.is_stop_grad(): continue
|
|
p -= g * lr
|
|
# detach with the prev graph to reduce memory consumption
|
|
p.detach_inplace()
|
|
|
|
|
|
class SGD(Optimizer):
|
|
""" SGD Optimizer.
|
|
Example:
|
|
```
|
|
optimizer = nn.SGD(model.parameters(), lr, momentum=0.9)
|
|
optimizer.step(loss)
|
|
```
|
|
"""
|
|
def __init__(self, params, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False):
|
|
super().__init__(params, lr)
|
|
self.momentum = momentum
|
|
self.weight_decay = weight_decay
|
|
self.dampening = dampening
|
|
self.nesterov = nesterov
|
|
|
|
# initialize required arguments
|
|
for pg in self.param_groups:
|
|
values = pg["values"] = []
|
|
for p in pg["params"]:
|
|
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
|
|
|
def step(self, loss):
|
|
self.pre_step(loss)
|
|
for pg in self.param_groups:
|
|
# get arguments from each param_groups
|
|
lr = pg.get("lr", self.lr)
|
|
momentum = pg.get("momentum", self.momentum)
|
|
weight_decay = pg.get("weight_decay", self.weight_decay)
|
|
dampening = pg.get("dampening", self.dampening)
|
|
nesterov = pg.get("nesterov", self.nesterov)
|
|
|
|
# optimize main body
|
|
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
|
|
if p.is_stop_grad(): continue
|
|
dp = p * weight_decay + g
|
|
v.assign(momentum * v + dp * (1 - dampening))
|
|
if nesterov:
|
|
p -= (dp + momentum * v) * lr
|
|
else:
|
|
p -= v * lr
|
|
p.detach_inplace()
|
|
|
|
class Adam(Optimizer):
|
|
""" Adam Optimizer.
|
|
Example:
|
|
```
|
|
optimizer = nn.Adam(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999))
|
|
optimizer.step(loss)
|
|
```
|
|
"""
|
|
def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0):
|
|
super().__init__(params, lr)
|
|
self.eps = eps
|
|
self.betas = betas
|
|
# self.weight_decay = weight_decay
|
|
assert weight_decay==0, "weight_decay is not supported yet"
|
|
|
|
# initialize required arguments for each param_groups
|
|
for pg in self.param_groups:
|
|
values = pg["values"] = []
|
|
m = pg["m"] = []
|
|
for p in pg["params"]:
|
|
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
|
m.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
|
|
|
def step(self, loss):
|
|
self.pre_step(loss)
|
|
n = float(self.n_step)
|
|
for pg in self.param_groups:
|
|
# get arguments from each param_groups
|
|
lr = pg.get("lr", self.lr)
|
|
eps = pg.get("eps", self.eps)
|
|
b0, b1 = pg.get("betas", self.betas)
|
|
for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]):
|
|
if p.is_stop_grad(): continue
|
|
m.assign(b0 * m + (1-b0) * g)
|
|
v.assign(b1 * v + (1-b1) * g * g)
|
|
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
|
|
p -= m * step_size / (jt.sqrt(v) + eps)
|
|
p.detach_inplace()
|