mirror of https://github.com/Jittor/Jittor
add gradient accumulate
This commit is contained in:
parent
68173ce507
commit
fff97a1599
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.17'
|
||||
__version__ = '1.2.2.18'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -34,6 +34,9 @@ class Optimizer(object):
|
|||
assert isinstance(pg, dict)
|
||||
self.param_groups.append(pg)
|
||||
self.n_step = 0
|
||||
# __zero_grad is a value for fast determ the grad is zero or not
|
||||
# so we can omit 0+x
|
||||
self.__zero_grad = True
|
||||
|
||||
def add_param_group(self, group):
|
||||
self.param_groups.append(group)
|
||||
|
@ -44,6 +47,9 @@ class Optimizer(object):
|
|||
return { k:v for k, v in self.__dict__.items()
|
||||
if k[0] != '_' and k not in exclude and not callable(v) }
|
||||
|
||||
def zero_grad(self):
|
||||
self.__zero_grad = True
|
||||
|
||||
def pre_step(self, loss):
|
||||
""" something should be done before step, such as calc gradients, mpi sync, and so on.
|
||||
|
||||
|
@ -58,7 +64,6 @@ class Optimizer(object):
|
|||
params = []
|
||||
params_has_grad = []
|
||||
for pg in self.param_groups:
|
||||
pg["grads"] = [None] * len(pg['params'])
|
||||
for p in pg['params']:
|
||||
params.append(p)
|
||||
if not p.is_stop_grad():
|
||||
|
@ -79,14 +84,52 @@ class Optimizer(object):
|
|||
# set up grads in param_groups
|
||||
pid = 0
|
||||
for pg in self.param_groups:
|
||||
if "grads" not in pg:
|
||||
pg["grads"] = [ jt.zeros_like(p).stop_grad().stop_fuse() for p in pg['params'] ]
|
||||
pg_grads = pg["grads"]
|
||||
for i, p in enumerate(pg['params']):
|
||||
if not p.is_stop_grad():
|
||||
# stop grad of grad
|
||||
pg_grads[i] = grads[pid].stop_grad()
|
||||
# accumulate grad and stop grad of grad
|
||||
g = grads[pid].stop_grad()
|
||||
if not self.__zero_grad:
|
||||
g = g + pg_grads[i]
|
||||
pg_grads[i].update(g)
|
||||
pid += 1
|
||||
self.__zero_grad = False
|
||||
|
||||
def backward(self, loss):
|
||||
'''
|
||||
optimize.backward(loss) is used for accumulate multiple step,
|
||||
it can be used as following:
|
||||
|
||||
Origin source code ::
|
||||
|
||||
n_iter = 10000
|
||||
batch_size = 100
|
||||
...
|
||||
for i in range(n_iter):
|
||||
...
|
||||
loss = calc_loss()
|
||||
optimizer.step(loss)
|
||||
|
||||
Accumulation version ::
|
||||
|
||||
n_iter = 10000
|
||||
batch_size = 100
|
||||
accumulation_steps = 10
|
||||
n_iter *= accumulation_steps
|
||||
batch_size //= accumulation_steps
|
||||
...
|
||||
for i in range(n_iter):
|
||||
...
|
||||
loss = calc_loss()
|
||||
# if loss is a mean across batch, we need to divide accumulation_steps
|
||||
optimizer.backward(loss / accumulation_steps)
|
||||
if (i+1) % accumulation_steps == 0:
|
||||
optimizer.step()
|
||||
|
||||
|
||||
'''
|
||||
self.pre_step(loss)
|
||||
|
||||
def step(self, loss=None):
|
||||
|
@ -97,6 +140,7 @@ class Optimizer(object):
|
|||
for p, g in zip(pg["params"], pg["grads"]):
|
||||
if p.is_stop_grad(): continue
|
||||
p.update(p - g * lr)
|
||||
self.zero_grad()
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
|
@ -146,6 +190,7 @@ class SGD(Optimizer):
|
|||
p.update(p - (dp + momentum * v) * lr)
|
||||
else:
|
||||
p.update(p - v * lr)
|
||||
self.zero_grad()
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
""" RMSprop Optimizer.
|
||||
|
@ -188,6 +233,7 @@ class RMSprop(Optimizer):
|
|||
if p.is_stop_grad(): continue
|
||||
v.update(alpha * v + (1-alpha) * g * g)
|
||||
p.update(p - lr * g / (jt.sqrt(v) + eps))
|
||||
self.zero_grad()
|
||||
|
||||
class Adam(Optimizer):
|
||||
""" Adam Optimizer.
|
||||
|
@ -235,6 +281,7 @@ class Adam(Optimizer):
|
|||
v.update(b1 * v + (1-b1) * g * g)
|
||||
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
|
||||
p.update(p - m * step_size / (jt.sqrt(v) + eps))
|
||||
self.zero_grad()
|
||||
|
||||
|
||||
class LRScheduler:
|
||||
|
@ -295,4 +342,4 @@ class LambdaLR(LRScheduler):
|
|||
|
||||
def get_lr(self):
|
||||
return [base_lr * lmbda(self.last_epoch)
|
||||
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
|
||||
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
|
||||
|
|
|
@ -0,0 +1,90 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers: Dun Liang <randonlang@gmail.com>.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
from jittor import init, Module
|
||||
import numpy as np
|
||||
from jittor.optim import Optimizer
|
||||
f32 = jt.float32
|
||||
|
||||
def matmul(a, b):
|
||||
(n, m), k = a.shape, b.shape[-1]
|
||||
a = a.broadcast([n,m,k], dims=[2])
|
||||
b = b.broadcast([n,m,k], dims=[0])
|
||||
return (a*b).sum(dim=1)
|
||||
|
||||
class Linear(Module):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5
|
||||
self.b = jt.random((out_features,))-0.5 if bias else None
|
||||
def execute(self, x):
|
||||
x = matmul(x, self.w)
|
||||
if self.b is not None:
|
||||
return x+self.b
|
||||
return x
|
||||
|
||||
def relu(x):
|
||||
return jt.maximum(x, 0.0)
|
||||
Relu = jt.make_module(relu)
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self, input_size):
|
||||
self.linear1 = Linear(input_size, 10)
|
||||
self.relu1 = Relu()
|
||||
self.linear2 = Linear(10, 1)
|
||||
def execute(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu1(x)
|
||||
return self.linear2(x)
|
||||
|
||||
class TestExample(unittest.TestCase):
|
||||
def test1(self):
|
||||
np.random.seed(0)
|
||||
jt.set_seed(3)
|
||||
n = 1000
|
||||
batch_size = 50
|
||||
base_lr = 0.05
|
||||
# tune accumulation_steps for step and batch_size
|
||||
accumulation_steps = 10
|
||||
n *= accumulation_steps
|
||||
batch_size //= accumulation_steps
|
||||
# we need to stop grad of global value to prevent memory leak
|
||||
lr = f32(base_lr).name("lr").stop_grad()
|
||||
|
||||
def get_data(n):
|
||||
for i in range(n):
|
||||
x = np.random.rand(batch_size, 1)
|
||||
y = x*x
|
||||
yield jt.float32(x), jt.float32(y)
|
||||
|
||||
model = Model(input_size=1)
|
||||
ps = model.parameters()
|
||||
opt = Optimizer(ps, lr)
|
||||
all_loss = 0
|
||||
|
||||
for i,(x,y) in enumerate(get_data(n)):
|
||||
pred_y = model(x).name("pred_y")
|
||||
loss = ((pred_y - y)**f32(2)).name("loss")
|
||||
loss_mean = loss.mean() / accumulation_steps
|
||||
all_loss += loss_mean.item()
|
||||
|
||||
opt.backward(loss_mean)
|
||||
if (i+1) % accumulation_steps == 0:
|
||||
opt.step()
|
||||
|
||||
if i>50:
|
||||
assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}"
|
||||
prev = jt.liveness_info()
|
||||
print(f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}")
|
||||
|
||||
print(all_loss)
|
||||
result = 19.8639366890402
|
||||
assert abs(all_loss - result) < 1e-3
|
||||
jt.clean()
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue