signal handler for ctrl-c

This commit is contained in:
Dun Liang 2020-11-03 12:36:04 +08:00
parent cd5731970a
commit 5713130b03
9 changed files with 43 additions and 10 deletions

View File

@ -7,7 +7,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.0.9'
__version__ = '1.2.1.0'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -149,11 +149,11 @@ class Dataset(object):
'''
if hasattr(self, "workers"):
for w in self.workers:
w.buffer.stop()
w.p.join()
w.p.close()
w.p.terminate()
def _worker_main(self, worker_id, buffer, status):
import jittor_utils
jittor_utils.cc.init_subprocess()
import time
try:
gid_obj = self.gid.get_obj()
@ -162,7 +162,7 @@ class Dataset(object):
while True:
# get id
with gid_lock:
while gid_obj.value >= self.batch_len:
while gid_obj.value >= self.batch_len or buffer.is_stop():
self.num_idle.value += 1
self.num_idle_c.notify()
self.gidc.wait()
@ -189,7 +189,12 @@ class Dataset(object):
# send data to main process
if mp_log_v:
print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer)
buffer.send(batch)
try:
buffer.send(batch)
except:
if buffer.is_stop():
continue
raise
now = time.time()
send_time = now - start
start = now

View File

@ -84,6 +84,14 @@ class TestRingBuffer(unittest.TestCase):
if batch_idx > 30:
break
pass
for batch_idx, (data, target) in tqdm(enumerate(self.train_loader)):
# time.sleep(5)
# print("break")
# break
# self.train_loader.display_worker_status()
if batch_idx > 300:
break
pass
if __name__ == "__main__":

View File

@ -57,6 +57,7 @@ class TestWhereOp(unittest.TestCase):
assert "Where Operator" in jt.where.__doc__
@unittest.skipIf(not jt.has_cuda, "No Torch found")
class TestWhereOpCuda(TestWhereOp):
def setUp(self):
self.where = jt.where

View File

@ -154,6 +154,9 @@ def pool_cleanup():
p.__exit__(None, None, None)
del p
def pool_initializer():
cc.init_subprocess()
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
global pool_size, p
bk = mp.current_process()._config.get('daemon')
@ -163,7 +166,7 @@ def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
mem_gib = mem_bytes/(1024.**3)
pool_size = min(16,max(int(mem_gib // 3), 1))
LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.")
p = Pool(pool_size)
p = Pool(pool_size, initializer=pool_initializer)
p.__enter__()
import atexit
atexit.register(pool_cleanup)

View File

@ -48,6 +48,9 @@ struct RingBuffer {
}
inline ~Cond() {
// a dirty hack
// ref: https://stackoverflow.com/questions/20439404/pthread-conditions-and-process-termination
cv.__data.__wrefs = 0;
pthread_cond_destroy(&cv);
}
@ -86,7 +89,7 @@ struct RingBuffer {
inline void wait() {
if (is_stop) {
abort();
throw std::runtime_error("stop");
}
{
MutexScope _(m);

View File

@ -22,9 +22,11 @@ struct PyMultiprocessRingBuffer {
// @pyjt(pop,recv)
PyObject* pop();
// @pyjt(clear)
inline void clear() { rb->l = rb->r = 0; }
inline void clear() { rb->l = rb->r = rb->is_stop = 0; }
// @pyjt(stop)
inline void stop() { rb->stop(); }
// @pyjt(is_stop)
inline bool is_stop() { return rb->is_stop; }
// @pyjt(total_pop)
inline uint64 total_pop() { return rb->l; }

View File

@ -13,6 +13,16 @@
#ifdef __GNUC__
#endif
#include <pybind11/iostream.h>
#include <sys/prctl.h>
#include <signal.h>
namespace jittor {
void init_subprocess() {
prctl(PR_SET_PDEATHSIG, SIGKILL);
}
}
PYBIND11_MODULE(jit_utils_core, m) {
pybind11::add_ostream_redirect(m, "ostream_redirect");
@ -39,4 +49,5 @@ PYBIND11_MODULE(jit_utils_core, m) {
m.def("log_capture_start", &jittor::log_capture_start);
m.def("log_capture_stop", &jittor::log_capture_stop);
m.def("log_capture_read", &jittor::log_capture_read);
m.def("init_subprocess", &jittor::init_subprocess);
}

View File

@ -219,7 +219,7 @@ int register_sigaction() {
sigaction(SIGKILL, &sa, NULL);
sigaction(SIGSTOP, &sa, NULL);
sigaction(SIGFPE, &sa, NULL);
// sigaction(SIGINT, &sa, NULL);
sigaction(SIGINT, &sa, NULL);
sigaction(SIGILL, &sa, NULL);
sigaction(SIGBUS, &sa, NULL);
sigaction(SIGQUIT, &sa, NULL);