fix switch flags

This commit is contained in:
Dun Liang 2020-11-02 22:41:26 +08:00
parent 143bb01d8e
commit 9ddd16126c
7 changed files with 37 additions and 9 deletions

View File

@ -149,7 +149,9 @@ class Dataset(object):
'''
if hasattr(self, "workers"):
for w in self.workers:
w.p.terminate()
w.buffer.stop()
w.p.join()
w.p.close()
def _worker_main(self, worker_id, buffer, status):
import time

View File

@ -75,7 +75,11 @@ class TestRingBuffer(unittest.TestCase):
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
.set_attrs(batch_size=300, shuffle=True)
self.train_loader.num_workers = 1
import time
for batch_idx, (data, target) in tqdm(enumerate(self.train_loader)):
# time.sleep(5)
# print("break")
# break
# self.train_loader.display_worker_status()
if batch_idx > 30:
break

View File

@ -70,6 +70,8 @@ class TestWhereOpCuda(TestWhereOp):
jt.flags.use_cuda = 0
@unittest.skipIf(not jt.has_cuda, "No Torch found")
class TestWhereOpCub(TestWhereOpCuda):
def setUp(self):
self.where = jt.compile_extern.cub_ops.cub_where

View File

@ -14,6 +14,8 @@ namespace jittor {
DEFINE_FLAG_WITH_SETTER(int, use_cuda, 0,
"Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda.");
extern void sync_all(bool device_sync);
void setter_use_cuda(int value) {
#ifdef HAS_CUDA
if (value) {
@ -27,6 +29,8 @@ void setter_use_cuda(int value) {
#else
CHECK(value==0) << "No CUDA found.";
#endif
if (use_cuda != value)
sync_all(0);
}
} // jittor

View File

@ -18,11 +18,18 @@ RingBuffer::RingBuffer(uint64 size, bool multiprocess) : m(multiprocess), cv(mul
size_mask = (1ll<<i)-1;
this->size = size_mask+1;
size_bit = i;
l = r = is_wait = 0;
l = r = is_wait = is_stop = 0;
is_multiprocess = multiprocess;
}
void RingBuffer::stop() {
MutexScope _(m);
is_stop = 1;
cv.notify();
}
RingBuffer::~RingBuffer() {
stop();
}

View File

@ -71,7 +71,8 @@ struct RingBuffer {
uint64 size_bit;
volatile uint64 l;
volatile uint64 r;
volatile int is_wait;
volatile bool is_wait;
volatile bool is_stop;
bool is_multiprocess;
Mutex m;
Cond cv;
@ -79,17 +80,23 @@ struct RingBuffer {
RingBuffer(uint64 size, bool multiprocess=false);
~RingBuffer();
void stop();
static RingBuffer* make_ring_buffer(uint64 size, bool multiprocess);
static void free_ring_buffer(RingBuffer* rb);
inline void wait() {
MutexScope _(m);
if (is_wait) {
cv.notify();
is_wait = 0;
if (is_stop) {
abort();
}
{
MutexScope _(m);
if (is_wait) {
cv.notify();
is_wait = 0;
}
is_wait = 1;
cv.wait(m);
}
is_wait = 1;
cv.wait(m);
}
inline void notify() {

View File

@ -23,6 +23,8 @@ struct PyMultiprocessRingBuffer {
PyObject* pop();
// @pyjt(clear)
inline void clear() { rb->l = rb->r = 0; }
// @pyjt(stop)
inline void stop() { rb->stop(); }
// @pyjt(total_pop)
inline uint64 total_pop() { return rb->l; }