add update queue

This commit is contained in:
Dun Liang 2020-06-20 20:46:30 +08:00
parent db275ed436
commit 5de2aec717
12 changed files with 238 additions and 37 deletions

View File

@ -453,12 +453,12 @@ class Module:
else: else:
LOG.v(f'load parameter {key} success ...') LOG.v(f'load parameter {key} success ...')
if isinstance(params[key], np.ndarray) or isinstance(params[key], list): if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
v.assign(array(params[key])) v.update(array(params[key]))
elif isinstance(params[key], Var): elif isinstance(params[key], Var):
v.assign(params[key]) v.update(params[key])
else: else:
# assume is pytorch tensor # assume is pytorch tensor
v.assign(array(params[key].cpu().detach().numpy())) v.update(array(params[key].cpu().detach().numpy()))
if n_failed: if n_failed:
LOG.w(f"load total {len(params)} params, {n_failed} failed") LOG.w(f"load total {len(params)} params, {n_failed} failed")
@ -511,7 +511,7 @@ class Module:
def mpi_param_broadcast(self, root=0): def mpi_param_broadcast(self, root=0):
if not in_mpi: return if not in_mpi: return
for p in self.parameters(): for p in self.parameters():
p.assign(p.mpi_broadcast(root).detach()) p.update(p.mpi_broadcast(root))
def make_module(func, exec_n_args=1): def make_module(func, exec_n_args=1):
class MakeModule(Module): class MakeModule(Module):

View File

@ -191,8 +191,10 @@ class BatchNorm(Module):
xvar = x2mean-xmean*xmean xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps) norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum self.running_mean.update(self.running_mean +
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum (xmean.sum([0,2,3]) - self.running_mean) * self.momentum)
self.running_var.update(self.running_var +
(xvar.sum([0,2,3])-self.running_var)*self.momentum)
else: else:
running_mean = self.running_mean.broadcast(x, [0,2,3]) running_mean = self.running_mean.broadcast(x, [0,2,3])
running_var = self.running_var.broadcast(x, [0,2,3]) running_var = self.running_var.broadcast(x, [0,2,3])
@ -225,8 +227,10 @@ class BatchNorm1d(Module):
xvar = x2mean-xmean*xmean xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps) norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean += (xmean.sum([0])-self.running_mean)*self.momentum self.running_mean.update(self.running_mean +
self.running_var += (xvar.sum([0])-self.running_var)*self.momentum (xmean.sum([0])-self.running_mean)*self.momentum)
self.running_var.update(self.running_var +
(xvar.sum([0])-self.running_var)*self.momentum)
else: else:
running_mean = self.running_mean.broadcast(x, [0]) running_mean = self.running_mean.broadcast(x, [0])
running_var = self.running_var.broadcast(x, [0]) running_var = self.running_var.broadcast(x, [0])

View File

@ -53,9 +53,6 @@ class Optimizer(object):
params.append(p) params.append(p)
if not p.is_stop_grad(): if not p.is_stop_grad():
params_has_grad.append(p) params_has_grad.append(p)
# sync params, reduce computing graph size
jt.sync(params)
# get gradient # get gradient
grads = jt.grad(loss, params_has_grad) grads = jt.grad(loss, params_has_grad)
@ -75,7 +72,8 @@ class Optimizer(object):
pg_grads = pg["grads"] pg_grads = pg["grads"]
for i, p in enumerate(pg['params']): for i, p in enumerate(pg['params']):
if not p.is_stop_grad(): if not p.is_stop_grad():
pg_grads[i] = grads[pid] # stop grad of grad
pg_grads[i] = grads[pid].stop_grad()
pid += 1 pid += 1
def step(self, loss): def step(self, loss):
@ -84,9 +82,7 @@ class Optimizer(object):
lr = pg.get("lr", self.lr) lr = pg.get("lr", self.lr)
for p, g in zip(pg["params"], pg["grads"]): for p, g in zip(pg["params"], pg["grads"]):
if p.is_stop_grad(): continue if p.is_stop_grad(): continue
p -= g * lr p.update(p - g * lr)
# detach with the prev graph to reduce memory consumption
p.detach_inplace()
class SGD(Optimizer): class SGD(Optimizer):
@ -108,7 +104,7 @@ class SGD(Optimizer):
for pg in self.param_groups: for pg in self.param_groups:
values = pg["values"] = [] values = pg["values"] = []
for p in pg["params"]: for p in pg["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad()) values.append(jt.zeros(p.shape, p.dtype).stop_grad())
def step(self, loss): def step(self, loss):
self.pre_step(loss) self.pre_step(loss)
@ -124,12 +120,11 @@ class SGD(Optimizer):
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
if p.is_stop_grad(): continue if p.is_stop_grad(): continue
dp = p * weight_decay + g dp = p * weight_decay + g
v.assign(momentum * v + dp * (1 - dampening)) v.update(momentum * v + dp * (1 - dampening))
if nesterov: if nesterov:
p -= (dp + momentum * v) * lr p.update(p - (dp + momentum * v) * lr)
else: else:
p -= v * lr p.update(p - v * lr)
p.detach_inplace()
class RMSprop(Optimizer): class RMSprop(Optimizer):
""" RMSprop Optimizer. """ RMSprop Optimizer.
@ -152,7 +147,7 @@ class RMSprop(Optimizer):
for pg in self.param_groups: for pg in self.param_groups:
values = pg["values"] = [] values = pg["values"] = []
for p in pg["params"]: for p in pg["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad()) values.append(jt.zeros(p.shape, p.dtype).stop_grad())
def step(self, loss): def step(self, loss):
self.pre_step(loss) self.pre_step(loss)
@ -163,9 +158,8 @@ class RMSprop(Optimizer):
alpha = pg.get("alpha", self.alpha) alpha = pg.get("alpha", self.alpha)
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]): for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
if p.is_stop_grad(): continue if p.is_stop_grad(): continue
v.assign(alpha * v + (1-alpha) * g * g) v.update(alpha * v + (1-alpha) * g * g)
p -= lr * g / (jt.sqrt(v) + eps) p.update(p - lr * g / (jt.sqrt(v) + eps))
p.detach_inplace()
class Adam(Optimizer): class Adam(Optimizer):
""" Adam Optimizer. """ Adam Optimizer.
@ -187,8 +181,8 @@ class Adam(Optimizer):
values = pg["values"] = [] values = pg["values"] = []
m = pg["m"] = [] m = pg["m"] = []
for p in pg["params"]: for p in pg["params"]:
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad()) values.append(jt.zeros(p.shape, p.dtype).stop_grad())
m.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad()) m.append(jt.zeros(p.shape, p.dtype).stop_grad())
def step(self, loss): def step(self, loss):
self.pre_step(loss) self.pre_step(loss)
@ -200,8 +194,7 @@ class Adam(Optimizer):
b0, b1 = pg.get("betas", self.betas) b0, b1 = pg.get("betas", self.betas)
for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]): for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]):
if p.is_stop_grad(): continue if p.is_stop_grad(): continue
m.assign(b0 * m + (1-b0) * g) m.update(b0 * m + (1-b0) * g)
v.assign(b1 * v + (1-b1) * g * g) v.update(b1 * v + (1-b1) * g * g)
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n) step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
p -= m * step_size / (jt.sqrt(v) + eps) p.update(p - m * step_size / (jt.sqrt(v) + eps))
p.detach_inplace()

View File

@ -38,8 +38,10 @@ class FakeMpiBatchNorm(nn.Module):
xvar = x2mean-xmean*xmean xvar = x2mean-xmean*xmean
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps) norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum self.running_mean.update(self.running_mean +
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum (xmean.sum([0,2,3])-self.running_mean)*self.momentum)
self.running_var.update(self.running_var +
(xvar.sum([0,2,3])-self.running_var)*self.momentum)
else: else:
running_mean = self.running_mean.broadcast(x, [0,2,3]) running_mean = self.running_mean.broadcast(x, [0,2,3])
running_var = self.running_var.broadcast(x, [0,2,3]) running_var = self.running_var.broadcast(x, [0,2,3])

View File

@ -44,6 +44,7 @@ class TestResnet(unittest.TestCase):
# mnist dataset # mnist dataset
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \ self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
.set_attrs(batch_size=self.batch_size, shuffle=True) .set_attrs(batch_size=self.batch_size, shuffle=True)
self.train_loader.num_workers = 4
# setup random seed # setup random seed
def setup_seed(self, seed): def setup_seed(self, seed):
@ -113,10 +114,12 @@ class TestResnet(unittest.TestCase):
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000 # Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000 # Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
if jt.in_mpi: # print(jt.core.number_of_lived_vars(), mem_used)
assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars() jt.display_memory_info()
else: # if jt.in_mpi:
assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars() # assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
# else:
# assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()
jt.sync_all(True) jt.sync_all(True)
assert np.mean(loss_list[-50:])<0.3 assert np.mean(loss_list[-50:])<0.3

View File

@ -15,6 +15,7 @@
#include "mem/allocator/sfrl_allocator.h" #include "mem/allocator/sfrl_allocator.h"
#include "mem/allocator/stat_allocator.h" #include "mem/allocator/stat_allocator.h"
#include "mem/mem_info.h" #include "mem/mem_info.h"
#include "update_queue.h"
namespace jittor { namespace jittor {
@ -51,6 +52,8 @@ void display_memory_info(const char* fileline) {
log << "hold_vars:" << VarHolder::hold_vars.size() log << "hold_vars:" << VarHolder::hold_vars.size()
<< "lived_vars:" << Var::number_of_lived_vars << "lived_vars:" << Var::number_of_lived_vars
<< "lived_ops:" << Op::number_of_lived_ops >> '\n'; << "lived_ops:" << Op::number_of_lived_ops >> '\n';
log << "update queue:" << update_queue.map.size()
>> '/' >> update_queue.map.size() >> '\n';
#ifdef NODE_MEMCHECK #ifdef NODE_MEMCHECK
// get the oldest var // get the oldest var

View File

@ -26,9 +26,10 @@ struct NodeFlags {
_stop_grad=2, _stop_grad=2,
_n=3, _n=3,
// op related flags // var related flags
_force_fuse=_n+0, _force_fuse=_n+0,
_stop_fuse=_n+1, _stop_fuse=_n+1,
_in_update_queue=_n+2,
// op related flags // op related flags
// bit0: support cpu // bit0: support cpu

148
src/update_queue.cc Normal file
View File

@ -0,0 +1,148 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "update_queue.h"
#include "executor.h"
#include "node.h"
#include "var.h"
#include "var_holder.h"
namespace jittor {
/*
The update queue is designed to batch update parameters asynchronously.
It maintains several queues internally.
Each updated parameter corresponds to a queue,
and the elements in the queue represent several updates of this parameter.
When a parameter is updated,
jittor internally updates the previous parameter several times
instead of the current parameter.
update queue
jittor内部会批量更新若干次之前的
below fig shows a async update process
first iter
\ iter 0
param
a 0
b 0
c 0
d 0
second iter
\ iter 0 1
params
a 0 1
b 0 1
c 0 1
d 0 1
third iter begin
0update被执行
\ iter 0 1 2
params
a [0]1 2
b [0]1
c [0]1
d [0]1
third iter end
\ iter 0 1 2
params
a 1 2
b 1 2
c 1 2
d 1 2
update_queue_auto_flush_delay: iter更新.
update queue的提出主要是为了解决统一计算图规模持续增长lived_var不断变多
update queue optimizer负责的optim.step被调用的
optim.step进行更新
* GAN SGD generatorSGD discriminator
discriminator batch norm
* inference model.eval, SGD
batch norm
LD jt.sync_all,
sync_all
使 var.update(new_var)
update queue
*/
DEFINE_FLAG(int, update_queue_auto_flush_delay, 2, "when size of a update queue is great than this value, update queue trigger auto flush(default 2).");
UpdateQueue update_queue;
void UpdateQueue::auto_flush() {
vector<Var*> vars;
vars.reserve(queue.size());
for (auto& l : queue) {
while (l.size() && l.size() >= update_queue_auto_flush_depth) {
auto iter = l.end(); iter--;
auto v = iter->v;
vars.push_back(v);
map.erase(v);
v->flags.set(NodeFlags::_in_update_queue, 0);
l.pop_back();
}
}
LOGvv << "auto flush var size" << vars.size();
exe.run_sync(move(vars), false);
}
void UpdateQueue::push(Var* v, Var* prev) {
if (v->flags.get(NodeFlags::_in_update_queue))
return;
v->flags.set(NodeFlags::_in_update_queue);
list<list<Item>>::iterator owner;
if (prev->flags.get(NodeFlags::_in_update_queue)) {
auto iter = map.find(prev);
ASSERT(iter != map.end());
owner = iter->second->owner;
} else {
queue.emplace_front();
owner = queue.begin();
}
if (owner->size() >= update_queue_auto_flush_depth)
auto_flush();
owner->emplace_front(UpdateQueue::Item{owner, v});
map[v] = owner->begin();
// if total size of update queue is too big,
// force sync all
if (map.size() > 100000)
sync_all();
}
void UpdateQueue::pop(Var* v) {
auto iter = map.find(v);
iter->second->owner->erase(iter->second);
map.erase(iter);
}
} // jittor

27
src/update_queue.h Normal file
View File

@ -0,0 +1,27 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
namespace jittor {
struct UpdateQueue {
struct Item {
list<list<Item>>::iterator owner;
Var* v;
};
list<list<Item>> queue;
unordered_map<Var*, list<Item>::iterator> map;
void push(Var* v, Var* prev);
void pop(Var* v);
void auto_flush();
};
extern UpdateQueue update_queue;
} // jittor

View File

@ -9,6 +9,7 @@
#include "op.h" #include "op.h"
#include "mem/allocator.h" #include "mem/allocator.h"
#include "pybind/py_var_tracer.h" #include "pybind/py_var_tracer.h"
#include "update_queue.h"
namespace jittor { namespace jittor {
@ -30,6 +31,8 @@ Var::~Var() {
if (mem_ptr != nullptr) if (mem_ptr != nullptr)
allocator->free(mem_ptr, size, allocation); allocator->free(mem_ptr, size, allocation);
number_of_lived_vars--; number_of_lived_vars--;
if (flags.get(NodeFlags::_in_update_queue))
update_queue.pop(this);
} }
string Var::to_string() { string Var::to_string() {

View File

@ -12,6 +12,7 @@
#include "var.h" #include "var.h"
#include "executor.h" #include "executor.h"
#include "graph.h" #include "graph.h"
#include "update_queue.h"
namespace jittor { namespace jittor {
@ -75,6 +76,13 @@ VarHolder* VarHolder::assign(VarHolder* v) {
return this; return this;
} }
VarHolder* VarHolder::update(VarHolder* v) {
auto dv = jittor::detach(v->var);
update_queue.push(dv.ptr, var);
*this = move(dv);
return this;
}
extern Executor exe; extern Executor exe;
void VarHolder::sync(bool device_sync) { void VarHolder::sync(bool device_sync) {

View File

@ -43,6 +43,15 @@ struct VarHolder {
// @attrs(return_self) // @attrs(return_self)
VarHolder* assign(VarHolder* v); VarHolder* assign(VarHolder* v);
/* update parameter and global variable,
different from assign, it will
stop grad between origin var and assigned var, and
will update in the background
*/
// @pyjt(update)
// @attrs(return_self)
VarHolder* update(VarHolder* v);
// @pyjt(swap) // @pyjt(swap)
// @attrs(return_self) // @attrs(return_self)
inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; }; inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; };