add gradient accumulate

This commit is contained in:
Dun Liang 2021-01-16 14:31:19 +08:00
parent 68173ce507
commit fff97a1599
3 changed files with 142 additions and 5 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.17'
__version__ = '1.2.2.18'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -34,6 +34,9 @@ class Optimizer(object):
assert isinstance(pg, dict)
self.param_groups.append(pg)
self.n_step = 0
# __zero_grad is a value for fast determ the grad is zero or not
# so we can omit 0+x
self.__zero_grad = True
def add_param_group(self, group):
self.param_groups.append(group)
@ -44,6 +47,9 @@ class Optimizer(object):
return { k:v for k, v in self.__dict__.items()
if k[0] != '_' and k not in exclude and not callable(v) }
def zero_grad(self):
self.__zero_grad = True
def pre_step(self, loss):
""" something should be done before step, such as calc gradients, mpi sync, and so on.
@ -58,7 +64,6 @@ class Optimizer(object):
params = []
params_has_grad = []
for pg in self.param_groups:
pg["grads"] = [None] * len(pg['params'])
for p in pg['params']:
params.append(p)
if not p.is_stop_grad():
@ -79,14 +84,52 @@ class Optimizer(object):
# set up grads in param_groups
pid = 0
for pg in self.param_groups:
if "grads" not in pg:
pg["grads"] = [ jt.zeros_like(p).stop_grad().stop_fuse() for p in pg['params'] ]
pg_grads = pg["grads"]
for i, p in enumerate(pg['params']):
if not p.is_stop_grad():
# stop grad of grad
pg_grads[i] = grads[pid].stop_grad()
# accumulate grad and stop grad of grad
g = grads[pid].stop_grad()
if not self.__zero_grad:
g = g + pg_grads[i]
pg_grads[i].update(g)
pid += 1
self.__zero_grad = False
def backward(self, loss):
'''
optimize.backward(loss) is used for accumulate multiple step,
it can be used as following:
Origin source code ::
n_iter = 10000
batch_size = 100
...
for i in range(n_iter):
...
loss = calc_loss()
optimizer.step(loss)
Accumulation version ::
n_iter = 10000
batch_size = 100
accumulation_steps = 10
n_iter *= accumulation_steps
batch_size //= accumulation_steps
...
for i in range(n_iter):
...
loss = calc_loss()
# if loss is a mean across batch, we need to divide accumulation_steps
optimizer.backward(loss / accumulation_steps)
if (i+1) % accumulation_steps == 0:
optimizer.step()
'''
self.pre_step(loss)
def step(self, loss=None):
@ -97,6 +140,7 @@ class Optimizer(object):
for p, g in zip(pg["params"], pg["grads"]):
if p.is_stop_grad(): continue
p.update(p - g * lr)
self.zero_grad()
class SGD(Optimizer):
@ -146,6 +190,7 @@ class SGD(Optimizer):
p.update(p - (dp + momentum * v) * lr)
else:
p.update(p - v * lr)
self.zero_grad()
class RMSprop(Optimizer):
""" RMSprop Optimizer.
@ -188,6 +233,7 @@ class RMSprop(Optimizer):
if p.is_stop_grad(): continue
v.update(alpha * v + (1-alpha) * g * g)
p.update(p - lr * g / (jt.sqrt(v) + eps))
self.zero_grad()
class Adam(Optimizer):
""" Adam Optimizer.
@ -235,6 +281,7 @@ class Adam(Optimizer):
v.update(b1 * v + (1-b1) * g * g)
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
p.update(p - m * step_size / (jt.sqrt(v) + eps))
self.zero_grad()
class LRScheduler:
@ -295,4 +342,4 @@ class LambdaLR(LRScheduler):
def get_lr(self):
return [base_lr * lmbda(self.last_epoch)
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)]

View File

@ -0,0 +1,90 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers: Dun Liang <randonlang@gmail.com>.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
from jittor import init, Module
import numpy as np
from jittor.optim import Optimizer
f32 = jt.float32
def matmul(a, b):
(n, m), k = a.shape, b.shape[-1]
a = a.broadcast([n,m,k], dims=[2])
b = b.broadcast([n,m,k], dims=[0])
return (a*b).sum(dim=1)
class Linear(Module):
def __init__(self, in_features, out_features, bias=True):
self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5
self.b = jt.random((out_features,))-0.5 if bias else None
def execute(self, x):
x = matmul(x, self.w)
if self.b is not None:
return x+self.b
return x
def relu(x):
return jt.maximum(x, 0.0)
Relu = jt.make_module(relu)
class Model(Module):
def __init__(self, input_size):
self.linear1 = Linear(input_size, 10)
self.relu1 = Relu()
self.linear2 = Linear(10, 1)
def execute(self, x):
x = self.linear1(x)
x = self.relu1(x)
return self.linear2(x)
class TestExample(unittest.TestCase):
def test1(self):
np.random.seed(0)
jt.set_seed(3)
n = 1000
batch_size = 50
base_lr = 0.05
# tune accumulation_steps for step and batch_size
accumulation_steps = 10
n *= accumulation_steps
batch_size //= accumulation_steps
# we need to stop grad of global value to prevent memory leak
lr = f32(base_lr).name("lr").stop_grad()
def get_data(n):
for i in range(n):
x = np.random.rand(batch_size, 1)
y = x*x
yield jt.float32(x), jt.float32(y)
model = Model(input_size=1)
ps = model.parameters()
opt = Optimizer(ps, lr)
all_loss = 0
for i,(x,y) in enumerate(get_data(n)):
pred_y = model(x).name("pred_y")
loss = ((pred_y - y)**f32(2)).name("loss")
loss_mean = loss.mean() / accumulation_steps
all_loss += loss_mean.item()
opt.backward(loss_mean)
if (i+1) % accumulation_steps == 0:
opt.step()
if i>50:
assert prev == jt.liveness_info(), f"memory leak {prev} {jt.liveness_info()}"
prev = jt.liveness_info()
print(f"step {i}, loss = {loss_mean.data.sum()} {jt.liveness_info()}")
print(all_loss)
result = 19.8639366890402
assert abs(all_loss - result) < 1e-3
jt.clean()
if __name__ == "__main__":
unittest.main()