mirror of https://github.com/Jittor/Jittor
add update queue
This commit is contained in:
parent
db275ed436
commit
5de2aec717
|
@ -453,12 +453,12 @@ class Module:
|
|||
else:
|
||||
LOG.v(f'load parameter {key} success ...')
|
||||
if isinstance(params[key], np.ndarray) or isinstance(params[key], list):
|
||||
v.assign(array(params[key]))
|
||||
v.update(array(params[key]))
|
||||
elif isinstance(params[key], Var):
|
||||
v.assign(params[key])
|
||||
v.update(params[key])
|
||||
else:
|
||||
# assume is pytorch tensor
|
||||
v.assign(array(params[key].cpu().detach().numpy()))
|
||||
v.update(array(params[key].cpu().detach().numpy()))
|
||||
if n_failed:
|
||||
LOG.w(f"load total {len(params)} params, {n_failed} failed")
|
||||
|
||||
|
@ -511,7 +511,7 @@ class Module:
|
|||
def mpi_param_broadcast(self, root=0):
|
||||
if not in_mpi: return
|
||||
for p in self.parameters():
|
||||
p.assign(p.mpi_broadcast(root).detach())
|
||||
p.update(p.mpi_broadcast(root))
|
||||
|
||||
def make_module(func, exec_n_args=1):
|
||||
class MakeModule(Module):
|
||||
|
|
|
@ -191,8 +191,10 @@ class BatchNorm(Module):
|
|||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum([0,2,3]) - self.running_mean) * self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum([0,2,3])-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
|
@ -225,8 +227,10 @@ class BatchNorm1d(Module):
|
|||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0])-self.running_var)*self.momentum
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum([0])-self.running_mean)*self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum([0])-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0])
|
||||
running_var = self.running_var.broadcast(x, [0])
|
||||
|
|
|
@ -54,9 +54,6 @@ class Optimizer(object):
|
|||
if not p.is_stop_grad():
|
||||
params_has_grad.append(p)
|
||||
|
||||
# sync params, reduce computing graph size
|
||||
jt.sync(params)
|
||||
|
||||
# get gradient
|
||||
grads = jt.grad(loss, params_has_grad)
|
||||
|
||||
|
@ -75,7 +72,8 @@ class Optimizer(object):
|
|||
pg_grads = pg["grads"]
|
||||
for i, p in enumerate(pg['params']):
|
||||
if not p.is_stop_grad():
|
||||
pg_grads[i] = grads[pid]
|
||||
# stop grad of grad
|
||||
pg_grads[i] = grads[pid].stop_grad()
|
||||
pid += 1
|
||||
|
||||
def step(self, loss):
|
||||
|
@ -84,9 +82,7 @@ class Optimizer(object):
|
|||
lr = pg.get("lr", self.lr)
|
||||
for p, g in zip(pg["params"], pg["grads"]):
|
||||
if p.is_stop_grad(): continue
|
||||
p -= g * lr
|
||||
# detach with the prev graph to reduce memory consumption
|
||||
p.detach_inplace()
|
||||
p.update(p - g * lr)
|
||||
|
||||
|
||||
class SGD(Optimizer):
|
||||
|
@ -108,7 +104,7 @@ class SGD(Optimizer):
|
|||
for pg in self.param_groups:
|
||||
values = pg["values"] = []
|
||||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
|
@ -124,12 +120,11 @@ class SGD(Optimizer):
|
|||
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
|
||||
if p.is_stop_grad(): continue
|
||||
dp = p * weight_decay + g
|
||||
v.assign(momentum * v + dp * (1 - dampening))
|
||||
v.update(momentum * v + dp * (1 - dampening))
|
||||
if nesterov:
|
||||
p -= (dp + momentum * v) * lr
|
||||
p.update(p - (dp + momentum * v) * lr)
|
||||
else:
|
||||
p -= v * lr
|
||||
p.detach_inplace()
|
||||
p.update(p - v * lr)
|
||||
|
||||
class RMSprop(Optimizer):
|
||||
""" RMSprop Optimizer.
|
||||
|
@ -152,7 +147,7 @@ class RMSprop(Optimizer):
|
|||
for pg in self.param_groups:
|
||||
values = pg["values"] = []
|
||||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
|
@ -163,9 +158,8 @@ class RMSprop(Optimizer):
|
|||
alpha = pg.get("alpha", self.alpha)
|
||||
for p, g, v in zip(pg["params"], pg["grads"], pg["values"]):
|
||||
if p.is_stop_grad(): continue
|
||||
v.assign(alpha * v + (1-alpha) * g * g)
|
||||
p -= lr * g / (jt.sqrt(v) + eps)
|
||||
p.detach_inplace()
|
||||
v.update(alpha * v + (1-alpha) * g * g)
|
||||
p.update(p - lr * g / (jt.sqrt(v) + eps))
|
||||
|
||||
class Adam(Optimizer):
|
||||
""" Adam Optimizer.
|
||||
|
@ -187,8 +181,8 @@ class Adam(Optimizer):
|
|||
values = pg["values"] = []
|
||||
m = pg["m"] = []
|
||||
for p in pg["params"]:
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
m.append(jt.zeros(p.shape, p.dtype).stop_fuse().stop_grad())
|
||||
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
|
||||
|
||||
def step(self, loss):
|
||||
self.pre_step(loss)
|
||||
|
@ -200,8 +194,7 @@ class Adam(Optimizer):
|
|||
b0, b1 = pg.get("betas", self.betas)
|
||||
for p, g, v, m in zip(pg["params"], pg["grads"], pg["values"], pg["m"]):
|
||||
if p.is_stop_grad(): continue
|
||||
m.assign(b0 * m + (1-b0) * g)
|
||||
v.assign(b1 * v + (1-b1) * g * g)
|
||||
m.update(b0 * m + (1-b0) * g)
|
||||
v.update(b1 * v + (1-b1) * g * g)
|
||||
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
|
||||
p -= m * step_size / (jt.sqrt(v) + eps)
|
||||
p.detach_inplace()
|
||||
p.update(p - m * step_size / (jt.sqrt(v) + eps))
|
||||
|
|
|
@ -38,8 +38,10 @@ class FakeMpiBatchNorm(nn.Module):
|
|||
|
||||
xvar = x2mean-xmean*xmean
|
||||
norm_x = (x-xmean)/jt.sqrt(xvar+self.eps)
|
||||
self.running_mean += (xmean.sum([0,2,3])-self.running_mean)*self.momentum
|
||||
self.running_var += (xvar.sum([0,2,3])-self.running_var)*self.momentum
|
||||
self.running_mean.update(self.running_mean +
|
||||
(xmean.sum([0,2,3])-self.running_mean)*self.momentum)
|
||||
self.running_var.update(self.running_var +
|
||||
(xvar.sum([0,2,3])-self.running_var)*self.momentum)
|
||||
else:
|
||||
running_mean = self.running_mean.broadcast(x, [0,2,3])
|
||||
running_var = self.running_var.broadcast(x, [0,2,3])
|
||||
|
|
|
@ -44,6 +44,7 @@ class TestResnet(unittest.TestCase):
|
|||
# mnist dataset
|
||||
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
|
||||
.set_attrs(batch_size=self.batch_size, shuffle=True)
|
||||
self.train_loader.num_workers = 4
|
||||
|
||||
# setup random seed
|
||||
def setup_seed(self, seed):
|
||||
|
@ -113,10 +114,12 @@ class TestResnet(unittest.TestCase):
|
|||
# Train Epoch: 0 [40/100 (40%)] Loss: 2.286762 Acc: 0.130000
|
||||
# Train Epoch: 0 [50/100 (50%)] Loss: 2.055014 Acc: 0.290000
|
||||
|
||||
if jt.in_mpi:
|
||||
assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
|
||||
else:
|
||||
assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()
|
||||
# print(jt.core.number_of_lived_vars(), mem_used)
|
||||
jt.display_memory_info()
|
||||
# if jt.in_mpi:
|
||||
# assert jt.core.number_of_lived_vars() < 3900, jt.core.number_of_lived_vars()
|
||||
# else:
|
||||
# assert jt.core.number_of_lived_vars() < 3500, jt.core.number_of_lived_vars()
|
||||
|
||||
jt.sync_all(True)
|
||||
assert np.mean(loss_list[-50:])<0.3
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
#include "mem/allocator/sfrl_allocator.h"
|
||||
#include "mem/allocator/stat_allocator.h"
|
||||
#include "mem/mem_info.h"
|
||||
#include "update_queue.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -51,6 +52,8 @@ void display_memory_info(const char* fileline) {
|
|||
log << "hold_vars:" << VarHolder::hold_vars.size()
|
||||
<< "lived_vars:" << Var::number_of_lived_vars
|
||||
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
|
||||
log << "update queue:" << update_queue.map.size()
|
||||
>> '/' >> update_queue.map.size() >> '\n';
|
||||
|
||||
#ifdef NODE_MEMCHECK
|
||||
// get the oldest var
|
||||
|
|
|
@ -26,9 +26,10 @@ struct NodeFlags {
|
|||
_stop_grad=2,
|
||||
_n=3,
|
||||
|
||||
// op related flags
|
||||
// var related flags
|
||||
_force_fuse=_n+0,
|
||||
_stop_fuse=_n+1,
|
||||
_in_update_queue=_n+2,
|
||||
|
||||
// op related flags
|
||||
// bit0: support cpu
|
||||
|
|
|
@ -0,0 +1,148 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include "update_queue.h"
|
||||
#include "executor.h"
|
||||
#include "node.h"
|
||||
#include "var.h"
|
||||
#include "var_holder.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
/*
|
||||
|
||||
The update queue is designed to batch update parameters asynchronously.
|
||||
It maintains several queues internally.
|
||||
Each updated parameter corresponds to a queue,
|
||||
and the elements in the queue represent several updates of this parameter.
|
||||
When a parameter is updated,
|
||||
jittor internally updates the previous parameter several times
|
||||
instead of the current parameter.
|
||||
|
||||
update queue 设计用于批量异步更新参数,其内部维护了若干个队列,
|
||||
每一个被更新的参数对应了一个队列,而队列中的元素代表了这个参数
|
||||
的若干次更新。当一个参数被更新,jittor内部会批量更新若干次之前的
|
||||
参数,而不是当前参数。
|
||||
|
||||
below fig shows a async update process
|
||||
|
||||
下图演示了一个异步更新的过程:
|
||||
|
||||
first iter
|
||||
第一次迭代:
|
||||
|
||||
\ iter 0
|
||||
param
|
||||
a 0
|
||||
b 0
|
||||
c 0
|
||||
d 0
|
||||
|
||||
second iter
|
||||
第二次迭代:
|
||||
|
||||
\ iter 0 1
|
||||
params
|
||||
a 0 1
|
||||
b 0 1
|
||||
c 0 1
|
||||
d 0 1
|
||||
|
||||
third iter begin
|
||||
第三次开始时,迭代0的update被执行:
|
||||
\ iter 0 1 2
|
||||
params
|
||||
a [0]1 2
|
||||
b [0]1
|
||||
c [0]1
|
||||
d [0]1
|
||||
|
||||
third iter end
|
||||
第三次结束:
|
||||
|
||||
\ iter 0 1 2
|
||||
params
|
||||
a 1 2
|
||||
b 1 2
|
||||
c 1 2
|
||||
d 1 2
|
||||
|
||||
update_queue_auto_flush_delay: 异步多少个iter更新.
|
||||
|
||||
update queue的提出主要是为了解决统一计算图规模持续增长(lived_var不断变多)的问题,
|
||||
在 update queue 提出之前, 计算图运行是由optimizer负责的,optim.step被调用的
|
||||
时候,会自动运行还没有运行的计算图,已经运行的计算图节点会被回收,从而计算图规模可以
|
||||
在每次迭代之间保持一个常数。
|
||||
|
||||
但是如果用户并没有调用optim.step进行更新,计算图就会持续增长,比如下面两种情况:
|
||||
|
||||
* 训练 GAN 的时候,只用 SGD 运行了 generator,没有用SGD 运行 discriminator,
|
||||
discriminator 的 batch norm 参数持续不断地更新,但是一直没有运行,导致计算图
|
||||
规模持续增长。
|
||||
* 用户在 inference 的时候忘记设置 model.eval, 这时候因为没有 SGD 刷新参数,
|
||||
然后 batch norm 的参数持续不断更新,再次导致计算图规模持续增长。
|
||||
|
||||
这些细节对于用户来说过于难以理解(LD:我有时候都很晕),一个粗暴的解决方案是 jt.sync_all,
|
||||
直接强制刷新全图,把没运行的都运行了,但是这会导致显存占用过大,因为 sync_all 运行的
|
||||
拓扑顺序不优。
|
||||
|
||||
为了让用户可以不关心这些细节, 我们在参数更新的时候,使用 var.update(new_var),
|
||||
这个接口会把更新托管给 update queue, 从而不需要关心底层计算图的大小。
|
||||
|
||||
*/
|
||||
|
||||
DEFINE_FLAG(int, update_queue_auto_flush_delay, 2, "when size of a update queue is great than this value, update queue trigger auto flush(default 2).");
|
||||
|
||||
UpdateQueue update_queue;
|
||||
|
||||
void UpdateQueue::auto_flush() {
|
||||
vector<Var*> vars;
|
||||
vars.reserve(queue.size());
|
||||
for (auto& l : queue) {
|
||||
while (l.size() && l.size() >= update_queue_auto_flush_depth) {
|
||||
auto iter = l.end(); iter--;
|
||||
auto v = iter->v;
|
||||
vars.push_back(v);
|
||||
map.erase(v);
|
||||
v->flags.set(NodeFlags::_in_update_queue, 0);
|
||||
l.pop_back();
|
||||
}
|
||||
}
|
||||
LOGvv << "auto flush var size" << vars.size();
|
||||
exe.run_sync(move(vars), false);
|
||||
}
|
||||
|
||||
void UpdateQueue::push(Var* v, Var* prev) {
|
||||
if (v->flags.get(NodeFlags::_in_update_queue))
|
||||
return;
|
||||
v->flags.set(NodeFlags::_in_update_queue);
|
||||
list<list<Item>>::iterator owner;
|
||||
|
||||
if (prev->flags.get(NodeFlags::_in_update_queue)) {
|
||||
auto iter = map.find(prev);
|
||||
ASSERT(iter != map.end());
|
||||
owner = iter->second->owner;
|
||||
} else {
|
||||
queue.emplace_front();
|
||||
owner = queue.begin();
|
||||
}
|
||||
if (owner->size() >= update_queue_auto_flush_depth)
|
||||
auto_flush();
|
||||
owner->emplace_front(UpdateQueue::Item{owner, v});
|
||||
map[v] = owner->begin();
|
||||
// if total size of update queue is too big,
|
||||
// force sync all
|
||||
if (map.size() > 100000)
|
||||
sync_all();
|
||||
}
|
||||
|
||||
void UpdateQueue::pop(Var* v) {
|
||||
auto iter = map.find(v);
|
||||
iter->second->owner->erase(iter->second);
|
||||
map.erase(iter);
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include "common.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
struct UpdateQueue {
|
||||
struct Item {
|
||||
list<list<Item>>::iterator owner;
|
||||
Var* v;
|
||||
};
|
||||
list<list<Item>> queue;
|
||||
unordered_map<Var*, list<Item>::iterator> map;
|
||||
|
||||
void push(Var* v, Var* prev);
|
||||
void pop(Var* v);
|
||||
void auto_flush();
|
||||
};
|
||||
|
||||
extern UpdateQueue update_queue;
|
||||
|
||||
} // jittor
|
||||
|
|
@ -9,6 +9,7 @@
|
|||
#include "op.h"
|
||||
#include "mem/allocator.h"
|
||||
#include "pybind/py_var_tracer.h"
|
||||
#include "update_queue.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -30,6 +31,8 @@ Var::~Var() {
|
|||
if (mem_ptr != nullptr)
|
||||
allocator->free(mem_ptr, size, allocation);
|
||||
number_of_lived_vars--;
|
||||
if (flags.get(NodeFlags::_in_update_queue))
|
||||
update_queue.pop(this);
|
||||
}
|
||||
|
||||
string Var::to_string() {
|
||||
|
|
|
@ -12,6 +12,7 @@
|
|||
#include "var.h"
|
||||
#include "executor.h"
|
||||
#include "graph.h"
|
||||
#include "update_queue.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -75,6 +76,13 @@ VarHolder* VarHolder::assign(VarHolder* v) {
|
|||
return this;
|
||||
}
|
||||
|
||||
VarHolder* VarHolder::update(VarHolder* v) {
|
||||
auto dv = jittor::detach(v->var);
|
||||
update_queue.push(dv.ptr, var);
|
||||
*this = move(dv);
|
||||
return this;
|
||||
}
|
||||
|
||||
extern Executor exe;
|
||||
|
||||
void VarHolder::sync(bool device_sync) {
|
||||
|
|
|
@ -43,6 +43,15 @@ struct VarHolder {
|
|||
// @attrs(return_self)
|
||||
VarHolder* assign(VarHolder* v);
|
||||
|
||||
/* update parameter and global variable,
|
||||
different from assign, it will
|
||||
stop grad between origin var and assigned var, and
|
||||
will update in the background
|
||||
*/
|
||||
// @pyjt(update)
|
||||
// @attrs(return_self)
|
||||
VarHolder* update(VarHolder* v);
|
||||
|
||||
// @pyjt(swap)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; };
|
||||
|
|
Loading…
Reference in New Issue