add dataset keep numpy array

This commit is contained in:
li-xl 2020-12-17 12:01:18 +08:00
parent bb562728d3
commit 37d4ffca0b
3 changed files with 20 additions and 9 deletions

View File

@ -27,8 +27,9 @@ mpi = jt.mpi
img_open_hook = HookTimer(Image, "open")
class Worker:
def __init__(self, target, args, buffer_size):
def __init__(self, target, args, buffer_size, keep_numpy_array=False):
self.buffer = jt.RingBuffer(buffer_size)
self.buffer.keep_numpy_array(keep_numpy_array)
self.status = mp.Array('f', 5, lock=False)
self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
@ -67,7 +68,8 @@ class Dataset(object):
drop_last = False,
num_workers = 0,
buffer_size = 512*1024*1024,
stop_grad = True):
stop_grad = True,
keep_numpy_array = False):
super().__init__()
self.total_len = None
self.batch_size = batch_size
@ -76,6 +78,7 @@ class Dataset(object):
self.num_workers = num_workers
self.buffer_size = buffer_size
self.stop_grad = stop_grad
self.keep_numpy_array = keep_numpy_array
def __getitem__(self, index):
raise NotImplementedError
@ -117,6 +120,7 @@ class Dataset(object):
'''
Change batch data to jittor array, such as np.ndarray, int, and float.
'''
if self.keep_numpy_array: return batch
if isinstance(batch, jt.Var): return batch
to_jt = lambda x: jt.array(x).stop_grad() \
if self.stop_grad else jt.array(x)
@ -300,7 +304,8 @@ Example::
self.num_idle_c = mp.Condition(self.gid.get_lock())
for i in range(self.num_workers):
w = Worker(target=self._worker_main, args=(i,),
buffer_size=self.buffer_size)
buffer_size=self.buffer_size,
keep_numpy_array=self.keep_numpy_array)
workers.append(w)
self.workers = workers
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)

View File

@ -142,7 +142,7 @@ static PyObject* to_py_object3(ArrayArgs&& a) {
return to_py_object(jit_op_maker::array_(move(a)));
}
static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset, bool keep_numpy_array) {
auto t = rb->pop_t<uint8>(offset);
if (t==0) {
auto x = rb->pop_t<int64>(offset);
@ -161,7 +161,7 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
auto size = rb->pop_t<int64>(offset);
PyObjHolder list(PyList_New(size));
for (uint i=0; i<size; i++) {
PyObject* o = pop_py_object(rb, offset);
PyObject* o = pop_py_object(rb, offset, keep_numpy_array);
PyList_SET_ITEM(list.obj, i, o);
}
return list.release();
@ -170,8 +170,8 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
auto size = rb->pop_t<int64>(offset);
PyObjHolder dict(PyDict_New());
for (int64 i=0; i<size; i++) {
PyObject* key = pop_py_object(rb, offset);
PyObject* value = pop_py_object(rb, offset);
PyObject* key = pop_py_object(rb, offset, keep_numpy_array);
PyObject* value = pop_py_object(rb, offset, keep_numpy_array);
PyDict_SetItem(dict.obj, key, value);
}
return dict.release();
@ -185,7 +185,10 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
size *= args.shape[i];
rb->pop(size, offset);
args.ptr = rb->get_ptr(size, offset);
return to_py_object3(move(args));
if (!keep_numpy_array)
return to_py_object3(move(args));
else
return to_py_object<ArrayArgs>(args);
}
if (t==6) {
return pop_py_object_pickle(rb, offset);
@ -212,7 +215,7 @@ void PyMultiprocessRingBuffer::push(PyObject* obj) {
PyObject* PyMultiprocessRingBuffer::pop() {
auto offset = rb->l;
auto obj = pop_py_object(rb, offset);
auto obj = pop_py_object(rb, offset, _keep_numpy_array);
rb->commit_pop(offset);
return obj;
}

View File

@ -13,6 +13,7 @@ namespace jittor {
// @pyjt(RingBuffer)
struct PyMultiprocessRingBuffer {
RingBuffer* rb;
bool _keep_numpy_array = false;
// @pyjt(__init__)
PyMultiprocessRingBuffer(uint64 size);
// @pyjt(__dealloc__)
@ -23,6 +24,8 @@ struct PyMultiprocessRingBuffer {
PyObject* pop();
// @pyjt(clear)
inline void clear() { rb->clear(); }
// @pyjt(keep_numpy_array)
inline void keep_numpy_array(bool keep) { _keep_numpy_array = keep; }
// @pyjt(stop)
inline void stop() { rb->stop(); }
// @pyjt(is_stop)