mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/jittor/jittor
This commit is contained in:
commit
32af287461
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.1.0'
|
||||
__version__ = '1.2.1.1'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1247,10 +1247,10 @@ class Sequential(Module):
|
|||
ret = callback(parents, k, self, n_children)
|
||||
if ret == False:
|
||||
return
|
||||
parents.append(self)
|
||||
for k,v in self.layers.items():
|
||||
parents.append(self)
|
||||
v.dfs(parents, k, callback, callback_leave)
|
||||
parents.pop()
|
||||
parents.pop()
|
||||
if callback_leave:
|
||||
callback_leave(parents, k, self, n_children)
|
||||
def append(self, mod):
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
from jittor import Module
|
||||
from jittor.models import resnet
|
||||
import pickle
|
||||
|
||||
f32 = jt.float32
|
||||
|
||||
def matmul(a, b):
|
||||
(n, m), k = a.shape, b.shape[-1]
|
||||
a = a.broadcast([n,m,k], dims=[2])
|
||||
b = b.broadcast([n,m,k], dims=[0])
|
||||
return (a*b).sum(dim=1)
|
||||
|
||||
|
||||
def relu(x):
|
||||
return jt.maximum(x, 0.0)
|
||||
Relu = jt.make_module(relu)
|
||||
|
||||
class Model(Module):
|
||||
def __init__(self, input_size):
|
||||
self.linear1 = Linear(input_size, 10)
|
||||
self.relu1 = Relu()
|
||||
self.linear2 = Linear(10, 1)
|
||||
def execute(self, x):
|
||||
x = self.linear1(x)
|
||||
x = self.relu1(x)
|
||||
return self.linear2(x)
|
||||
|
||||
class Linear(Module):
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5
|
||||
self.b = jt.random((out_features,))-0.5 if bias else None
|
||||
def execute(self, x):
|
||||
x = matmul(x, self.w)
|
||||
if self.b is not None:
|
||||
return x+self.b
|
||||
return x
|
||||
|
||||
|
||||
class TestTraceVar(unittest.TestCase):
|
||||
def test_simple_model(self):
|
||||
with jt.flag_scope(trace_py_var=2):
|
||||
|
||||
model = Model(input_size=1)
|
||||
batch_size = 10
|
||||
x = jt.float32(np.random.rand(batch_size, 1))
|
||||
y = model(x)
|
||||
y.sync()
|
||||
|
||||
|
||||
data = jt.dump_trace_data()
|
||||
jt.clear_trace_data()
|
||||
# with open("/tmp/simple_model.pkl", "wb") as f:
|
||||
# pickle.dump(data, f)
|
||||
|
||||
def test_simple_model_train(self):
|
||||
with jt.flag_scope(trace_py_var=2):
|
||||
|
||||
model = Model(input_size=1)
|
||||
opt = jt.optim.SGD(model.parameters(), 0.1)
|
||||
|
||||
batch_size = 10
|
||||
x = jt.float32(np.random.rand(batch_size, 1))
|
||||
y = model(x)
|
||||
opt.step(y**2)
|
||||
jt.sync_all()
|
||||
|
||||
data = jt.dump_trace_data()
|
||||
jt.clear_trace_data()
|
||||
# with open("/tmp/simple_model_train.pkl", "wb") as f:
|
||||
# pickle.dump(data, f)
|
||||
|
||||
def test_resnet(self):
|
||||
with jt.flag_scope(trace_py_var=2):
|
||||
|
||||
resnet18 = resnet.Resnet18()
|
||||
x = jt.float32(np.random.rand(2, 3, 224, 224))
|
||||
y = resnet18(x)
|
||||
y.sync()
|
||||
|
||||
data = jt.dump_trace_data()
|
||||
jt.clear_trace_data()
|
||||
# with open("/tmp/resnet.pkl", "wb") as f:
|
||||
# pickle.dump(data, f)
|
||||
|
||||
def test_resnet_train(self):
|
||||
with jt.flag_scope(trace_py_var=2):
|
||||
|
||||
resnet18 = resnet.Resnet18()
|
||||
opt = jt.optim.SGD(resnet18.parameters(), 0.1)
|
||||
x = jt.float32(np.random.rand(2, 3, 224, 224))
|
||||
y = resnet18(x)
|
||||
|
||||
opt.step(y**2)
|
||||
jt.sync_all()
|
||||
|
||||
data = jt.dump_trace_data()
|
||||
jt.clear_trace_data()
|
||||
# with open("/tmp/resnet_train.pkl", "wb") as f:
|
||||
# pickle.dump(data, f)
|
||||
|
||||
def test_resnet_train_profile(self):
|
||||
with jt.profile_scope(trace_py_var=1):
|
||||
|
||||
resnet18 = resnet.Resnet18()
|
||||
opt = jt.optim.SGD(resnet18.parameters(), 0.1)
|
||||
x = jt.float32(np.random.rand(2, 3, 224, 224))
|
||||
y = resnet18(x)
|
||||
|
||||
opt.step(y**2)
|
||||
jt.sync_all()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,23 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
|
||||
def fill_module_name(m, name):
|
||||
ps = []
|
||||
stack = []
|
||||
def callback(parents, k, v, n):
|
||||
stack.append(str(k))
|
||||
for k2, p in v.__dict__.items():
|
||||
if isinstance(p, jt.Var):
|
||||
ps.append(p)
|
||||
p.name(".".join(stack[1:]+[str(k2)]))
|
||||
v._trace_name = str(k)
|
||||
def callback_leave(parents, k, v, n):
|
||||
stack.pop()
|
||||
m.dfs([], name, callback, callback_leave)
|
|
@ -426,6 +426,10 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
#endif
|
||||
last_is_cuda = is_cuda;
|
||||
op->do_run_after_prepare(jkl);
|
||||
// record trace data
|
||||
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var==2)) {
|
||||
trace_data.record_execution(op, is_fused_op, jkl);
|
||||
}
|
||||
LOGvvv << "Finished Op(" >> op->name() << rid >>
|
||||
"/" >> queue.size() >> ") output:" << op->outputs();
|
||||
if (is_fused_op) {
|
||||
|
@ -458,7 +462,7 @@ void Executor::run_sync(vector<Var*> vars, bool device_sync) {
|
|||
display_memory_info(__FILELINE__, false, true);
|
||||
// log jit_key and file location
|
||||
op->do_prepare(jkl);
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
|
||||
string jit_src_path = Op::get_filename_from_jit_key(jkl.to_cstring(), ".cc");
|
||||
LOGe << "[Error] source file location:" << jit_src_path;
|
||||
if (is_fused_op) {
|
||||
LOGf << "Execute fused operator(" >> rid >> '/' >> queue.size() >> ")"
|
||||
|
|
|
@ -84,7 +84,6 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
if (grads.size()) {
|
||||
grads[0] = make_number(1.f, loss);
|
||||
assign_attrs(grads[0].ptr, loss);
|
||||
registe_node_trace_grad(grads[0].ptr, loss, 0);
|
||||
}
|
||||
|
||||
vector<pair<Node*, int64>> id_buffer;
|
||||
|
@ -154,6 +153,7 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
} else
|
||||
douts[i] = nullptr;
|
||||
}
|
||||
trace_grad_op = op;
|
||||
op->grads(douts, dins);
|
||||
// dump "for (Var* in : op->inputs())"
|
||||
for (int i=0; i<n_i; i++,j++) {
|
||||
|
@ -175,8 +175,8 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
auto out = id_buffer[j].first->var();
|
||||
if (id<0) continue;
|
||||
Var* dout = grads[id];
|
||||
trace_grad_op = op;
|
||||
VarPtr dvar = make_grad(op, out, dout, var, index);
|
||||
registe_node_trace_grad(dvar.ptr, op, index);
|
||||
if (dvar && dvar->num>=0 && var->num)
|
||||
ASSERT(dvar->num==var->num && dvar->shape.size()==var->shape.size())
|
||||
<< "dvar" << dvar << "var" << var;
|
||||
|
@ -194,12 +194,12 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
}
|
||||
#endif
|
||||
assign_attrs(grad.ptr, var);
|
||||
registe_node_trace_grad(grad.ptr, var, index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
trace_grad_op = nullptr;
|
||||
// set zero grad
|
||||
for (size_t i=0; i<results.size(); i++) {
|
||||
Var* var = targets[i];
|
||||
|
@ -211,7 +211,6 @@ vector<VarPtr> grad(Var* loss, vector<Var*> targets) {
|
|||
LOGw << "grads[">>i>>"] '">> var->name>>"' doesn't have gradient. It will be set to zero:" << var;
|
||||
grad = make_number(0.f, var);
|
||||
assign_attrs(grad.ptr, var);
|
||||
registe_node_trace_grad(grad.ptr, var, 0);
|
||||
}
|
||||
}
|
||||
return results;
|
||||
|
|
|
@ -50,7 +50,8 @@ struct RingBuffer {
|
|||
inline ~Cond() {
|
||||
// a dirty hack
|
||||
// ref: https://stackoverflow.com/questions/20439404/pthread-conditions-and-process-termination
|
||||
cv.__data.__wrefs = 0;
|
||||
// cv.__data.__wrefs = 0;
|
||||
cv.__data = {0};
|
||||
pthread_cond_destroy(&cv);
|
||||
}
|
||||
|
||||
|
|
|
@ -120,16 +120,15 @@ struct Node {
|
|||
#ifdef NODE_MEMCHECK
|
||||
inline Node() {
|
||||
lived_nodes[(void*)this] = ++total_node;
|
||||
registe_node_trace(this);
|
||||
}
|
||||
|
||||
inline virtual ~Node() {
|
||||
lived_nodes.erase((void*)this);
|
||||
unregiste_node_trace(this);
|
||||
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);
|
||||
}
|
||||
#else
|
||||
inline Node() {};
|
||||
inline virtual ~Node() {};
|
||||
inline virtual ~Node() { if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.release_node(this);};
|
||||
#endif
|
||||
inline Var* var() { return (Var*)this; }
|
||||
inline Op* op() { return (Op*)this; }
|
||||
|
|
|
@ -30,6 +30,7 @@ Op::Op() {
|
|||
flags.set(NodeFlags::_var, 0);
|
||||
flags.set(NodeFlags::_cpu, 1);
|
||||
number_of_lived_ops++;
|
||||
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this);
|
||||
}
|
||||
|
||||
Op::~Op() {
|
||||
|
|
|
@ -125,9 +125,9 @@ void parallel_compile_all_ops(vector<int>& queue, vector<int>& range, FusedOp& f
|
|||
}
|
||||
LOGvvv << "Check op needs compile:" << op;
|
||||
op->do_prepare(jkl);
|
||||
if (jk.empty()) continue;
|
||||
if (jkl.empty()) continue;
|
||||
|
||||
const char* jit_key = jk.to_cstring();
|
||||
const char* jit_key = jkl.to_cstring();
|
||||
auto iter = jit_key_mapper.find(jit_key);
|
||||
if (iter != jit_key_mapper.end()) continue;
|
||||
|
||||
|
|
|
@ -62,6 +62,28 @@ unique_ptr<MemoryChecker>* load_memory_checker(string name) {
|
|||
return mm;
|
||||
}
|
||||
|
||||
extern string _get_stack_info(Op* op);
|
||||
|
||||
static string get_stack_info(Op* op) {
|
||||
string stack_info = "stack info:\n";
|
||||
if (string("fused") == op->name()) {
|
||||
auto fop = (FusedOp*)op;
|
||||
map<string, int> stacks;
|
||||
for (Op* op : fop->ops) {
|
||||
stacks[_get_stack_info(op)] = 1;
|
||||
}
|
||||
for (auto& kv : stacks) {
|
||||
stack_info += kv.first;
|
||||
stack_info += '\n';
|
||||
}
|
||||
return stack_info;
|
||||
} else {
|
||||
stack_info += _get_stack_info(op);
|
||||
stack_info += '\n';
|
||||
return stack_info;
|
||||
}
|
||||
}
|
||||
|
||||
void Profiler::record_and_run(
|
||||
jit_op_entry_t jit_entry,
|
||||
Op* op,
|
||||
|
@ -82,6 +104,9 @@ void Profiler::record_and_run(
|
|||
0, 0, 0
|
||||
};
|
||||
iter = profiler.records.find(key);
|
||||
if (trace_py_var) {
|
||||
iter->second.stack_info = get_stack_info(op);
|
||||
}
|
||||
}
|
||||
bool is_fused = op->name() == string("fused");
|
||||
int loop = (is_fused &&
|
||||
|
@ -141,6 +166,10 @@ vector<vector<string>> Profiler::report(const string& sort_key) {
|
|||
for (auto& kv : profiler.records) {
|
||||
names.push_back(kv.first);
|
||||
fnames.push_back(Op::get_filename_from_jit_key(kv.first, ".cc"));
|
||||
if (kv.second.stack_info.size()) {
|
||||
fnames.back() += '\n';
|
||||
fnames.back() += kv.second.stack_info.c_str();
|
||||
}
|
||||
auto& kinfo = kv.second;
|
||||
order.push_back(order.size());
|
||||
// do not count relay op time
|
||||
|
|
|
@ -7,6 +7,7 @@
|
|||
#include "common.h"
|
||||
#include "profiler/cache_info.h"
|
||||
#include "op_compiler.h"
|
||||
#include "misc/cstr.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
|
@ -23,6 +24,7 @@ struct Profiler {
|
|||
uint64_t compute_total;
|
||||
// cache test info
|
||||
unique_ptr<CacheInfo> cache_info;
|
||||
cstr stack_info;
|
||||
|
||||
void update(int c, uint64_t t, uint64_t in, uint64_t out, uint64_t comp) {
|
||||
count += 1<<c;
|
||||
|
|
|
@ -1,83 +1,333 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
#include <Python.h>
|
||||
#include <frameobject.h>
|
||||
#include "pyjt/py_obj_holder.h"
|
||||
#include "pyjt/py_converter.h"
|
||||
#include "pybind/py_var_tracer.h"
|
||||
#include "misc/str_utils.h"
|
||||
#include "op.h"
|
||||
#include "var.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
using namespace pybind11::literals;
|
||||
#include "fused_op.h"
|
||||
|
||||
namespace jittor {
|
||||
|
||||
DEFINE_FLAG(int, trace_py_var, 0, "Trace py stack max depth for debug.");
|
||||
Op* trace_grad_op = nullptr;
|
||||
|
||||
unordered_map<const Node*, string> trace_data;
|
||||
TraceData trace_data;
|
||||
int64 cnt = 0;
|
||||
|
||||
void __registe_node_trace(Node* node) {
|
||||
auto py_stacks =
|
||||
py::module::import("traceback")
|
||||
.attr("extract_stack")(nullptr, trace_py_var);
|
||||
auto len = py_stacks.attr("__len__")().cast<int>();
|
||||
string info;
|
||||
for (int i=0; i<len; i++) {
|
||||
auto py_stack = py_stacks.attr("__getitem__")(i);
|
||||
auto filename = py_stack.attr("filename").cast<string>();
|
||||
if (len==1)
|
||||
info += split(filename, "/").back();
|
||||
else {
|
||||
info += "\n ";
|
||||
info += filename;
|
||||
}
|
||||
info += ':';
|
||||
info += py_stack.attr("name").cast<string>();
|
||||
info += ':';
|
||||
info += S(py_stack.attr("lineno").cast<int>());
|
||||
info += ':';
|
||||
info += py_stack.attr("line").cast<string>();
|
||||
}
|
||||
trace_data[node] = info;
|
||||
static PyObject* my_import(const char* module_name, const char* attr) {
|
||||
// LOGir << module_name << attr;
|
||||
PyObjHolder a(PyImport_ImportModule(module_name));
|
||||
PyObjHolder b(PyObject_GetAttrString(a.obj, attr));
|
||||
// LOGir << "Done";
|
||||
return b.obj;
|
||||
}
|
||||
|
||||
void __unregiste_node_trace(Node* node) {
|
||||
trace_data.erase(node);
|
||||
}
|
||||
static PyObject* find_obj_name(PyFrameObject* f, PyObject* obj, const char* default_name="_model") {
|
||||
auto co = f->f_code;
|
||||
auto map = co->co_varnames;
|
||||
|
||||
void __registe_node_trace_grad(Node* g, Node* node, int x_index) {
|
||||
if (!g) return;
|
||||
string& gname = trace_data.at(g);
|
||||
string name = "grad(";
|
||||
if (startswith(gname, "grad("))
|
||||
return;
|
||||
if (!node->is_var()) {
|
||||
name += node->op()->name_ex();
|
||||
name += ':';
|
||||
name += S(x_index);
|
||||
}
|
||||
name += ":" + gname;
|
||||
name += "):";
|
||||
name += trace_data.at(node);
|
||||
gname = name;
|
||||
std::function<void(Node*)> dfs = [&] (Node* node) {
|
||||
for (Node* i : node->inputs()) {
|
||||
string& iname = trace_data[i];
|
||||
if (iname.find("__init__.py:grad:") != string::npos && !startswith(iname, "grad(")) {
|
||||
iname = name;
|
||||
dfs(i);
|
||||
auto fast = f->f_localsplus;
|
||||
auto j = PyTuple_GET_SIZE(map);
|
||||
if (j > co->co_nlocals)
|
||||
j = co->co_nlocals;
|
||||
if (co->co_nlocals) {
|
||||
for (int i=0; i<j; i++) {
|
||||
if (fast[i] == obj) {
|
||||
auto s = PyTuple_GET_ITEM(map, i);
|
||||
Py_INCREF(s);
|
||||
return s;
|
||||
}
|
||||
}
|
||||
};
|
||||
dfs(g);
|
||||
}
|
||||
auto ncells = PyTuple_GET_SIZE(co->co_cellvars);
|
||||
auto nfreevars = PyTuple_GET_SIZE(co->co_freevars);
|
||||
if (ncells || nfreevars) {
|
||||
for (int i=0; i<ncells; i++) {
|
||||
if (fast[i+co->co_nlocals] == obj) {
|
||||
auto s = PyTuple_GET_ITEM(co->co_cellvars, i);
|
||||
Py_INCREF(s);
|
||||
return s;
|
||||
}
|
||||
}
|
||||
for (int i=0; i<nfreevars; i++) {
|
||||
if (fast[i+co->co_nlocals+ncells] == obj) {
|
||||
auto s = PyTuple_GET_ITEM(co->co_freevars, i);
|
||||
Py_INCREF(s);
|
||||
return s;
|
||||
}
|
||||
}
|
||||
}
|
||||
// LOGw << "not found name" << map << co->co_cellvars << co->co_freevars << (PyObject*)f;
|
||||
return PyUnicode_FromString(default_name);
|
||||
}
|
||||
|
||||
void __print_node_trace(const Node* node, std::ostream& os) {
|
||||
if (trace_data.count(node))
|
||||
os << '{' << trace_data.at(node) << '}';
|
||||
static string to_string(PyObject* obj) {
|
||||
Py_ssize_t size;
|
||||
const char* s = PyUnicode_AsUTF8AndSize(obj, &size);
|
||||
return string(s, size);
|
||||
}
|
||||
|
||||
static vector<Stack> get_stack_info() {
|
||||
// cnt ++;
|
||||
// if (cnt % 100 != 0) return {};
|
||||
vector<Stack> stacks;
|
||||
static auto getframe = my_import("sys", "_getframe");
|
||||
static auto jt_module = my_import("jittor", "Module");
|
||||
static auto jt_optimizer = my_import("jittor.optim", "Optimizer");
|
||||
static auto fill_module_name = my_import("jittor.utils.tracer", "fill_module_name");
|
||||
static auto _trace_name = PyUnicode_FromString("_trace_name");
|
||||
|
||||
PyObjHolder ret(PyObject_CallFunctionObjArgs(getframe, nullptr));
|
||||
|
||||
auto frame = (PyFrameObject*)ret.obj;
|
||||
int n=0;
|
||||
while (frame) n++, frame = frame->f_back;
|
||||
PyFrameObject* frames[n];
|
||||
frame = (PyFrameObject*)ret.obj;
|
||||
int i=n;
|
||||
while (i) frames[--i] = frame, frame = frame->f_back;
|
||||
PyObject* prev_obj = nullptr;
|
||||
for (int i=0; i<n; i++) {
|
||||
auto f = frames[i];
|
||||
if (Py_SIZE(f->f_code->co_varnames)) {
|
||||
auto fast = f->f_localsplus;
|
||||
auto obj = fast[0];
|
||||
if (obj == prev_obj) continue;
|
||||
prev_obj = obj;
|
||||
if (obj == nullptr)
|
||||
// normal function first argument is null
|
||||
continue;
|
||||
auto tp_mro = obj->ob_type->tp_mro;
|
||||
auto base_type = PyTuple_GET_ITEM(tp_mro, Py_SIZE(tp_mro)-2);
|
||||
auto prev_f = i? frames[i-1] : f;
|
||||
if (base_type == jt_optimizer) {
|
||||
PyObjHolder ret(find_obj_name(f->f_back, obj, "_opt"));
|
||||
stacks.emplace_back(Stack{
|
||||
to_string(ret.obj),
|
||||
string(obj->ob_type->tp_name),
|
||||
to_string(prev_f->f_code->co_filename),
|
||||
(int)PyFrame_GetLineNumber(prev_f)});
|
||||
break;
|
||||
}
|
||||
if (base_type != jt_module)
|
||||
continue;
|
||||
PyObjHolder ret;
|
||||
_PyObject_LookupAttr(obj, _trace_name, &ret.obj);
|
||||
string scope_name;
|
||||
if (!ret.obj) {
|
||||
// find base name
|
||||
auto co_name = to_string(f->f_code->co_name);
|
||||
if (co_name == "__init__") {
|
||||
scope_name = string(obj->ob_type->tp_name) + "_init";
|
||||
} else
|
||||
if (co_name == "__call__") {
|
||||
if (i) {
|
||||
ret.assign(find_obj_name(f->f_back, obj));
|
||||
scope_name = to_string(ret.obj);
|
||||
} else {
|
||||
ret.assign(PyUnicode_FromString("_model"));
|
||||
scope_name = "_model";
|
||||
}
|
||||
PyObjHolder _(PyObject_CallFunctionObjArgs(
|
||||
fill_module_name, obj, ret.obj, nullptr));
|
||||
}
|
||||
} else {
|
||||
scope_name = to_string(ret.obj);
|
||||
}
|
||||
stacks.emplace_back(Stack{
|
||||
move(scope_name),
|
||||
string(obj->ob_type->tp_name),
|
||||
to_string(prev_f->f_code->co_filename),
|
||||
(int)PyFrame_GetLineNumber(prev_f)});
|
||||
}
|
||||
}
|
||||
return stacks;
|
||||
}
|
||||
|
||||
void TraceData::record_node(Node* node, bool record_stack) {
|
||||
if (thread_name.size()) return;
|
||||
NodeData data;
|
||||
data.id = node_data_cnt++;
|
||||
id_map[node] = data.id;
|
||||
if (!node->is_var()) {
|
||||
if (record_stack) {
|
||||
if (trace_grad_op) {
|
||||
auto iter = trace_data.id_map.find(trace_grad_op);
|
||||
data.stacks.emplace_back(Stack{"grad", "Grad", "", 0});
|
||||
if (iter != trace_data.id_map.end()) {
|
||||
auto& prev_stack = trace_data.node_data[iter->second].stacks;
|
||||
for (auto& s : prev_stack)
|
||||
data.stacks.push_back(s);
|
||||
}
|
||||
} else
|
||||
data.stacks = get_stack_info();
|
||||
}
|
||||
} else {
|
||||
}
|
||||
node_data[data.id] = move(data);
|
||||
}
|
||||
|
||||
static int64 get_node_id(Node* node) {
|
||||
auto iter = trace_data.id_map.find(node);
|
||||
if (iter != trace_data.id_map.end())
|
||||
return iter->second;
|
||||
trace_data.record_node(node, false);
|
||||
return trace_data.node_data_cnt - 1;
|
||||
}
|
||||
|
||||
void TraceData::release_node(Node* node) {
|
||||
if (thread_name.size()) return;
|
||||
auto iter = trace_data.id_map.find(node);
|
||||
if (iter == trace_data.id_map.end())
|
||||
return;
|
||||
auto node_id = iter->second;
|
||||
id_map.erase(node);
|
||||
if (trace_py_var == 1) {
|
||||
node_data.erase(node_id);
|
||||
}
|
||||
}
|
||||
|
||||
void TraceData::record_exe_node(Node* node) {
|
||||
auto node_id = get_node_id(node);
|
||||
auto& data = node_data[node_id];
|
||||
if (data.inputs.size() != node->inputs().size() || data.attrs.size() == 0) {
|
||||
data.inputs.clear();
|
||||
data.inputs.reserve(node->inputs().size());
|
||||
for (auto i : node->inputs()) {
|
||||
auto iid = get_node_id(i);
|
||||
data.inputs.push_back(iid);
|
||||
node_data[iid].outputs.push_back(node_id);
|
||||
}
|
||||
if (node->is_var()) {
|
||||
auto v = node->var();
|
||||
std::stringstream ss;
|
||||
ss << v->shape;
|
||||
data.attrs["shape"] = ss.str();
|
||||
data.attrs["ndim"] = S(v->shape.size());
|
||||
data.attrs["dtype"] = v->dtype().to_cstring();
|
||||
data.attrs["dsize"] = S(v->dtype().dsize());
|
||||
data.attrs["name"] = v->name.c_str();
|
||||
data.attrs["is_var"] = "1";
|
||||
} else {
|
||||
auto op = node->op();
|
||||
data.attrs["name"] = op->name_ex();
|
||||
data.attrs["is_var"] = "0";
|
||||
// TODO: add other op attrs
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TraceData::record_op(Op* op) {
|
||||
record_exe_node(op);
|
||||
for (auto o : op->outputs())
|
||||
record_exe_node(o);
|
||||
}
|
||||
|
||||
void TraceData::record_execution(Op* op, bool is_fused_op, JK& jk) {
|
||||
ExecuteOpInfo& einfo = execute_op_info[execute_op_info_cnt++];
|
||||
if (is_fused_op) {
|
||||
FusedOp* fop = (FusedOp*)op;
|
||||
for (auto op : fop->ops) {
|
||||
record_op(op);
|
||||
einfo.fused_ops.push_back(get_node_id(op));
|
||||
}
|
||||
} else {
|
||||
record_op(op);
|
||||
einfo.fused_ops.push_back(get_node_id(op));
|
||||
}
|
||||
op->do_prepare(jk);
|
||||
if (jk.empty()) return;
|
||||
const char* jit_key = jk.to_cstring();
|
||||
auto iter = jit_key_mapper.find(jit_key);
|
||||
if (iter == jit_key_mapper.end())
|
||||
einfo.jit_key = jit_key;
|
||||
else
|
||||
einfo.jit_key = iter->second;
|
||||
jit_key_map[einfo.jit_key].push_back(execute_op_info_cnt-1);
|
||||
einfo.file_path = Op::get_filename_from_jit_key(jk.to_cstring(), ".cc");
|
||||
}
|
||||
|
||||
template<class T>
|
||||
static void fill_dict(PyObject* dict, T key, PyObject* value) {
|
||||
PyObjHolder k(to_py_object<T>(key));
|
||||
PyObjHolder v(value);
|
||||
PyDict_SetItem(dict, k.obj, value);
|
||||
}
|
||||
|
||||
// template<>
|
||||
// PyObject* to_py_object(const Stack& stack) {
|
||||
// return nullptr;
|
||||
// }
|
||||
|
||||
DEF_IS(Stack, PyObject*) to_py_object(const T& a) {
|
||||
PyObjHolder dict(PyDict_New());
|
||||
fill_dict(dict.obj, string("name"), to_py_object<string>(a.module_name));
|
||||
fill_dict(dict.obj, string("type"), to_py_object<string>(a.module_type));
|
||||
fill_dict(dict.obj, string("file_path"), to_py_object<string>(a.file_path));
|
||||
fill_dict(dict.obj, string("lineno"), to_py_object<int64>(a.lineno));
|
||||
return dict.release();
|
||||
}
|
||||
|
||||
PyObject* dump_trace_data() {
|
||||
PyObjHolder dict(PyDict_New());
|
||||
PyObjHolder node_data(PyDict_New());
|
||||
PyObjHolder execute_op_info(PyDict_New());
|
||||
for (auto& kv : trace_data.node_data) {
|
||||
if (kv.second.attrs.size() == 0)
|
||||
continue;
|
||||
PyObjHolder dict(PyDict_New());
|
||||
fill_dict(dict.obj, string("id"), to_py_object(kv.second.id));
|
||||
fill_dict(dict.obj, string("inputs"), to_py_object(kv.second.inputs));
|
||||
fill_dict(dict.obj, string("outputs"), to_py_object(kv.second.outputs));
|
||||
fill_dict(dict.obj, string("stacks"), to_py_object(kv.second.stacks));
|
||||
fill_dict(dict.obj, string("attrs"), to_py_object(kv.second.attrs));
|
||||
fill_dict(node_data.obj, kv.first, dict.release());
|
||||
}
|
||||
for (auto& kv : trace_data.execute_op_info) {
|
||||
PyObjHolder dict(PyDict_New());
|
||||
fill_dict(dict.obj, string("fused_ops"), to_py_object(kv.second.fused_ops));
|
||||
fill_dict(dict.obj, string("jit_key"), to_py_object<string>(kv.second.jit_key));
|
||||
fill_dict(dict.obj, string("file_path"), to_py_object<string>(kv.second.file_path));
|
||||
fill_dict(dict.obj, string("attrs"), to_py_object(kv.second.attrs));
|
||||
fill_dict(execute_op_info.obj, kv.first, dict.release());
|
||||
}
|
||||
fill_dict(dict.obj, string("node_data"), node_data.release());
|
||||
fill_dict(dict.obj, string("execute_op_info"), execute_op_info.release());
|
||||
return dict.release();
|
||||
}
|
||||
|
||||
void clear_trace_data() {
|
||||
trace_data.execute_op_info.clear();
|
||||
trace_data.jit_key_map.clear();
|
||||
trace_data.id_map.clear();
|
||||
trace_data.node_data.clear();
|
||||
}
|
||||
|
||||
string _get_stack_info(Op* op) {
|
||||
string stack_info = "";
|
||||
auto iter = trace_data.id_map.find(op);
|
||||
if (iter == trace_data.id_map.end())
|
||||
return stack_info;
|
||||
auto node_id = iter->second;
|
||||
auto iter2 = trace_data.node_data.find(node_id);
|
||||
if (iter2 == trace_data.node_data.end())
|
||||
return stack_info;
|
||||
for (auto& stack : iter2->second.stacks)
|
||||
stack_info += stack.module_name + " -> ";
|
||||
return stack_info;
|
||||
}
|
||||
|
||||
void print_node_trace(const Node* node, std::ostream& os) {
|
||||
if (!node->is_var())
|
||||
os << _get_stack_info((((Node*)node))->op());
|
||||
}
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
|
@ -9,27 +10,62 @@
|
|||
namespace jittor {
|
||||
|
||||
DECLARE_FLAG(int, trace_py_var);
|
||||
extern Op* trace_grad_op;
|
||||
struct JitKey;
|
||||
|
||||
struct Stack {
|
||||
string module_name;
|
||||
string module_type;
|
||||
string file_path;
|
||||
int lineno;
|
||||
};
|
||||
|
||||
#ifdef NODE_MEMCHECK
|
||||
void __registe_node_trace(Node* node);
|
||||
void __unregiste_node_trace(Node* node);
|
||||
void __registe_node_trace_grad(Node* g, Node* node, int x_index);
|
||||
void __print_node_trace(const Node* node, std::ostream& os);
|
||||
struct NodeData {
|
||||
int64 id;
|
||||
vector<int64> inputs;
|
||||
vector<int64> outputs;
|
||||
vector<Stack> stacks;
|
||||
/*
|
||||
if is var, then contain:
|
||||
is_var: 1
|
||||
shape: [a,b,c,d]
|
||||
ndim: x
|
||||
dtype: floatxx
|
||||
dsize: 4 or 8
|
||||
name: xxx
|
||||
if is op, then contain:
|
||||
is_var: 0
|
||||
name: xxx
|
||||
other op attr
|
||||
*/
|
||||
unordered_map<string,string> attrs;
|
||||
};
|
||||
|
||||
inline void registe_node_trace(Node* node)
|
||||
{ if (trace_py_var) __registe_node_trace(node); }
|
||||
inline void unregiste_node_trace(Node* node)
|
||||
{ if (trace_py_var) __unregiste_node_trace(node); }
|
||||
inline void registe_node_trace_grad(Node* g, Node* node, int x_index)
|
||||
{ if (trace_py_var) __registe_node_trace_grad(g, node, x_index); }
|
||||
inline void print_node_trace(const Node* node, std::ostream& os)
|
||||
{ if (trace_py_var) __print_node_trace(node, os); }
|
||||
#else
|
||||
inline void registe_node_trace(Node* node) {}
|
||||
inline void unregiste_node_trace(Node* node) {}
|
||||
inline void registe_node_trace_grad(Node* g, Node* node, int x_index) {}
|
||||
inline void print_node_trace(const Node* node, std::ostream& os) {}
|
||||
#endif
|
||||
struct ExecuteOpInfo {
|
||||
vector<int64> fused_ops;
|
||||
string jit_key;
|
||||
string file_path;
|
||||
unordered_map<string,string> attrs;
|
||||
};
|
||||
|
||||
struct TraceData {
|
||||
int64 node_data_cnt;
|
||||
int64 execute_op_info_cnt;
|
||||
unordered_map<int64, NodeData> node_data;
|
||||
unordered_map<int64, ExecuteOpInfo> execute_op_info;
|
||||
// jit_key map to id of execute_op_info
|
||||
unordered_map<string, vector<int64>> jit_key_map;
|
||||
unordered_map<Node*, int64> id_map;
|
||||
|
||||
void record_node(Node* node, bool record_stack=true);
|
||||
void release_node(Node*);
|
||||
void record_op(Op* op);
|
||||
void record_exe_node(Node* node);
|
||||
void record_execution(Op* op, bool is_fused_op, JitKey& jk);
|
||||
};
|
||||
|
||||
extern TraceData trace_data;
|
||||
|
||||
void print_node_trace(const Node* node, std::ostream& os);
|
||||
|
||||
} // jittor
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
// ***************************************************************
|
||||
// Copyright (c) 2020 Jittor. All Rights Reserved.
|
||||
// Authors: Dun Liang <randonlang@gmail.com>.
|
||||
// This file is subject to the terms and conditions defined in
|
||||
// file 'LICENSE.txt', which is part of this source code package.
|
||||
// ***************************************************************
|
||||
#pragma once
|
||||
#include <Python.h>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
// @pyjt(dump_trace_data)
|
||||
PyObject* dump_trace_data();
|
||||
|
||||
// @pyjt(clear_trace_data)
|
||||
void clear_trace_data();
|
||||
|
||||
} // jittor
|
|
@ -26,17 +26,6 @@ namespace jittor {
|
|||
|
||||
#define GET_PY_NONE(code) ((code), Py_INCREF(Py_None), Py_None)
|
||||
|
||||
inline Log& operator<<(Log& os, PyObject* objp) {
|
||||
PyObjHolder repr_obj(PyObject_Repr(objp));
|
||||
|
||||
if (PyUnicode_CheckExact(repr_obj.obj)) {
|
||||
return os << Py_TYPE(objp)->tp_name <<
|
||||
PyUnicode_AsUTF8(repr_obj.obj);
|
||||
} else {
|
||||
return os << "unknown(" >> (void*)objp >> ")";
|
||||
}
|
||||
}
|
||||
|
||||
// string
|
||||
DEF_IS(string, bool) is_type(PyObject* obj) {
|
||||
return PyUnicode_CheckExact(obj);
|
||||
|
|
|
@ -34,6 +34,18 @@ struct PyObjHolder {
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
inline Log& operator<<(Log& os, PyObject* objp) {
|
||||
PyObjHolder repr_obj(PyObject_Repr(objp));
|
||||
|
||||
if (PyUnicode_CheckExact(repr_obj.obj)) {
|
||||
return os << Py_TYPE(objp)->tp_name <<
|
||||
PyUnicode_AsUTF8(repr_obj.obj);
|
||||
} else {
|
||||
return os << "unknown(" >> (void*)objp >> ")";
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
#define PYJF_MODULE_INIT(name) \
|
||||
|
|
|
@ -29,6 +29,7 @@ Var::Var(NanoVector shape, NanoString dtype)
|
|||
ASSERT(ns.is_dtype());
|
||||
number_of_lived_vars++;
|
||||
numel();
|
||||
if (PREDICT_BRANCH_NOT_TAKEN(trace_py_var)) trace_data.record_node(this);
|
||||
}
|
||||
|
||||
string Var::to_string() {
|
||||
|
|
Loading…
Reference in New Issue