This commit is contained in:
li-xl 2020-11-08 16:04:16 +08:00
commit 32af287461
18 changed files with 590 additions and 104 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.1.0'
__version__ = '1.2.1.1'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -1247,10 +1247,10 @@ class Sequential(Module):
ret = callback(parents, k, self, n_children)
if ret == False:
return
parents.append(self)
for k,v in self.layers.items():
parents.append(self)
v.dfs(parents, k, callback, callback_leave)
parents.pop()
parents.pop()
if callback_leave:
callback_leave(parents, k, self, n_children)
def append(self, mod):

View File

@ -0,0 +1,122 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
from jittor import Module
from jittor.models import resnet
import pickle
f32 = jt.float32
def matmul(a, b):
(n, m), k = a.shape, b.shape[-1]
a = a.broadcast([n,m,k], dims=[2])
b = b.broadcast([n,m,k], dims=[0])
return (a*b).sum(dim=1)
def relu(x):
return jt.maximum(x, 0.0)
Relu = jt.make_module(relu)
class Model(Module):
def __init__(self, input_size):
self.linear1 = Linear(input_size, 10)
self.relu1 = Relu()
self.linear2 = Linear(10, 1)
def execute(self, x):
x = self.linear1(x)
x = self.relu1(x)
return self.linear2(x)
class Linear(Module):
def __init__(self, in_features, out_features, bias=True):
self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5
self.b = jt.random((out_features,))-0.5 if bias else None
def execute(self, x):
x = matmul(x, self.w)
if self.b is not None:
return x+self.b
return x
class TestTraceVar(unittest.TestCase):
def test_simple_model(self):
with jt.flag_scope(trace_py_var=2):
model = Model(input_size=1)
batch_size = 10
x = jt.float32(np.random.rand(batch_size, 1))
y = model(x)
y.sync()
data = jt.dump_trace_data()
jt.clear_trace_data()
# with open("/tmp/simple_model.pkl", "wb") as f:
# pickle.dump(data, f)
def test_simple_model_train(self):
with jt.flag_scope(trace_py_var=2):
model = Model(input_size=1)
opt = jt.optim.SGD(model.parameters(), 0.1)
batch_size = 10
x = jt.float32(np.random.rand(batch_size, 1))
y = model(x)
opt.step(y**2)
jt.sync_all()
data = jt.dump_trace_data()
jt.clear_trace_data()
# with open("/tmp/simple_model_train.pkl", "wb") as f:
# pickle.dump(data, f)
def test_resnet(self):
with jt.flag_scope(trace_py_var=2):
resnet18 = resnet.Resnet18()
x = jt.float32(np.random.rand(2, 3, 224, 224))
y = resnet18(x)
y.sync()
data = jt.dump_trace_data()
jt.clear_trace_data()
# with open("/tmp/resnet.pkl", "wb") as f:
# pickle.dump(data, f)
def test_resnet_train(self):
with jt.flag_scope(trace_py_var=2):
resnet18 = resnet.Resnet18()
opt = jt.optim.SGD(resnet18.parameters(), 0.1)
x = jt.float32(np.random.rand(2, 3, 224, 224))
y = resnet18(x)
opt.step(y**2)
jt.sync_all()
data = jt.dump_trace_data()
jt.clear_trace_data()
# with open("/tmp/resnet_train.pkl", "wb") as f:
# pickle.dump(data, f)
def test_resnet_train_profile(self):
with jt.profile_scope(trace_py_var=1):
resnet18 = resnet.Resnet18()
opt = jt.optim.SGD(resnet18.parameters(), 0.1)
x = jt.float32(np.random.rand(2, 3, 224, 224))
y = resnet18(x)
opt.step(y**2)
jt.sync_all()
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,23 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
#
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
def fill_module_name(m, name):
ps = []
stack = []
def callback(parents, k, v, n):
stack.append(str(k))
for k2, p in v.__dict__.items():
if isinstance(p, jt.Var):
ps.append(p)
p.name(".".join(stack[1:]+[str(k2)]))
v._trace_name = str(k)
def callback_leave(parents, k, v, n):
stack.pop()
m.dfs([], name, callback, callback_leave)

View File

@ -426,6 +426,10 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
#endif
last_is_cuda = is_cuda;
op->do_run_after_prepare(jkl);
// record trace data
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var==2)) {
trace_data.record_execution(op, is_fused_op, jkl);
}
LOGvvv << "Finished Op(" >> op->name() << rid >>
"/" >> queue.size() >> ") output:" << op->outputs();
if (is_fused_op) {
@ -458,7 +462,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
display_memory_info(__FILELINE__, false, true);
// log jit_key and file location
op->do_prepare(jkl);
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc");
LOGe << "[Error] source file location:" << jit_src_path;
if (is_fused_op) {
LOGf << "Execute fused operator(" >> rid >> '/' >> queue.size() >> ")"

View File

@ -84,7 +84,6 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
if (grads.size()) {
grads[0] = make_number(1.f, loss);
assign_attrs(grads[0].ptr, loss);
registe_node_trace_grad(grads[0].ptr, loss, 0);
}
vector<pair<Node*, int64>> id_buffer;
@ -154,6 +153,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
} else
douts[i] = nullptr;
}
trace_grad_op = op;
op->grads(douts, dins);
// dump "for (Var* in : op->inputs())"
for (int i=0; i<n_i; i++,j++) {
@ -175,8 +175,8 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
auto out = id_buffer[j].first->var();
if (id<0) continue;
Var* dout = grads[id];
trace_grad_op = op;
VarPtr dvar = make_grad(op, out, dout, var, index);
registe_node_trace_grad(dvar.ptr, op, index);
if (dvar && dvar->num>=0 && var->num)
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
<< "dvar" << dvar << "var" << var;
@ -194,12 +194,12 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
}
#endif
assign_attrs(grad.ptr, var);
registe_node_trace_grad(grad.ptr, var, index);
}
}
}
}
}
trace_grad_op = nullptr;
// set zero grad
for (size_t i=0; i<results.size(); i++) {
Var* var = targets[i];
@ -211,7 +211,6 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
LOGw << "grads[">>i>>"] '">> var->name>>"' doesn't have gradient. It will be set to zero:" << var;
grad = make_number(0.f, var);
assign_attrs(grad.ptr, var);
registe_node_trace_grad(grad.ptr, var, 0);
}
}
return results;

View File

@ -50,7 +50,8 @@ struct RingBuffer {
inline ~Cond() {
// a dirty hack
// ref: https://stackoverflow.com/questions/20439404/pthread-conditions-and-process-termination
cv.__data.__wrefs = 0;
// cv.__data.__wrefs = 0;
cv.__data = {0};
pthread_cond_destroy(&cv);
}

View File

@ -120,16 +120,15 @@ struct Node {
#ifdef NODE_MEMCHECK
inline Node() {
lived_nodes[(void*)this] = ++total_node;
registe_node_trace(this);
}
inline virtual ~Node() {
lived_nodes.erase((void*)this);
unregiste_node_trace(this);
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);
}
#else
inline Node() {};
inline virtual ~Node() {};
inline virtual ~Node() { if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);};
#endif
inline Var* var() { return (Var*)this; }
inline Op* op() { return (Op*)this; }

View File

@ -30,6 +30,7 @@ Op::Op() {
flags.set(NodeFlags::_var, 0);
flags.set(NodeFlags::_cpu, 1);
number_of_lived_ops++;
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this);
}
Op::~Op() {

View File

@ -125,9 +125,9 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
}
LOGvvv << "Check op needs compile:" << op;
op->do_prepare(jkl);
if (jk.empty()) continue;
if (jkl.empty()) continue;
const char* jit_key = jk.to_cstring();
const char* jit_key = jkl.to_cstring();
auto iter = jit_key_mapper.find(jit_key);
if (iter != jit_key_mapper.end()) continue;

View File

@ -62,6 +62,28 @@ unique_ptr<MemoryChecker>* load_memory_checker(string name) {
return mm;
}
extern string _get_stack_info(Op* op);
static string get_stack_info(Op* op) {
string stack_info = "stack info:\n";
if (string("fused") == op->name()) {
auto fop = (FusedOp*)op;
map<string, int> stacks;
for (Op* op : fop->ops) {
stacks[_get_stack_info(op)] = 1;
}
for (auto& kv : stacks) {
stack_info += kv.first;
stack_info += '\n';
}
return stack_info;
} else {
stack_info += _get_stack_info(op);
stack_info += '\n';
return stack_info;
}
}
void Profiler::record_and_run(
jit_op_entry_t jit_entry,
Op* op,
@ -82,6 +104,9 @@ void Profiler::record_and_run(
0, 0, 0
};
iter = profiler.records.find(key);
if (trace_py_var) {
iter->second.stack_info = get_stack_info(op);
}
}
bool is_fused = op->name() == string("fused");
int loop = (is_fused &&
@ -141,6 +166,10 @@ vector<vector<string>> Profiler::report(const string& sort_key) {
for (auto& kv : profiler.records) {
names.push_back(kv.first);
fnames.push_back(Op::get_filename_from_jit_key(kv.first, ".cc"));
if (kv.second.stack_info.size()) {
fnames.back() += '\n';
fnames.back() += kv.second.stack_info.c_str();
}
auto& kinfo = kv.second;
order.push_back(order.size());
// do not count relay op time

View File

@ -7,6 +7,7 @@
#include "common.h"
#include "profiler/cache_info.h"
#include "op_compiler.h"
#include "misc/cstr.h"
namespace jittor {
@ -23,6 +24,7 @@ struct Profiler {
uint64_t compute_total;
// cache test info
unique_ptr<CacheInfo> cache_info;
cstr stack_info;
void update(int c, uint64_t t, uint64_t in, uint64_t out, uint64_t comp) {
count += 1<<c;

View File

@ -1,83 +1,333 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <Python.h>
#include <frameobject.h>
#include "pyjt/py_obj_holder.h"
#include "pyjt/py_converter.h"
#include "pybind/py_var_tracer.h"
#include "misc/str_utils.h"
#include "op.h"
#include "var.h"
namespace py = pybind11;
using namespace pybind11::literals;
#include "fused_op.h"
namespace jittor {
DEFINE_FLAG(int, trace_py_var, 0, "Trace py stack max depth for debug.");
Op* trace_grad_op = nullptr;
unordered_map<const Node*, string> trace_data;
TraceData trace_data;
int64 cnt = 0;
void __registe_node_trace(Node* node) {
auto py_stacks =
py::module::import("traceback")
.attr("extract_stack")(nullptr, trace_py_var);
auto len = py_stacks.attr("__len__")().cast<int>();
string info;
for (int i=0; i<len; i++) {
auto py_stack = py_stacks.attr("__getitem__")(i);
auto filename = py_stack.attr("filename").cast<string>();
if (len==1)
info += split(filename, "/").back();
else {
info += "\n ";
info += filename;
}
info += ':';
info += py_stack.attr("name").cast<string>();
info += ':';
info += S(py_stack.attr("lineno").cast<int>());
info += ':';
info += py_stack.attr("line").cast<string>();
}
trace_data[node] = info;
static PyObject* my_import(const char* module_name, const char* attr) {
// LOGir << module_name << attr;
PyObjHolder a(PyImport_ImportModule(module_name));
PyObjHolder b(PyObject_GetAttrString(a.obj, attr));
// LOGir << "Done";
return b.obj;
}
void __unregiste_node_trace(Node* node) {
trace_data.erase(node);
}
static PyObject* find_obj_name(PyFrameObject* f, PyObject* obj, const char* default_name="_model") {
auto co = f->f_code;
auto map = co->co_varnames;
void __registe_node_trace_grad(Node* g, Node* node, int x_index) {
if (!g) return;
string& gname = trace_data.at(g);
string name = "grad(";
if (startswith(gname, "grad("))
return;
if (!node->is_var()) {
name += node->op()->name_ex();
name += ':';
name += S(x_index);
}
name += ":" + gname;
name += "):";
name += trace_data.at(node);
gname = name;
std::function<void(Node*)> dfs = [&] (Node* node) {
for (Node* i : node->inputs()) {
string& iname = trace_data[i];
if (iname.find("__init__.py:grad:") != string::npos && !startswith(iname, "grad(")) {
iname = name;
dfs(i);
auto fast = f->f_localsplus;
auto j = PyTuple_GET_SIZE(map);
if (j > co->co_nlocals)
j = co->co_nlocals;
if (co->co_nlocals) {
for (int i=0; i<j; i++) {
if (fast[i] == obj) {
auto s = PyTuple_GET_ITEM(map, i);
Py_INCREF(s);
return s;
}
}
};
dfs(g);
}
auto ncells = PyTuple_GET_SIZE(co->co_cellvars);
auto nfreevars = PyTuple_GET_SIZE(co->co_freevars);
if (ncells || nfreevars) {
for (int i=0; i<ncells; i++) {
if (fast[i+co->co_nlocals] == obj) {
auto s = PyTuple_GET_ITEM(co->co_cellvars, i);
Py_INCREF(s);
return s;
}
}
for (int i=0; i<nfreevars; i++) {
if (fast[i+co->co_nlocals+ncells] == obj) {
auto s = PyTuple_GET_ITEM(co->co_freevars, i);
Py_INCREF(s);
return s;
}
}
}
// LOGw << "not found name" << map << co->co_cellvars << co->co_freevars << (PyObject*)f;
return PyUnicode_FromString(default_name);
}
void __print_node_trace(const Node* node, std::ostream& os) {
if (trace_data.count(node))
os << '{' << trace_data.at(node) << '}';
static string to_string(PyObject* obj) {
Py_ssize_t size;
const char* s = PyUnicode_AsUTF8AndSize(obj, &size);
return string(s, size);
}
static vector<Stack> get_stack_info() {
// cnt ++;
// if (cnt % 100 != 0) return {};
vector<Stack> stacks;
static auto getframe = my_import("sys", "_getframe");
static auto jt_module = my_import("jittor", "Module");
static auto jt_optimizer = my_import("jittor.optim", "Optimizer");
static auto fill_module_name = my_import("jittor.utils.tracer", "fill_module_name");
static auto _trace_name = PyUnicode_FromString("_trace_name");
PyObjHolder ret(PyObject_CallFunctionObjArgs(getframe, nullptr));
auto frame = (PyFrameObject*)ret.obj;
int n=0;
while (frame) n++, frame = frame->f_back;
PyFrameObject* frames[n];
frame = (PyFrameObject*)ret.obj;
int i=n;
while (i) frames[--i] = frame, frame = frame->f_back;
PyObject* prev_obj = nullptr;
for (int i=0; i<n; i++) {
auto f = frames[i];
if (Py_SIZE(f->f_code->co_varnames)) {
auto fast = f->f_localsplus;
auto obj = fast[0];
if (obj == prev_obj) continue;
prev_obj = obj;
if (obj == nullptr)
// normal function first argument is null
continue;
auto tp_mro = obj->ob_type->tp_mro;
auto base_type = PyTuple_GET_ITEM(tp_mro, Py_SIZE(tp_mro)-2);
auto prev_f = i? frames[i-1] : f;
if (base_type == jt_optimizer) {
PyObjHolder ret(find_obj_name(f->f_back, obj, "_opt"));
stacks.emplace_back(Stack{
to_string(ret.obj),
string(obj->ob_type->tp_name),
to_string(prev_f->f_code->co_filename),
(int)PyFrame_GetLineNumber(prev_f)});
break;
}
if (base_type != jt_module)
continue;
PyObjHolder ret;
_PyObject_LookupAttr(obj, _trace_name, &ret.obj);
string scope_name;
if (!ret.obj) {
// find base name
auto co_name = to_string(f->f_code->co_name);
if (co_name == "__init__") {
scope_name = string(obj->ob_type->tp_name) + "_init";
} else
if (co_name == "__call__") {
if (i) {
ret.assign(find_obj_name(f->f_back, obj));
scope_name = to_string(ret.obj);
} else {
ret.assign(PyUnicode_FromString("_model"));
scope_name = "_model";
}
PyObjHolder _(PyObject_CallFunctionObjArgs(
fill_module_name, obj, ret.obj, nullptr));
}
} else {
scope_name = to_string(ret.obj);
}
stacks.emplace_back(Stack{
move(scope_name),
string(obj->ob_type->tp_name),
to_string(prev_f->f_code->co_filename),
(int)PyFrame_GetLineNumber(prev_f)});
}
}
return stacks;
}
void TraceData::record_node(Node* node, bool record_stack) {
if (thread_name.size()) return;
NodeData data;
data.id = node_data_cnt++;
id_map[node] = data.id;
if (!node->is_var()) {
if (record_stack) {
if (trace_grad_op) {
auto iter = trace_data.id_map.find(trace_grad_op);
data.stacks.emplace_back(Stack{"grad", "Grad", "", 0});
if (iter != trace_data.id_map.end()) {
auto& prev_stack = trace_data.node_data[iter->second].stacks;
for (auto& s : prev_stack)
data.stacks.push_back(s);
}
} else
data.stacks = get_stack_info();
}
} else {
}
node_data[data.id] = move(data);
}
static int64 get_node_id(Node* node) {
auto iter = trace_data.id_map.find(node);
if (iter != trace_data.id_map.end())
return iter->second;
trace_data.record_node(node, false);
return trace_data.node_data_cnt - 1;
}
void TraceData::release_node(Node* node) {
if (thread_name.size()) return;
auto iter = trace_data.id_map.find(node);
if (iter == trace_data.id_map.end())
return;
auto node_id = iter->second;
id_map.erase(node);
if (trace_py_var == 1) {
node_data.erase(node_id);
}
}
void TraceData::record_exe_node(Node* node) {
auto node_id = get_node_id(node);
auto& data = node_data[node_id];
if (data.inputs.size() != node->inputs().size() || data.attrs.size() == 0) {
data.inputs.clear();
data.inputs.reserve(node->inputs().size());
for (auto i : node->inputs()) {
auto iid = get_node_id(i);
data.inputs.push_back(iid);
node_data[iid].outputs.push_back(node_id);
}
if (node->is_var()) {
auto v = node->var();
std::stringstream ss;
ss << v->shape;
data.attrs["shape"] = ss.str();
data.attrs["ndim"] = S(v->shape.size());
data.attrs["dtype"] = v->dtype().to_cstring();
data.attrs["dsize"] = S(v->dtype().dsize());
data.attrs["name"] = v->name.c_str();
data.attrs["is_var"] = "1";
} else {
auto op = node->op();
data.attrs["name"] = op->name_ex();
data.attrs["is_var"] = "0";
// TODO: add other op attrs
}
}
}
void TraceData::record_op(Op* op) {
record_exe_node(op);
for (auto o : op->outputs())
record_exe_node(o);
}
void TraceData::record_execution(Op* op, bool is_fused_op, JK& jk) {
ExecuteOpInfo& einfo = execute_op_info[execute_op_info_cnt++];
if (is_fused_op) {
FusedOp* fop = (FusedOp*)op;
for (auto op : fop->ops) {
record_op(op);
einfo.fused_ops.push_back(get_node_id(op));
}
} else {
record_op(op);
einfo.fused_ops.push_back(get_node_id(op));
}
op->do_prepare(jk);
if (jk.empty()) return;
const char* jit_key = jk.to_cstring();
auto iter = jit_key_mapper.find(jit_key);
if (iter == jit_key_mapper.end())
einfo.jit_key = jit_key;
else
einfo.jit_key = iter->second;
jit_key_map[einfo.jit_key].push_back(execute_op_info_cnt-1);
einfo.file_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
}
template<class T>
static void fill_dict(PyObject* dict, T key, PyObject* value) {
PyObjHolder k(to_py_object<T>(key));
PyObjHolder v(value);
PyDict_SetItem(dict, k.obj, value);
}
// template<>
// PyObject* to_py_object(const Stack& stack) {
// return nullptr;
// }
DEF_IS(Stack, PyObject*) to_py_object(const T& a) {
PyObjHolder dict(PyDict_New());
fill_dict(dict.obj, string("name"), to_py_object<string>(a.module_name));
fill_dict(dict.obj, string("type"), to_py_object<string>(a.module_type));
fill_dict(dict.obj, string("file_path"), to_py_object<string>(a.file_path));
fill_dict(dict.obj, string("lineno"), to_py_object<int64>(a.lineno));
return dict.release();
}
PyObject* dump_trace_data() {
PyObjHolder dict(PyDict_New());
PyObjHolder node_data(PyDict_New());
PyObjHolder execute_op_info(PyDict_New());
for (auto& kv : trace_data.node_data) {
if (kv.second.attrs.size() == 0)
continue;
PyObjHolder dict(PyDict_New());
fill_dict(dict.obj, string("id"), to_py_object(kv.second.id));
fill_dict(dict.obj, string("inputs"), to_py_object(kv.second.inputs));
fill_dict(dict.obj, string("outputs"), to_py_object(kv.second.outputs));
fill_dict(dict.obj, string("stacks"), to_py_object(kv.second.stacks));
fill_dict(dict.obj, string("attrs"), to_py_object(kv.second.attrs));
fill_dict(node_data.obj, kv.first, dict.release());
}
for (auto& kv : trace_data.execute_op_info) {
PyObjHolder dict(PyDict_New());
fill_dict(dict.obj, string("fused_ops"), to_py_object(kv.second.fused_ops));
fill_dict(dict.obj, string("jit_key"), to_py_object<string>(kv.second.jit_key));
fill_dict(dict.obj, string("file_path"), to_py_object<string>(kv.second.file_path));
fill_dict(dict.obj, string("attrs"), to_py_object(kv.second.attrs));
fill_dict(execute_op_info.obj, kv.first, dict.release());
}
fill_dict(dict.obj, string("node_data"), node_data.release());
fill_dict(dict.obj, string("execute_op_info"), execute_op_info.release());
return dict.release();
}
void clear_trace_data() {
trace_data.execute_op_info.clear();
trace_data.jit_key_map.clear();
trace_data.id_map.clear();
trace_data.node_data.clear();
}
string _get_stack_info(Op* op) {
string stack_info = "";
auto iter = trace_data.id_map.find(op);
if (iter == trace_data.id_map.end())
return stack_info;
auto node_id = iter->second;
auto iter2 = trace_data.node_data.find(node_id);
if (iter2 == trace_data.node_data.end())
return stack_info;
for (auto& stack : iter2->second.stacks)
stack_info += stack.module_name + " -> ";
return stack_info;
}
void print_node_trace(const Node* node, std::ostream& os) {
if (!node->is_var())
os << _get_stack_info((((Node*)node))->op());
}
} // jittor

View File

@ -1,5 +1,6 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
@ -9,27 +10,62 @@
namespace jittor {
DECLARE_FLAG(int, trace_py_var);
extern Op* trace_grad_op;
struct JitKey;
struct Stack {
string module_name;
string module_type;
string file_path;
int lineno;
};
#ifdef NODE_MEMCHECK
void __registe_node_trace(Node* node);
void __unregiste_node_trace(Node* node);
void __registe_node_trace_grad(Node* g, Node* node, int x_index);
void __print_node_trace(const Node* node, std::ostream& os);
struct NodeData {
int64 id;
vector<int64> inputs;
vector<int64> outputs;
vector<Stack> stacks;
/*
if is var, then contain:
is_var: 1
shape: [a,b,c,d]
ndim: x
dtype: floatxx
dsize: 4 or 8
name: xxx
if is op, then contain:
is_var: 0
name: xxx
other op attr
*/
unordered_map<string,string> attrs;
};
inline void registe_node_trace(Node* node)
{ if (trace_py_var) __registe_node_trace(node); }
inline void unregiste_node_trace(Node* node)
{ if (trace_py_var) __unregiste_node_trace(node); }
inline void registe_node_trace_grad(Node* g, Node* node, int x_index)
{ if (trace_py_var) __registe_node_trace_grad(g, node, x_index); }
inline void print_node_trace(const Node* node, std::ostream& os)
{ if (trace_py_var) __print_node_trace(node, os); }
#else
inline void registe_node_trace(Node* node) {}
inline void unregiste_node_trace(Node* node) {}
inline void registe_node_trace_grad(Node* g, Node* node, int x_index) {}
inline void print_node_trace(const Node* node, std::ostream& os) {}
#endif
struct ExecuteOpInfo {
vector<int64> fused_ops;
string jit_key;
string file_path;
unordered_map<string,string> attrs;
};
struct TraceData {
int64 node_data_cnt;
int64 execute_op_info_cnt;
unordered_map<int64, NodeData> node_data;
unordered_map<int64, ExecuteOpInfo> execute_op_info;
// jit_key map to id of execute_op_info
unordered_map<string, vector<int64>> jit_key_map;
unordered_map<Node*, int64> id_map;
void record_node(Node* node, bool record_stack=true);
void release_node(Node*);
void record_op(Op* op);
void record_exe_node(Node* node);
void record_execution(Op* op, bool is_fused_op, JitKey& jk);
};
extern TraceData trace_data;
void print_node_trace(const Node* node, std::ostream& os);
} // jittor

View File

@ -0,0 +1,18 @@
// ***************************************************************
// Copyright (c) 2020 Jittor. All Rights Reserved.
// Authors: Dun Liang <randonlang@gmail.com>.
// This file is subject to the terms and conditions defined in
// file 'LICENSE.txt', which is part of this source code package.
// ***************************************************************
#pragma once
#include <Python.h>
namespace jittor {
// @pyjt(dump_trace_data)
PyObject* dump_trace_data();
// @pyjt(clear_trace_data)
void clear_trace_data();
} // jittor

View File

@ -26,17 +26,6 @@ namespace jittor {
#define GET_PY_NONE(code) ((code), Py_INCREF(Py_None), Py_None)
inline Log& operator<<(Log& os, PyObject* objp) {
PyObjHolder repr_obj(PyObject_Repr(objp));
if (PyUnicode_CheckExact(repr_obj.obj)) {
return os << Py_TYPE(objp)->tp_name <<
PyUnicode_AsUTF8(repr_obj.obj);
} else {
return os << "unknown(" >> (void*)objp >> ")";
}
}
// string
DEF_IS(string, bool) is_type(PyObject* obj) {
return PyUnicode_CheckExact(obj);

View File

@ -34,6 +34,18 @@ struct PyObjHolder {
}
};
inline Log& operator<<(Log& os, PyObject* objp) {
PyObjHolder repr_obj(PyObject_Repr(objp));
if (PyUnicode_CheckExact(repr_obj.obj)) {
return os << Py_TYPE(objp)->tp_name <<
PyUnicode_AsUTF8(repr_obj.obj);
} else {
return os << "unknown(" >> (void*)objp >> ")";
}
}
}
#define PYJF_MODULE_INIT(name) \

View File

@ -29,6 +29,7 @@ Var::Var(NanoVector shape, NanoString dtype)
ASSERT(ns.is_dtype());
number_of_lived_vars++;
numel();
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this);
}
string Var::to_string() {