add node_order control execute order

This commit is contained in:
Dun Liang 2022-05-06 12:57:47 +08:00
parent 3e6fb4cad8
commit 88bb84255f
22 changed files with 191 additions and 306 deletions

View File

@ -1526,7 +1526,8 @@ def jittor_exit():
elif hooks.exception is not None:
pass
else:
core.sync_all(True)
pass
# core.sync_all(True)
core.cleanup()
atexit.register(jittor_exit)

View File

@ -1146,7 +1146,10 @@ if os.name == 'nt':
return cmd
if ' -O' not in cc_flags:
opt_flags += " -O2 "
if os.environ.get("debug", "0") == "1":
opt_flags += " -O0 "
else:
opt_flags += " -O2 "
kernel_opt_flags += " -Ofast "
lto_flags = ""
if os.environ.get("enable_lto") == "1":

View File

@ -98,16 +98,39 @@ class Optimizer(object):
def zero_grad(self):
self.__zero_grad = True
def pre_step(self, loss, retain_graph=False):
""" something should be done before step, such as calc gradients, mpi sync, and so on.
def backward(self, loss, retain_graph=False):
'''
optimize.backward(loss) is used for accumulate multiple step,
it can be used as following:
Example::
Origin source code ::
class MyOptimizer(Optimizer):
def step(self, loss):
self.post_step(loss)
...
"""
n_iter = 10000
batch_size = 100
...
for i in range(n_iter):
...
loss = calc_loss()
optimizer.step(loss)
Accumulation version ::
n_iter = 10000
batch_size = 100
accumulation_steps = 10
n_iter *= accumulation_steps
batch_size //= accumulation_steps
...
for i in range(n_iter):
...
loss = calc_loss()
# if loss is a mean across batch, we need to divide accumulation_steps
optimizer.backward(loss / accumulation_steps)
if (i+1) % accumulation_steps == 0:
optimizer.step()
'''
# clean prev grads
params = []
params_has_grad = []
@ -117,6 +140,9 @@ class Optimizer(object):
if not p.is_stop_grad():
params_has_grad.append(p)
# sync prev params
jt.sync(params_has_grad)
# get gradient
grads = jt.grad(loss, params_has_grad, retain_graph)
@ -153,50 +179,44 @@ class Optimizer(object):
pid += 1
self.__zero_grad = False
def backward(self, loss, retain_graph=False):
'''
optimize.backward(loss) is used for accumulate multiple step,
it can be used as following:
def pre_step(self, loss, retain_graph=False):
""" something should be done before step, such as calc gradients, mpi sync, and so on.
Origin source code ::
Example::
n_iter = 10000
batch_size = 100
...
for i in range(n_iter):
...
loss = calc_loss()
optimizer.step(loss)
class MyOptimizer(Optimizer):
def step(self, loss):
self.pre_step(loss)
...
self.post_step(loss)
"""
if loss is not None:
self.backward(loss, retain_graph)
jt.flags.node_order = 1
Accumulation version ::
def post_step(self):
""" something should be done before step, such as zero grad, and so on.
n_iter = 10000
batch_size = 100
accumulation_steps = 10
n_iter *= accumulation_steps
batch_size //= accumulation_steps
...
for i in range(n_iter):
...
loss = calc_loss()
# if loss is a mean across batch, we need to divide accumulation_steps
optimizer.backward(loss / accumulation_steps)
if (i+1) % accumulation_steps == 0:
optimizer.step()
Example::
class MyOptimizer(Optimizer):
def step(self, loss):
self.pre_step(loss)
...
self.post_step(loss)
"""
jt.flags.node_order = 0
self.zero_grad()
'''
self.pre_step(loss, retain_graph)
def step(self, loss=None, retain_graph=False):
if loss is not None:
self.pre_step(loss, retain_graph)
self.pre_step(loss, retain_graph)
for pg in self.param_groups:
lr = pg.get("lr", self.lr)
for p, g in zip(pg["params"], pg["grads"]):
if p.is_stop_grad(): continue
p.update(p - g * lr)
self.zero_grad()
self.post_step()
def _build_grad_map(self):
_grad_map = {}
@ -255,9 +275,9 @@ class SGD(Optimizer):
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
self.param_groups.append(group)
def step(self, loss=None):
if loss is not None:
self.pre_step(loss)
def step(self, loss=None, retain_graph=False):
self.pre_step(loss, retain_graph=False)
jt.flags.node_order = 1
for pg in self.param_groups:
# get arguments from each param_groups
lr = pg.get("lr", self.lr)
@ -275,7 +295,7 @@ class SGD(Optimizer):
p.update(p - (dp + momentum * v) * lr)
else:
p.update(p - v * lr)
self.zero_grad()
self.post_step()
class RMSprop(Optimizer):
""" RMSprop Optimizer.
@ -306,9 +326,8 @@ class RMSprop(Optimizer):
values.append(jt.zeros(p.shape, p.dtype).stop_grad())
self.param_groups.append(group)
def step(self, loss=None):
if loss is not None:
self.pre_step(loss)
def step(self, loss=None, retain_graph=False):
self.pre_step(loss, retain_graph)
for pg in self.param_groups:
# get arguments from each param_groups
lr = pg.get("lr", self.lr)
@ -318,7 +337,7 @@ class RMSprop(Optimizer):
if p.is_stop_grad(): continue
v.update(alpha * v + (1-alpha) * g * g)
p.update(p - lr * g / (jt.sqrt(v) + eps))
self.zero_grad()
self.post_step()
class Adam(Optimizer):
""" Adam Optimizer.
@ -351,10 +370,10 @@ class Adam(Optimizer):
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
self.param_groups.append(group)
def step(self, loss=None):
if loss is not None:
self.pre_step(loss)
def step(self, loss=None, retain_graph=False):
self.pre_step(loss, retain_graph)
n = float(self.n_step)
jt.flags.node_order = 1
for pg in self.param_groups:
# get arguments from each param_groups
lr = pg.get("lr", self.lr)
@ -368,7 +387,7 @@ class Adam(Optimizer):
v.update(b1 * v + (1-b1) * g * g)
step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n)
p.update(p - m * step_size / (jt.sqrt(v) + eps))
self.zero_grad()
self.post_step()
class AdamW(Optimizer):
@ -402,9 +421,8 @@ class AdamW(Optimizer):
m.append(jt.zeros(p.shape, p.dtype).stop_grad())
self.param_groups.append(group)
def step(self, loss=None):
if loss is not None:
self.pre_step(loss)
def step(self, loss=None, retain_graph=False):
self.pre_step(loss, retain_graph)
n = float(self.n_step)
for pg in self.param_groups:
# get arguments from each param_groups
@ -422,7 +440,7 @@ class AdamW(Optimizer):
denom = jt.sqrt(v) / jt.sqrt(bias_correction2) + eps
step_size = lr / bias_correction1
p.update(p - step_size * m / denom)
self.zero_grad()
self.post_step()
class LRScheduler:

View File

@ -9,6 +9,7 @@
// ***************************************************************
#include <algorithm>
#include <functional>
#include <queue>
#ifdef HAS_CUDA
#include <cuda_runtime.h>
#include "helper_cuda.h"
@ -209,11 +210,13 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
// get all nodes need to be executed
int need_opt = 0;
auto t = ++Node::tflag_count;
int64 max_id = 0;
for (Var* v : vars)
if (!v->is_finished() && v->tflag != t) {
v->tflag = t;
start_var_num++;
bfs_q.push_back(v);
max_id = std::max(max_id, v->id);
}
for (int i=0; i<bfs_q.size(); i++) {
auto node = bfs_q[i];
@ -225,12 +228,14 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
bfs_q.push_back(i.node);
}
// this var has been fetched
if (node->flags.get(NodeFlags::_fetch)) {
if (weak_sync || node->flags.get(NodeFlags::_fetch)) {
for (auto& n : node->_outputs) {
// if not in queue and is fetch op
if (n.node->tflag != t &&
n.node->pending_liveness &&
!n.node->is_finished() &&
n.node->flags.get(NodeFlags::_fetch)) {
(n.node->id <= max_id ||
n.node->flags.get(NodeFlags::_fetch))) {
n.node->tflag = t;
need_opt += n.node->flags.get(NodeFlags::_has_gopt);
bfs_q.push_back(n.node);
@ -330,7 +335,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
{
// queue.clear();
#ifndef JT_bfs_executor
map<int64, int> p_queue;
std::priority_queue<pair<int64,int64>> p_queue;
#endif
for (int root : roots) {
for (int i=root; i>=0; i=next[i]) {
@ -349,7 +354,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
queue.push_back(root);
#else
if (deps[root] == 0)
p_queue[ops[root]->id] = root;
p_queue.emplace(-ops[root]->order(), root);
#endif
}
#ifdef JT_bfs_executor
@ -361,8 +366,8 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
#ifdef JT_bfs_executor
int op_id = queue[s];
#else
int op_id = p_queue.begin()->second;
p_queue.erase(p_queue.begin());
int op_id = p_queue.top().second;
p_queue.pop();
queue.push_back(op_id);
#endif
for (int i=op_id; i>=0; i=next[i]) {
@ -382,7 +387,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync, bool weak_sync) {
queue.push_back(op2_id);
#else
if (deps[op2_id] == 0)
p_queue[op2->id] = op2_id;
p_queue.emplace(-op2->order(), op2_id);
#endif
}
}

View File

@ -14,7 +14,6 @@ namespace jittor {
DEFINE_FLAG(int, check_graph, 0, "Unify graph sanity check.");
EXTERN_LIB unordered_map<void*, int64> lived_nodes;
template <typename T>
string ss_convert(T x) {

View File

@ -29,7 +29,6 @@
#include "mem/allocator/stat_allocator.h"
#include "mem/allocator/temp_allocator.h"
#include "mem/mem_info.h"
#include "update_queue.h"
#include "executor.h"
namespace jittor {
@ -67,8 +66,6 @@ void display_memory_info(const char* fileline, bool dump_var, bool red_color) {
log << "hold_vars:" << hold_vars.size()
<< "lived_vars:" << Var::number_of_lived_vars
<< "lived_ops:" << Op::number_of_lived_ops >> '\n';
log << "update queue:" << update_queue.queue.size()
>> '/' >> update_queue.map.size() >> '\n';
#ifdef NODE_MEMCHECK
// get the oldest var

View File

@ -13,9 +13,14 @@
namespace jittor {
EXTERN_LIB unordered_map<void*, int64> lived_nodes;
EXTERN_LIB unordered_map<int64, Node*> lived_nodes_id;
EXTERN_LIB int64 total_node;
EXTERN_LIB int64 nt;
EXTERN_LIB vector<Node*> free_buffer;
EXTERN_LIB uint8 node_order;
inline static Node* get_node(int64 id)
{ return lived_nodes_id.count(id) ? lived_nodes_id[id] : nullptr; }
struct NodeFlags {
typedef uint32 nf_t;
@ -29,12 +34,14 @@ struct NodeFlags {
_stop_grad=2,
// bit3: is fetch
_fetch=3,
_n=4,
// bit4: node order low
_node_order_low=4,
_node_order_high=5,
_n=6,
// var related flags
_force_fuse=_n+0,
_stop_fuse=_n+1,
_in_update_queue=_n+2,
_needed_by_backward=_n+3,
_out_hint=_n+4,
@ -129,19 +136,28 @@ struct Node {
list<input_t> _inputs;
list<output_t> _outputs;
#ifdef NODE_MEMCHECK
inline Node() {
lived_nodes[(void*)this] = id = ++total_node;
int64 order() {
if (flags.get(NodeFlags::_node_order_low)) return 0;
if (flags.get(NodeFlags::_node_order_high)) return 1ll<<60;
return id;
}
inline virtual ~Node() {
inline Node() {
id = ++total_node;
#ifdef NODE_MEMCHECK
lived_nodes_id[id] = this;
lived_nodes[(void*)this] = id;
#endif
flags.set(NodeFlags::_node_order_low, node_order, 2);
}
inline virtual ~Node() {
#ifdef NODE_MEMCHECK
lived_nodes_id.erase(id);
lived_nodes.erase((void*)this);
#endif
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);
}
#else
inline Node() { id = ++total_node; };
inline virtual ~Node() { if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);};
#endif
inline Var* var() { return (Var*)this; }
inline Op* op() { return (Op*)this; }
inline Node* node() { return this; }
@ -155,13 +171,6 @@ struct Node {
#endif
}
void memcheck_all_exist() const;
inline int64 __id() const {
#ifdef NODE_MEMCHECK
return lived_nodes.at((void*)this);
#else
return 0;
#endif
}
// release from counter and memory checker
void __release();
#define CHECK_NODE_EXIST(node) \

View File

@ -298,9 +298,6 @@ std::ostream& operator<<(std::ostream& os, const Op* op) {
os << "->" << (void*)v;
}
os << ')';
#ifdef NODE_MEMCHECK
os << '<' << op->__id() << '>';
#endif
if (trace_py_var) {
os << '{';
print_node_trace(op, os);

View File

@ -113,8 +113,6 @@ VarPtr BroadcastToOp::grad(Var* out, Var* dout, Var* v, int v_index) {
if (v_index==1) return nullptr;
if (bcast_mask==0) return dout;
VarPtr dv = make_reduce(dout, ns_add, bcast_mask, keepdims_mask);
if (dv->shape.size() != v->shape.size())
dv->shape = v->shape;
return dv;
}
@ -150,7 +148,7 @@ void BroadcastToOp::infer_shape() {
}
auto mask = ((xshape==1 && (yshape!=1 || !bx))&1) << i;
bcast_mask |= mask;
keepdims_mask |= mask;
if (bx) keepdims_mask |= mask;
int64 zs;
if ((xshape == 1 || yshape == 1) && (xshape != yshape)) {
zs = xshape * yshape;

View File

@ -39,6 +39,9 @@ CodeOp::CodeOp(NanoVector shape, NanoString dtype, vector<Var*>&& inputs,
if (_outputs[0]->num < 0) {
check_vary_shape(_outputs[0]->shape);
}
if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0) {
flags.set(NodeFlags::_manual_set_vnbb);
}
}
@ -60,6 +63,8 @@ CodeOp::CodeOp(
check_vary_shape(_outputs[i]->shape);
}
}
if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0)
flags.set(NodeFlags::_manual_set_vnbb);
}
CodeOp::CodeOp(
@ -81,6 +86,8 @@ CodeOp::CodeOp(
TODO: vary shape not allowed in direct output
*/
}
if (this->cuda_grad_src.size() == 0 && this->cpu_grad_src.size() == 0)
flags.set(NodeFlags::_manual_set_vnbb);
}

View File

@ -245,8 +245,7 @@ void TraceData::record_node(Node* node, bool record_stack) {
}
} else {
}
if (node->__id())
data.attrs["__id"] = S(node->__id());
data.attrs["__id"] = S(node->id);
data.attrs["is_var"] = node->is_var() ? "1" : "0";
data.attrs["name"] = "unname";
node_data[data.id] = move(data);

View File

@ -1,152 +0,0 @@
// ***************************************************************
// Copyright (c) 2022 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include "update_queue.h"
#include "executor.h"
#include "node.h"
#include "var.h"
#include "var_holder.h"
namespace jittor {
/*
The update queue is designed to batch update parameters asynchronously.
It maintains several queues internally.
Each updated parameter corresponds to a queue,
and the elements in the queue represent several updates of this parameter.
When a parameter is updated,
jittor internally updates the previous parameter several times
instead of the current parameter.
update queue
jittor内部会批量更新若干次之前的
below fig shows a async update process
first iter
\ iter 0
param
a 0
b 0
c 0
d 0
second iter
\ iter 0 1
params
a 0 1
b 0 1
c 0 1
d 0 1
third iter begin
0update被执行
\ iter 0 1 2
params
a [0]1 2
b [0]1
c [0]1
d [0]1
third iter end
\ iter 0 1 2
params
a 1 2
b 1 2
c 1 2
d 1 2
update_queue_auto_flush_delay: iter更新.
update queue的提出主要是为了解决统一计算图规模持续增长lived_var不断变多
update queue optimizer负责的optim.step被调用的
optim.step进行更新
* GAN SGD generatorSGD discriminator
discriminator batch norm
* inference model.eval, SGD
batch norm
LD jt.sync_all,
sync_all
使 var.update(new_var)
update queue
*/
DEFINE_FLAG(int, update_queue_auto_flush_delay, 2, "when size of a update queue is great than this value, update queue trigger auto flush(default 2).");
UpdateQueue update_queue;
void UpdateQueue::auto_flush() {
vector<Var*> vars;
vars.reserve(queue.size());
for (auto& l : queue) {
while (l.size() && l.size() >= update_queue_auto_flush_delay) {
auto iter = l.end(); iter--;
auto v = iter->v;
vars.push_back(v);
map.erase(v);
v->flags.set(NodeFlags::_in_update_queue, 0);
l.pop_back();
}
}
LOGvv << "auto flush var size" << vars.size();
exe.run_sync(move(vars), false);
}
void UpdateQueue::push(Var* v, Var* prev) {
if (v->flags.get(NodeFlags::_in_update_queue))
return;
list<list<Item>>::iterator owner;
if (prev->flags.get(NodeFlags::_in_update_queue)) {
auto iter = map.find(prev);
ASSERT(iter != map.end());
owner = iter->second->owner;
} else {
queue.emplace_front();
owner = queue.begin();
}
if (owner->size() >= update_queue_auto_flush_delay) {
auto_flush();
}
v->flags.set(NodeFlags::_in_update_queue);
owner->emplace_front(UpdateQueue::Item{owner, v});
map[v] = owner->begin();
// if total size of update queue is too big,
// force sync all
if (map.size() > 100000)
sync_all();
}
void UpdateQueue::pop(Var* v) {
auto iter = map.find(v);
iter->second->owner->erase(iter->second);
if (iter->second->owner->size() == 0)
queue.erase(iter->second->owner);
map.erase(iter);
}
} // jittor

View File

@ -1,28 +0,0 @@
// ***************************************************************
// Copyright (c) 2022 Jittor. All Rights Reserved.
// Maintainers: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include "common.h"
namespace jittor {
struct UpdateQueue {
struct Item {
list<list<Item>>::iterator owner;
Var* v;
};
list<list<Item>> queue;
unordered_map<Var*, list<Item>::iterator> map;
void push(Var* v, Var* prev);
void pop(Var* v);
void auto_flush();
};
EXTERN_LIB UpdateQueue update_queue;
} // jittor

View File

@ -69,7 +69,7 @@ constexpr char yellow[] = "\x1b[1;33m";
inline static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) {
if (level == 'i') {
if (level == 'i' || level == 'I') {
if (verbose == 0) color_begin = "\x1b[1;32m"; else
if (verbose < 10) color_begin = "\x1b[1;32m"; else
if (verbose < 100) color_begin = "\x1b[1;32m"; else
@ -90,7 +90,7 @@ constexpr char red[] = "\033[38;5;1m";
constexpr char yellow[] = "\033[38;5;3m";
inline static void get_color(char level, int verbose, const char*& color_begin, const char*& color_end) {
if (level == 'i') {
if (level == 'i' || level == 'I') {
if (verbose == 0) color_begin = "\033[38;5;2m"; else
if (verbose < 10) color_begin = "\033[38;5;250m"; else
if (verbose < 100) color_begin = "\033[38;5;244m"; else
@ -280,9 +280,10 @@ bool check_vlog(const char* fileline, int verbose);
#define LOGrrrr LOGvvvv >> jittor::red
#define LOGyyyy LOGvvvv >> jittor::yellow
#define LOGir LOGi >> jittor::red
#define LOGig LOGi >> jittor::green
#define LOGiy LOGi >> jittor::yellow
#define LOGI jittor::LogVoidify() && jittor::Log(__FILELINE__, 'I', 0)
#define LOGir LOGI >> jittor::red
#define LOGig LOGI >> jittor::green
#define LOGiy LOGI >> jittor::yellow
void system_with_check(const char* cmd, const char* cwd=nullptr);

View File

@ -10,7 +10,6 @@
#include "op.h"
#include "mem/allocator.h"
#include "pybind/py_var_tracer.h"
#include "update_queue.h"
namespace jittor {
@ -22,6 +21,7 @@ DEFINE_FLAG(bool, no_grad, 0,
"No grad for all jittor Var creation");
DEFINE_FLAG(bool, no_fuse, 0,
"No fusion optimization for all jittor Var creation");
DEFINE_FLAG(uint8, node_order, 0, "id prior");
// TODO: fuse multiple flags
DEFINE_FLAG(int, amp_reg, 0, "Auto mixed-precision control registers, bit 0: prefer 32; bit 1: prefer 16; bit 2: keep reduce type; bit 3 keep white list type; bit 4: array like op prefer too");
@ -101,11 +101,8 @@ std::ostream& operator<<(std::ostream& os, const Var& var) {
<< ":s" << var.is_finished()
<< ":n" << var.flags.get(NodeFlags::_needed_by_backward)
<< ','
<< var.dtype().to_cstring() << ',' << var.name << ',' << std::hex <<(uint64)var.mem_ptr
<< var.dtype().to_cstring() << ',' << var.name << ',' << std::hex <<(uint64)var.mem_ptr << std::dec
<< ')' << var.shape;
#ifdef NODE_MEMCHECK
os << '<' << var.__id() << '>';
#endif
if (trace_py_var) {
os << '{';
print_node_trace(&var, os);

View File

@ -13,7 +13,6 @@
#include "var.h"
#include "executor.h"
#include "graph.h"
#include "update_queue.h"
#include "mem/allocator/cuda_dual_allocator.h"
#include "ops/op_register.h"
@ -27,7 +26,7 @@ list<VarHolder*>::iterator sync_ptr = hold_vars.end();
void add_hold_vars(VarHolder* self) {
hold_vars.push_front(self);
self->iter = hold_vars.begin();
if (lazy_execution) return;
if (lazy_execution && Op::number_of_lived_ops < 100000) return;
auto v = self->var;
for (int i=0; i<5; i++) {
auto op = v->input();
@ -115,19 +114,15 @@ VarHolder* VarHolder::assign(VarHolder* v) {
}
VarHolder* VarHolder::update(VarHolder* v) {
auto dv = jittor::detach(v->var);
update_queue.push(dv.ptr, var);
*this = move(dv);
return this;
v->var->flags.set(NodeFlags::_out_hint);
return assign(v);
}
VarHolder* VarHolder::_update(VarHolder* v) {
auto dv = jittor::detach(v->var);
if (var->flags.get(NodeFlags::_in_update_queue))
update_queue.push(dv.ptr, var);
v->var->own_both_liveness();
var->release_both_liveness();
var = dv.ptr;
dv.ptr = nullptr;
var = v->var;
var->flags.set(NodeFlags::_out_hint);
return this;
}

View File

@ -204,7 +204,8 @@ struct VarHolder {
inline VarHolder* start_grad() {
if (!var->dtype().is_float())
LOGw << "cannot enable grad of a non-float value:" << var;
_update(this);
auto dvar = jittor::detach(var);
std::swap(dvar.ptr, var);
return this;
}

View File

@ -40,6 +40,13 @@ class TestACL(unittest.TestCase):
y = jt.float32(x)
np.testing.assert_allclose(x, y.numpy())
@jt.flag_scope(use_acl=1)
def test_rand(self):
a = jt.rand(10)
b = a*10
b.sync()
print(b)
def test_meminfo(self):
jt.display_memory_info()

View File

@ -194,14 +194,11 @@ class TestArray(unittest.TestCase):
np.testing.assert_allclose(a, c)
def test_scalar_fuse_unary(self):
c = jt.ones(10)
jt.sync_all()
with jt.profile_scope() as rep:
a = jt.array([1])
b = -a
a = a.clone()
b = b.clone()
jt.sync([a, b])
assert a.data == 1
assert b.data == -1
b = c-1
assert b.data[1] == 0
assert len(rep) == 2
@unittest.skipIf(not jt.has_cuda, "Cuda not found")

View File

@ -218,6 +218,32 @@ class TestCore(unittest.TestCase):
assert cnt2 == 10
assert cnt1 <= 33, cnt1
def test_node_order(self):
a = jt.nn.Sequential()
for i in range(10):
a.append(jt.nn.Linear(10,10, bias=False))
sgd = jt.optim.SGD(a.parameters(), 0.1)
jt.sync_all()
with jt.log_capture_scope(log_silent=1,
log_vprefix="exe=100") as logs:
x = jt.rand(3,10)
y = a(x)
sgd.step(y*y)
jt.sync_all()
orders = []
for l in logs:
msg = l["msg"]
if "Finished" in msg:
# print(msg)
if "weight" in msg:
assert msg.count("Var") >= 2
order = int(msg.split('fused ')[1].split("/")[0])
# print(order)
orders.append(order)
assert len(orders) == 10, orders
for i in range(10):
assert orders[i] <= 14+i*3
if __name__ == "__main__":
unittest.main()

View File

@ -85,14 +85,14 @@ class TestFusedOp(unittest.TestCase):
(hv, lv, lo))
for i in range(8):
check(0,0,0)
a = jt.array(1.0).name('a').stop_fuse()
b = (a+jt.array(1.0).name('t1').stop_fuse()).name('b')
c = (b+jt.array(1.0).name('t2').stop_fuse()).name('c')
a = jt.array([1.0,1.0]).name('a').stop_fuse()
b = (a+jt.array([1.0,1.0]).name('t1').stop_fuse()).name('b')
c = (b+jt.array([1.0,1.0]).name('t2').stop_fuse()).name('c')
check(3,5,5)
graph = jt.dump_all_graphs()
# for n in graph.nodes_info:
# print(n)
self.assertEqual(c.data, 3)
np.testing.assert_allclose(c.data, [3,3])
graph2 = jt.dump_all_graphs()
print("check", i)
for n in graph2.nodes_info:

View File

@ -83,6 +83,14 @@ class TestTransposeOp(unittest.TestCase):
b = a.transpose((1,0))
b.sync()
@unittest.skipIf(not jt.compiler.has_cuda, "No CUDA found")
@jt.flag_scope(use_cuda=1)
def test_cutt_bug(self):
a = jt.rand(640000,4,3)
b = a.transpose(0,2,1)
b.sync(True)
print(a.shape, b.shape)
class TestFuseTransposeOp(unittest.TestCase):