Merge branch 'cjld' of https://github.com/Jittor/jittor into polish_conv_tuner

This commit is contained in:
Dun Liang 2020-05-12 18:02:59 +08:00
commit 32651c5c7d
3 changed files with 198 additions and 98 deletions

View File

@ -15,6 +15,7 @@ from jittor import init, Module
import numpy as np
import math
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
from jittor.optim import *
def matmul_transpose(a, b):
'''
@ -154,104 +155,6 @@ class BCEWithLogitsLoss(Module):
output = self.bce(output, target)
return output
class SGD(object):
""" Usage:
optimizer = nn.SGD(model.parameters(), lr)
optimizer.step(loss)
"""
def __init__(self, parameters, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False, param_sync_iter=10000):
self.lr = lr
self.momentum = momentum
self.weight_decay = weight_decay
self.dampening = dampening
self.nesterov = nesterov
self.sgd_step = 0
self.param_sync_iter = param_sync_iter
self.no_grad_parameters = []
self.parameters = []
self.values = []
for p in parameters:
# broadcast parameter from 0 node when init
if jt.mpi:
p.assign(p.mpi_broadcast().detach())
if p.is_stop_grad():
self.no_grad_parameters.append(p)
continue
self.parameters.append(p)
self.values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
def step(self, loss):
self.sgd_step += 1
ps = self.parameters
gs = jt.grad(loss, ps)
if jt.mpi:
for g in gs:
g.assign(g.mpi_all_reduce("mean"))
if self.sgd_step%self.param_sync_iter==0:
for p in ps:
p.assign(p.mpi_all_reduce("mean"))
for p, g, v in zip(ps, gs, self.values):
dp = p * self.weight_decay + g
v.assign(self.momentum * v + dp * (1 - self.dampening))
if self.nesterov:
p -= (dp + self.momentum * v) * self.lr
else:
p -= v * self.lr
# detach with the prev graph to reduce memory consumption
p.detach_inplace()
# sync all no grad parameters, such as
# moving_mean and moving_var in batch_norm
# sync such parameters to reduce memory consumption
jt.sync(self.no_grad_parameters)
class Adam(object):
""" Usage:
optimizer = nn.Adam(model.parameters(), lr)
optimizer.step(loss)
"""
def __init__(self, parameters, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0, param_sync_iter=10000):
self.lr = lr
self.eps = eps
self.betas = betas
# self.weight_decay = weight_decay
assert weight_decay==0, "weight_decay is not supported yet"
self.adam_step = 0
self.param_sync_iter = param_sync_iter
self.no_grad_parameters = []
self.parameters = []
self.values = []
self.m = []
for p in parameters:
if jt.mpi:
p.assign(p.mpi_broadcast().detach())
if p.is_stop_grad():
self.no_grad_parameters.append(p)
continue
self.parameters.append(p)
self.values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
self.m.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
def step(self, loss):
self.adam_step += 1
ps = self.parameters
gs = jt.grad(loss, ps)
if jt.mpi:
for g in gs:
g.assign(g.mpi_all_reduce("mean"))
if self.adam_step%self.param_sync_iter==0:
for p in ps:
p.assign(p.mpi_all_reduce("mean"))
n, (b0, b1) = float(self.adam_step), self.betas
for p, g, v, m in zip(ps, gs, self.values, self.m):
m.assign(b0 * m + (1-b0) * g)
v.assign(b1 * v + (1-b1) * g * g)
step_size = self.lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
p -= m * step_size / (jt.sqrt(v) + self.eps)
p.detach_inplace()
jt.sync(self.no_grad_parameters)
def softmax(x, dim = None):
if dim is None:
x = (x - x.max()).exp()

170
python/jittor/optim.py Normal file
View File

@ -0,0 +1,170 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Guowei Yang <471184555@qq.com>
# Guoye Yang <498731903@qq.com>
# Wenyang Zhou <576825820@qq.com>
# Meng-Hao Guo <guomenghao1997@gmail.com>
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import numpy as np
class Optimizer(object):
""" Basic class of Optimizer.
Example:
```
optimizer = nn.SGD(model.parameters(), lr)
optimizer.step(loss)
```
"""
def __init__(self, params, lr, param_sync_iter=10000):
self.param_groups = []
self.lr = lr
self.param_sync_iter = param_sync_iter
assert len(params) > 0, "Length of parameters should not be zero"
if not isinstance(params[0], dict):
params = [{'params': params}]
for pg in params:
assert isinstance(pg, dict)
self.param_groups.append(pg)
self.n_step = 0
def pre_step(self, loss):
""" something should be done before step,
such as calc gradients, mpi sync, and so on.
Example:
```
class MyOptimizer(Optimizer):
def step(self, loss):
self.post_step(loss)
...
```
"""
# clean prev grads
params = []
params_has_grad = []
for pg in self.param_groups:
pg["grads"] = [None] * len(pg['params'])
for p in pg['params']:
params.append(p)
if not p.is_stop_grad():
params_has_grad.append(p)
# sync params, reduce computing graph size
jt.sync(params)
# get gradient
grads = jt.grad(loss, params_has_grad)
# sync grads and model if in mpi
if jt.mpi:
for g in grads:
g.assign(g.mpi_all_reduce("mean"))
if self.n_step % self.param_sync_iter == 0:
for p in params:
p.assign(p.mpi_all_reduce("mean"))
self.n_step += 1
# set up grads in param_groups
pid = 0
for pg in self.param_groups:
pg_grads = pg["grads"]
for i, p in enumerate(pg['params']):
if not p.is_stop_grad():
pg_grads[i] = grads[pid]
pid += 1
def step(self, loss):
self.pre_step(loss)
for pg in self.param_groups:
lr = pg.get("lr", self.lr)
for p, g in zip(pg["params"], pg["grads"]):
if p.is_stop_grad(): continue
p -= g * lr
# detach with the prev graph to reduce memory consumption
p.detach_inplace()
class SGD(Optimizer):
""" SGD Optimizer.
Example:
```
optimizer = nn.SGD(model.parameters(), lr, momentum=0.9)
optimizer.step(loss)
```
"""
def __init__(self, params, lr, momentum=0, weight_decay=0, dampening=0, nesterov=False):
super().__init__(params, lr)
self.momentum = momentum
self.weight_decay = weight_decay
self.dampening = dampening
self.nesterov = nesterov
# initialize required arguments
for pg in self.param_groups:
values = pg["values"] = []
for p in pg["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
def step(self, loss):
self.pre_step(loss)
for pg in self.param_groups:
# get arguments from each param_groups
lr = pg.get("lr", self.lr)
momentum = pg.get("momentum", self.momentum)
weight_decay = pg.get("weight_decay", self.weight_decay)
dampening = pg.get("dampening", self.dampening)
nesterov = pg.get("nesterov", self.nesterov)
# optimize main body
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
dp = p * weight_decay + g
v.assign(momentum * v + dp * (1 - dampening))
if nesterov:
p -= (dp + momentum * v) * lr
else:
p -= v * lr
p.detach_inplace()
class Adam(Optimizer):
""" Adam Optimizer.
Example:
```
optimizer = nn.Adam(model.parameters(), lr, eps=1e-8, betas=(0.9, 0.999))
optimizer.step(loss)
```
"""
def __init__(self, params, lr, eps=1e-8, betas=(0.9, 0.999), weight_decay=0):
super().__init__(params, lr)
self.eps = eps
self.betas = betas
# self.weight_decay = weight_decay
assert weight_decay==0, "weight_decay is not supported yet"
# initialize required arguments for each param_groups
for pg in self.param_groups:
values = pg["values"] = []
m = pg["m"] = []
for p in pg["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
m.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
def step(self, loss):
self.pre_step(loss)
n = float(self.n_step)
for pg in self.param_groups:
# get arguments from each param_groups
lr = pg.get("lr", self.lr)
eps = pg.get("eps", self.eps)
b0, b1 = pg.get("betas", self.betas)
for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]):
m.assign(b0 * m + (1-b0) * g)
v.assign(b1 * v + (1-b1) * g * g)
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
p -= m * step_size / (jt.sqrt(v) + eps)
p.detach_inplace()

View File

@ -0,0 +1,27 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
from jittor import nn
class TestOptimizer(unittest.TestCase):
def test_param_groups(self):
pa = jt.ones((1,))
pb = jt.ones((1,))
data = jt.ones((1,))
opt = nn.SGD([
{"params":[pa], "lr":0.1},
{"params":[pb]},
], 1)
opt.step(pa*data+pb*data)
assert pa.data == 0.9 and pb.data == 0, (pa, pb)
if __name__ == "__main__":
unittest.main()