mirror of https://github.com/Jittor/Jittor
fix switch flags
This commit is contained in:
parent
143bb01d8e
commit
9ddd16126c
|
@ -149,7 +149,9 @@ class Dataset(object):
|
|||
'''
|
||||
if hasattr(self, "workers"):
|
||||
for w in self.workers:
|
||||
w.p.terminate()
|
||||
w.buffer.stop()
|
||||
w.p.join()
|
||||
w.p.close()
|
||||
|
||||
def _worker_main(self, worker_id, buffer, status):
|
||||
import time
|
||||
|
|
|
@ -75,7 +75,11 @@ class TestRingBuffer(unittest.TestCase):
|
|||
self.train_loader = MNIST(train=True, transform=trans.Resize(224)) \
|
||||
.set_attrs(batch_size=300, shuffle=True)
|
||||
self.train_loader.num_workers = 1
|
||||
import time
|
||||
for batch_idx, (data, target) in tqdm(enumerate(self.train_loader)):
|
||||
# time.sleep(5)
|
||||
# print("break")
|
||||
# break
|
||||
# self.train_loader.display_worker_status()
|
||||
if batch_idx > 30:
|
||||
break
|
||||
|
|
|
@ -70,6 +70,8 @@ class TestWhereOpCuda(TestWhereOp):
|
|||
jt.flags.use_cuda = 0
|
||||
|
||||
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "No Torch found")
|
||||
class TestWhereOpCub(TestWhereOpCuda):
|
||||
def setUp(self):
|
||||
self.where = jt.compile_extern.cub_ops.cub_where
|
||||
|
|
|
@ -14,6 +14,8 @@ namespace jittor {
|
|||
DEFINE_FLAG_WITH_SETTER(int, use_cuda, 0,
|
||||
"Use cuda or not. 1 for trying to use cuda, 2 for forcing to use cuda.");
|
||||
|
||||
extern void sync_all(bool device_sync);
|
||||
|
||||
void setter_use_cuda(int value) {
|
||||
#ifdef HAS_CUDA
|
||||
if (value) {
|
||||
|
@ -27,6 +29,8 @@ void setter_use_cuda(int value) {
|
|||
#else
|
||||
CHECK(value==0) << "No CUDA found.";
|
||||
#endif
|
||||
if (use_cuda != value)
|
||||
sync_all(0);
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -18,11 +18,18 @@ RingBuffer::RingBuffer(uint64 size, bool multiprocess) : m(multiprocess), cv(mul
|
|||
size_mask = (1ll<<i)-1;
|
||||
this->size = size_mask+1;
|
||||
size_bit = i;
|
||||
l = r = is_wait = 0;
|
||||
l = r = is_wait = is_stop = 0;
|
||||
is_multiprocess = multiprocess;
|
||||
}
|
||||
|
||||
void RingBuffer::stop() {
|
||||
MutexScope _(m);
|
||||
is_stop = 1;
|
||||
cv.notify();
|
||||
}
|
||||
|
||||
RingBuffer::~RingBuffer() {
|
||||
stop();
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -71,7 +71,8 @@ struct RingBuffer {
|
|||
uint64 size_bit;
|
||||
volatile uint64 l;
|
||||
volatile uint64 r;
|
||||
volatile int is_wait;
|
||||
volatile bool is_wait;
|
||||
volatile bool is_stop;
|
||||
bool is_multiprocess;
|
||||
Mutex m;
|
||||
Cond cv;
|
||||
|
@ -79,17 +80,23 @@ struct RingBuffer {
|
|||
|
||||
RingBuffer(uint64 size, bool multiprocess=false);
|
||||
~RingBuffer();
|
||||
void stop();
|
||||
static RingBuffer* make_ring_buffer(uint64 size, bool multiprocess);
|
||||
static void free_ring_buffer(RingBuffer* rb);
|
||||
|
||||
inline void wait() {
|
||||
MutexScope _(m);
|
||||
if (is_wait) {
|
||||
cv.notify();
|
||||
is_wait = 0;
|
||||
if (is_stop) {
|
||||
abort();
|
||||
}
|
||||
{
|
||||
MutexScope _(m);
|
||||
if (is_wait) {
|
||||
cv.notify();
|
||||
is_wait = 0;
|
||||
}
|
||||
is_wait = 1;
|
||||
cv.wait(m);
|
||||
}
|
||||
is_wait = 1;
|
||||
cv.wait(m);
|
||||
}
|
||||
|
||||
inline void notify() {
|
||||
|
|
|
@ -23,6 +23,8 @@ struct PyMultiprocessRingBuffer {
|
|||
PyObject* pop();
|
||||
// @pyjt(clear)
|
||||
inline void clear() { rb->l = rb->r = 0; }
|
||||
// @pyjt(stop)
|
||||
inline void stop() { rb->stop(); }
|
||||
|
||||
// @pyjt(total_pop)
|
||||
inline uint64 total_pop() { return rb->l; }
|
||||
|
|
Loading…
Reference in New Issue