fix memleak

This commit is contained in:
Dun Liang 2020-07-14 22:31:50 +08:00
parent a60c73ea39
commit 060686fafd
5 changed files with 89 additions and 19 deletions

View File

@ -656,7 +656,7 @@ can also be None)::
for i, r in enumerate(ret):
j = self.input_mask[i]
if j<0:
assert r is None, f"The {i}-th returned grad should be None, "\
assert r is None, f"{type(self)}'s {i}-th returned grad should be None, "\
"because the input value is not jittor variable."
else:
new_ret.append(r)

View File

@ -679,14 +679,16 @@ def compile_src(src, h, basename):
])}
LOGf << "Not a valid call.";
}} catch (const std::exception& e) {{
std::stringstream ss;
ss {error_log_code};
PyErr_Format(PyExc_RuntimeError,
"%s\\n%s\\nFailed reason:%s",
ss.str().c_str(),
R""({decs})"",
e.what()
);
if (!PyErr_Occurred()) {{
std::stringstream ss;
ss {error_log_code};
PyErr_Format(PyExc_RuntimeError,
"%s\\n%s\\nFailed reason:%s",
ss.str().c_str(),
R""({decs})"",
e.what()
);
}}
}}
{func_return_failed};
}}

View File

@ -203,8 +203,8 @@ class TestFunction(unittest.TestCase):
def test_multi_grads_multi_out5(self):
class MyFunc(Function):
def execute(self, x, z, y):
self.x = x
self.y = y
self.x = x.name("x")
self.y = y.name("y")
return x*y, "test", x/y
def grad(self, grad0, _, grad1):
@ -212,12 +212,72 @@ class TestFunction(unittest.TestCase):
res = (grad0 * self.y, 1, grad1 * self.x)
print(res)
return res
a = jt.array(3.0)
b = jt.array(4.0)
a = jt.array(3.0).name('a')
b = jt.array(4.0).name('b')
c,_,d = MyFunc()(a, "a", b)
c.name('c'), d.name('d')
expect_error(lambda : jt.grad(c+d*3, [a, b]))
def test_zz_last_test(self):
def test_zmem_leak(self):
def test():
self.test_multi_grads_multi_out5()
test()
jt.clean()
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
def test_zmem_leak2(self):
def test():
class MyFunc(Function):
def execute(self, x, z, y):
self.x = x.name("x")
self.y = y.name("y")
return x*y, "test", x/y
def grad(self, grad0, _, grad1):
assert _ is None
res = (grad0 * self.y, None, grad1 * self.x)
return res
a = jt.array(3.0).name('a')
b = jt.array(4.0).name('b')
c,_,d = MyFunc()(a, "a", b)
c.name('c'), d.name('d')
g = jt.grad(c+d*3, [a, b])
test()
jt.clean()
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
@unittest.skipIf(True, "skip memleak test")
def test_zmem_leak3(self):
def test():
class MyFunc(Function):
def execute(self, x, z, y):
self.x = x
self.y = y
return x*y, "test", x/y
def grad(self, grad0, _, grad1):
assert _ is None
res = (grad0 * self.y, None, grad1 * self.x)
return res
a = jt.array(3.0)
b = jt.array(4.0)
c,_,d = MyFunc()(a, "a", b)
g = jt.grad(c+d*3, [a, b])
jt.sync(g)
import resource
t1 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
for i in range(100000):
test()
if i % 10000 == 0:
jt.clean()
t2 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
for i in range(1000000):
test()
if i % 10000 == 0:
jt.clean()
t3 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
print(t1,t2,t3)
assert t3 < t2 + 10, (t1,t2,t3)
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
if __name__ == "__main__":

View File

@ -18,9 +18,7 @@ static auto make_tape = get_op_info("tape")
TapeOp::TapeOp(Var* x) {
flags.set(NodeFlags::_cpu);
flags.set(NodeFlags::_cuda);
auto y = create_output(nullptr, x->dtype());
if (x->name.ptr)
y->name = x->name;
create_output(nullptr, x->dtype());
}
VarPtr TapeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
@ -35,7 +33,18 @@ void TapeOp::infer_shape() {
}
void Tapes::grads(Var** douts, VarPtr* dins) {
callback.func(_outputs.size(), douts, _inputs.size(), dins);
CHECK(callback.deleter);
try {
callback.func(_outputs.size(), douts, _inputs.size(), dins);
} catch (...) {
// if error occur in callback, we need to
// free it to prevent memory leak, but this is still
// not enough, error may occur outside. please
// find a better solution
callback.deleter();
callback.deleter = nullptr;
throw;
}
}
Tapes::Tapes(

View File

@ -13,7 +13,6 @@ struct PyObjHolder {
PyObject* obj;
inline PyObjHolder(PyObject* obj) : obj(obj) {
if (!obj) {
PyErr_Print();
LOGf << "Python error occur";
}
}