mirror of https://github.com/Jittor/Jittor
signal handler for ctrl-c
This commit is contained in:
parent
cd5731970a
commit
5713130b03
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.0.9'
|
||||
__version__ = '1.2.1.0'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -149,11 +149,11 @@ class Dataset(object):
|
|||
'''
|
||||
if hasattr(self, "workers"):
|
||||
for w in self.workers:
|
||||
w.buffer.stop()
|
||||
w.p.join()
|
||||
w.p.close()
|
||||
w.p.terminate()
|
||||
|
||||
def _worker_main(self, worker_id, buffer, status):
|
||||
import jittor_utils
|
||||
jittor_utils.cc.init_subprocess()
|
||||
import time
|
||||
try:
|
||||
gid_obj = self.gid.get_obj()
|
||||
|
@ -162,7 +162,7 @@ class Dataset(object):
|
|||
while True:
|
||||
# get id
|
||||
with gid_lock:
|
||||
while gid_obj.value >= self.batch_len:
|
||||
while gid_obj.value >= self.batch_len or buffer.is_stop():
|
||||
self.num_idle.value += 1
|
||||
self.num_idle_c.notify()
|
||||
self.gidc.wait()
|
||||
|
@ -189,7 +189,12 @@ class Dataset(object):
|
|||
# send data to main process
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer)
|
||||
buffer.send(batch)
|
||||
try:
|
||||
buffer.send(batch)
|
||||
except:
|
||||
if buffer.is_stop():
|
||||
continue
|
||||
raise
|
||||
now = time.time()
|
||||
send_time = now - start
|
||||
start = now
|
||||
|
|
|
@ -84,6 +84,14 @@ class TestRingBuffer(unittest.TestCase):
|
|||
if batch_idx > 30:
|
||||
break
|
||||
pass
|
||||
for batch_idx, (data, target) in tqdm(enumerate(self.train_loader)):
|
||||
# time.sleep(5)
|
||||
# print("break")
|
||||
# break
|
||||
# self.train_loader.display_worker_status()
|
||||
if batch_idx > 300:
|
||||
break
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
@ -57,6 +57,7 @@ class TestWhereOp(unittest.TestCase):
|
|||
assert "Where Operator" in jt.where.__doc__
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Torch found")
|
||||
class TestWhereOpCuda(TestWhereOp):
|
||||
def setUp(self):
|
||||
self.where = jt.where
|
||||
|
|
|
@ -154,6 +154,9 @@ def pool_cleanup():
|
|||
p.__exit__(None, None, None)
|
||||
del p
|
||||
|
||||
def pool_initializer():
|
||||
cc.init_subprocess()
|
||||
|
||||
def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
||||
global pool_size, p
|
||||
bk = mp.current_process()._config.get('daemon')
|
||||
|
@ -163,7 +166,7 @@ def run_cmds(cmds, cache_path, jittor_path, msg="run_cmds"):
|
|||
mem_gib = mem_bytes/(1024.**3)
|
||||
pool_size = min(16,max(int(mem_gib // 3), 1))
|
||||
LOG.i(f"Total mem: {mem_gib:.2f}GB, using {pool_size} procs for compiling.")
|
||||
p = Pool(pool_size)
|
||||
p = Pool(pool_size, initializer=pool_initializer)
|
||||
p.__enter__()
|
||||
import atexit
|
||||
atexit.register(pool_cleanup)
|
||||
|
|
|
@ -48,6 +48,9 @@ struct RingBuffer {
|
|||
}
|
||||
|
||||
inline ~Cond() {
|
||||
// a dirty hack
|
||||
// ref: https://stackoverflow.com/questions/20439404/pthread-conditions-and-process-termination
|
||||
cv.__data.__wrefs = 0;
|
||||
pthread_cond_destroy(&cv);
|
||||
}
|
||||
|
||||
|
@ -86,7 +89,7 @@ struct RingBuffer {
|
|||
|
||||
inline void wait() {
|
||||
if (is_stop) {
|
||||
abort();
|
||||
throw std::runtime_error("stop");
|
||||
}
|
||||
{
|
||||
MutexScope _(m);
|
||||
|
|
|
@ -22,9 +22,11 @@ struct PyMultiprocessRingBuffer {
|
|||
// @pyjt(pop,recv)
|
||||
PyObject* pop();
|
||||
// @pyjt(clear)
|
||||
inline void clear() { rb->l = rb->r = 0; }
|
||||
inline void clear() { rb->l = rb->r = rb->is_stop = 0; }
|
||||
// @pyjt(stop)
|
||||
inline void stop() { rb->stop(); }
|
||||
// @pyjt(is_stop)
|
||||
inline bool is_stop() { return rb->is_stop; }
|
||||
|
||||
// @pyjt(total_pop)
|
||||
inline uint64 total_pop() { return rb->l; }
|
||||
|
|
|
@ -13,6 +13,16 @@
|
|||
#ifdef __GNUC__
|
||||
#endif
|
||||
#include <pybind11/iostream.h>
|
||||
#include <sys/prctl.h>
|
||||
#include <signal.h>
|
||||
|
||||
namespace jittor {
|
||||
|
||||
void init_subprocess() {
|
||||
prctl(PR_SET_PDEATHSIG, SIGKILL);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(jit_utils_core, m) {
|
||||
pybind11::add_ostream_redirect(m, "ostream_redirect");
|
||||
|
@ -39,4 +49,5 @@ PYBIND11_MODULE(jit_utils_core, m) {
|
|||
m.def("log_capture_start", &jittor::log_capture_start);
|
||||
m.def("log_capture_stop", &jittor::log_capture_stop);
|
||||
m.def("log_capture_read", &jittor::log_capture_read);
|
||||
m.def("init_subprocess", &jittor::init_subprocess);
|
||||
}
|
||||
|
|
|
@ -219,7 +219,7 @@ int register_sigaction() {
|
|||
sigaction(SIGKILL, &sa, NULL);
|
||||
sigaction(SIGSTOP, &sa, NULL);
|
||||
sigaction(SIGFPE, &sa, NULL);
|
||||
// sigaction(SIGINT, &sa, NULL);
|
||||
sigaction(SIGINT, &sa, NULL);
|
||||
sigaction(SIGILL, &sa, NULL);
|
||||
sigaction(SIGBUS, &sa, NULL);
|
||||
sigaction(SIGQUIT, &sa, NULL);
|
||||
|
|
Loading…
Reference in New Issue