mirror of https://github.com/Jittor/Jittor
fix memleak
This commit is contained in:
parent
a60c73ea39
commit
060686fafd
|
@ -656,7 +656,7 @@ can also be None)::
|
|||
for i, r in enumerate(ret):
|
||||
j = self.input_mask[i]
|
||||
if j<0:
|
||||
assert r is None, f"The {i}-th returned grad should be None, "\
|
||||
assert r is None, f"{type(self)}'s {i}-th returned grad should be None, "\
|
||||
"because the input value is not jittor variable."
|
||||
else:
|
||||
new_ret.append(r)
|
||||
|
|
|
@ -679,14 +679,16 @@ def compile_src(src, h, basename):
|
|||
])}
|
||||
LOGf << "Not a valid call.";
|
||||
}} catch (const std::exception& e) {{
|
||||
std::stringstream ss;
|
||||
ss {error_log_code};
|
||||
PyErr_Format(PyExc_RuntimeError,
|
||||
"%s\\n%s\\nFailed reason:%s",
|
||||
ss.str().c_str(),
|
||||
R""({decs})"",
|
||||
e.what()
|
||||
);
|
||||
if (!PyErr_Occurred()) {{
|
||||
std::stringstream ss;
|
||||
ss {error_log_code};
|
||||
PyErr_Format(PyExc_RuntimeError,
|
||||
"%s\\n%s\\nFailed reason:%s",
|
||||
ss.str().c_str(),
|
||||
R""({decs})"",
|
||||
e.what()
|
||||
);
|
||||
}}
|
||||
}}
|
||||
{func_return_failed};
|
||||
}}
|
||||
|
|
|
@ -203,8 +203,8 @@ class TestFunction(unittest.TestCase):
|
|||
def test_multi_grads_multi_out5(self):
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, z, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
self.x = x.name("x")
|
||||
self.y = y.name("y")
|
||||
return x*y, "test", x/y
|
||||
|
||||
def grad(self, grad0, _, grad1):
|
||||
|
@ -212,12 +212,72 @@ class TestFunction(unittest.TestCase):
|
|||
res = (grad0 * self.y, 1, grad1 * self.x)
|
||||
print(res)
|
||||
return res
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
a = jt.array(3.0).name('a')
|
||||
b = jt.array(4.0).name('b')
|
||||
c,_,d = MyFunc()(a, "a", b)
|
||||
c.name('c'), d.name('d')
|
||||
expect_error(lambda : jt.grad(c+d*3, [a, b]))
|
||||
|
||||
def test_zz_last_test(self):
|
||||
def test_zmem_leak(self):
|
||||
def test():
|
||||
self.test_multi_grads_multi_out5()
|
||||
test()
|
||||
jt.clean()
|
||||
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
|
||||
|
||||
def test_zmem_leak2(self):
|
||||
def test():
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, z, y):
|
||||
self.x = x.name("x")
|
||||
self.y = y.name("y")
|
||||
return x*y, "test", x/y
|
||||
|
||||
def grad(self, grad0, _, grad1):
|
||||
assert _ is None
|
||||
res = (grad0 * self.y, None, grad1 * self.x)
|
||||
return res
|
||||
a = jt.array(3.0).name('a')
|
||||
b = jt.array(4.0).name('b')
|
||||
c,_,d = MyFunc()(a, "a", b)
|
||||
c.name('c'), d.name('d')
|
||||
g = jt.grad(c+d*3, [a, b])
|
||||
test()
|
||||
jt.clean()
|
||||
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
|
||||
|
||||
@unittest.skipIf(True, "skip memleak test")
|
||||
def test_zmem_leak3(self):
|
||||
def test():
|
||||
class MyFunc(Function):
|
||||
def execute(self, x, z, y):
|
||||
self.x = x
|
||||
self.y = y
|
||||
return x*y, "test", x/y
|
||||
|
||||
def grad(self, grad0, _, grad1):
|
||||
assert _ is None
|
||||
res = (grad0 * self.y, None, grad1 * self.x)
|
||||
return res
|
||||
a = jt.array(3.0)
|
||||
b = jt.array(4.0)
|
||||
c,_,d = MyFunc()(a, "a", b)
|
||||
g = jt.grad(c+d*3, [a, b])
|
||||
jt.sync(g)
|
||||
import resource
|
||||
t1 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
for i in range(100000):
|
||||
test()
|
||||
if i % 10000 == 0:
|
||||
jt.clean()
|
||||
t2 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
for i in range(1000000):
|
||||
test()
|
||||
if i % 10000 == 0:
|
||||
jt.clean()
|
||||
t3 = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss
|
||||
print(t1,t2,t3)
|
||||
assert t3 < t2 + 10, (t1,t2,t3)
|
||||
self.assertEqual(jt.liveness_info()["lived_vars"], 0)
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -18,9 +18,7 @@ static auto make_tape = get_op_info("tape")
|
|||
TapeOp::TapeOp(Var* x) {
|
||||
flags.set(NodeFlags::_cpu);
|
||||
flags.set(NodeFlags::_cuda);
|
||||
auto y = create_output(nullptr, x->dtype());
|
||||
if (x->name.ptr)
|
||||
y->name = x->name;
|
||||
create_output(nullptr, x->dtype());
|
||||
}
|
||||
|
||||
VarPtr TapeOp::grad(Var* out, Var* dout, Var* v, int v_index) {
|
||||
|
@ -35,7 +33,18 @@ void TapeOp::infer_shape() {
|
|||
}
|
||||
|
||||
void Tapes::grads(Var** douts, VarPtr* dins) {
|
||||
callback.func(_outputs.size(), douts, _inputs.size(), dins);
|
||||
CHECK(callback.deleter);
|
||||
try {
|
||||
callback.func(_outputs.size(), douts, _inputs.size(), dins);
|
||||
} catch (...) {
|
||||
// if error occur in callback, we need to
|
||||
// free it to prevent memory leak, but this is still
|
||||
// not enough, error may occur outside. please
|
||||
// find a better solution
|
||||
callback.deleter();
|
||||
callback.deleter = nullptr;
|
||||
throw;
|
||||
}
|
||||
}
|
||||
|
||||
Tapes::Tapes(
|
||||
|
|
|
@ -13,7 +13,6 @@ struct PyObjHolder {
|
|||
PyObject* obj;
|
||||
inline PyObjHolder(PyObject* obj) : obj(obj) {
|
||||
if (!obj) {
|
||||
PyErr_Print();
|
||||
LOGf << "Python error occur";
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue