fix ring buffer stuck

This commit is contained in:
Dun Liang 2021-01-25 12:14:23 +08:00
parent 4297ba6fde
commit ccd9598540
5 changed files with 63 additions and 17 deletions

View File

@ -57,8 +57,8 @@ class TestSearchSorted(unittest.TestCase):
s_tc = torch.from_numpy(s)
v_tc = torch.from_numpy(v)
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
y_tc = torch.searchsorted(s_tc, v_tc, right=True)
y_jt = jt.searchsorted(s_jt, v_jt, right=True)
assert np.allclose(y_jt.numpy(), y_tc.data)
y_jt = jt.searchsorted(s_jt, v_jt, right=False)
y_tc = torch.searchsorted(s_tc, v_tc, right=False)

View File

@ -34,6 +34,18 @@ class Model(Module):
x = self.relu1(x)
return self.linear2(x)
def print_stack_tree(data):
tree = {}
for n in data["node_data"].values():
p = tree
for s in n["stacks"]:
name = s['name']
if name not in p:
p[name] = {}
p = p[name]
from pprint import pprint
pprint(tree)
class Linear(Module):
def __init__(self, in_features, out_features, bias=True):
self.w = (jt.random((in_features, out_features))-0.5) / in_features**0.5
@ -58,8 +70,8 @@ class TestTraceVar(unittest.TestCase):
data = jt.dump_trace_data()
jt.clear_trace_data()
# with open("/tmp/simple_model.pkl", "wb") as f:
# pickle.dump(data, f)
with open(f"{jt.flags.cache_path}/simple_model.pkl", "wb") as f:
pickle.dump(data, f)
def test_simple_model_train(self):
with jt.flag_scope(trace_py_var=2):
@ -75,10 +87,20 @@ class TestTraceVar(unittest.TestCase):
data = jt.dump_trace_data()
jt.clear_trace_data()
# with open("/tmp/simple_model_train.pkl", "wb") as f:
# pickle.dump(data, f)
# print_stack_tree(data)
for k,v in data["execute_op_info"].items():
for i in v['fused_ops']:
if i not in data["node_data"]:
assert 0, (i, "not found")
def test_resnet(self):
for k,v in list(data["node_data"].items()):
if v["attrs"]["name"] == "unname":
assert 0
print(len(data["node_data"]))
with open(f"{jt.flags.cache_path}/simple_model_train.pkl", "wb") as f:
pickle.dump(data, f)
def test_resnet_infer(self):
with jt.flag_scope(trace_py_var=2):
resnet18 = resnet.Resnet18()
@ -88,10 +110,14 @@ class TestTraceVar(unittest.TestCase):
data = jt.dump_trace_data()
jt.clear_trace_data()
# with open("/tmp/resnet.pkl", "wb") as f:
# pickle.dump(data, f)
with open(f"{jt.flags.cache_path}/resnet.pkl", "wb") as f:
pickle.dump(data, f)
for k,v in data["execute_op_info"].items():
for i in v['fused_ops']:
if i not in data["node_data"]:
assert 0, (i, "not found")
def test_resnet_train(self):
def test_resnet_trainx(self):
with jt.flag_scope(trace_py_var=2):
resnet18 = resnet.Resnet18()
@ -104,8 +130,19 @@ class TestTraceVar(unittest.TestCase):
data = jt.dump_trace_data()
jt.clear_trace_data()
# with open("/tmp/resnet_train.pkl", "wb") as f:
# pickle.dump(data, f)
with open(f"{jt.flags.cache_path}/resnet_train.pkl", "wb") as f:
pickle.dump(data, f)
for k,v in data["execute_op_info"].items():
for i in v['fused_ops']:
if i not in data["node_data"]:
assert 0, (i, "not found")
for k,v in data["node_data"].items():
if 'name' not in v["attrs"]:
print(v)
# assert 'name' in v["attrs"], v
# for s in v["stacks"]:
# if "_opt" in s["name"] or "_model" in s["name"]:
# assert 0, v
def test_resnet_train_profile(self):
with jt.profile_scope(trace_py_var=1):

View File

@ -9,6 +9,6 @@ docker build --tag jittor/converter_server -f /tmp/converter_server.dockerfile .
# docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.utils.converter_server"
while true; do
timeout --foreground 24h docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 -v ~/https:/https jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && FLASK_APP=/usr/local/lib/python3.7/dist-packages/jittor/utils/converter_server python3.7 -m flask run --cert=/https/fullchain.pem --key=/https/privkey.pem --host=0.0.0.0"
timeout --foreground 24h docker run --rm -it -m 16g --cpus=8 -p 0.0.0.0:5000:5000 -v ~/https:/https jittor/converter_server bash -c "python3.7 -m pip install -U git+https://github.com/Jittor/jittor.git && python3.7 -m jittor.test.test_core && FLASK_APP=/usr/local/lib/python3.7/dist-packages/jittor/utils/converter_server python3.7 -m flask run --cert=/https/fullchain.pem --key=/https/privkey.pem --host=0.0.0.0"
sleep 10
done

View File

@ -57,7 +57,6 @@ void RingBuffer::free_ring_buffer(RingBuffer* rb) {
if (is_multiprocess) {
munmap(rb, total_size);
} else {
rb->~RingBuffer();
free((void*)rb);
}
}

View File

@ -128,7 +128,8 @@ static vector<Stack> get_stack_info() {
auto base_type = PyTuple_GET_ITEM(tp_mro, Py_SIZE(tp_mro)-2);
auto prev_f = i? frames[i-1] : f;
if (base_type == jt_optimizer) {
PyObjHolder ret(find_obj_name(f->f_back, obj, "_opt"));
string init_name = string(obj->ob_type->tp_name) + "_init";
PyObjHolder ret(find_obj_name(f->f_back, obj, init_name.c_str()));
stacks.emplace_back(Stack{
to_string(ret.obj),
string(obj->ob_type->tp_name),
@ -189,7 +190,7 @@ void TraceData::record_node(Node* node, bool record_stack) {
NodeData data;
data.id = node_data_cnt++;
id_map[node] = data.id;
if (!node->is_var() || trace_py_var>=3) {
if (trace_py_var) {
if (record_stack) {
if (trace_grad_op) {
auto iter = trace_data.id_map.find(trace_grad_op);
@ -205,6 +206,10 @@ void TraceData::record_node(Node* node, bool record_stack) {
}
} else {
}
if (node->__id())
data.attrs["__id"] = S(node->__id());
data.attrs["is_var"] = node->is_var() ? "1" : "0";
data.attrs["name"] = "unname";
node_data[data.id] = move(data);
}
@ -223,7 +228,7 @@ void TraceData::release_node(Node* node) {
return;
auto node_id = iter->second;
id_map.erase(node);
if (trace_py_var >= 1) {
if (trace_py_var < 2) {
node_data.erase(node_id);
}
}
@ -231,7 +236,8 @@ void TraceData::release_node(Node* node) {
void TraceData::record_exe_node(Node* node) {
auto node_id = get_node_id(node);
auto& data = node_data[node_id];
if (data.inputs.size() != node->inputs().size() || data.attrs.size() == 0) {
auto name_iter = data.attrs.find("name");
if (data.inputs.size() != node->inputs().size() || data.attrs.size() == 0 || name_iter == data.attrs.end() || name_iter->second == "unname") {
data.inputs.clear();
data.inputs.reserve(node->inputs().size());
for (auto i : node->inputs()) {
@ -316,6 +322,10 @@ PyObject* dump_trace_data() {
for (auto& kv : trace_data.node_data) {
if (kv.second.attrs.size() == 0)
continue;
auto name_iter = kv.second.attrs.find("name");
// if don't have name, this node is not executed
if (name_iter == kv.second.attrs.end() || name_iter->second == "unname")
continue;
PyObjHolder dict(PyDict_New());
fill_dict(dict.obj, string("id"), to_py_object(kv.second.id));
fill_dict(dict.obj, string("inputs"), to_py_object(kv.second.inputs));