mirror of https://github.com/Jittor/Jittor
add dataset keep numpy array
This commit is contained in:
parent
bb562728d3
commit
37d4ffca0b
|
@ -27,8 +27,9 @@ mpi = jt.mpi
|
|||
img_open_hook = HookTimer(Image, "open")
|
||||
|
||||
class Worker:
|
||||
def __init__(self, target, args, buffer_size):
|
||||
def __init__(self, target, args, buffer_size, keep_numpy_array=False):
|
||||
self.buffer = jt.RingBuffer(buffer_size)
|
||||
self.buffer.keep_numpy_array(keep_numpy_array)
|
||||
|
||||
self.status = mp.Array('f', 5, lock=False)
|
||||
self.p = mp.Process(target=target, args=args+(self.buffer,self.status))
|
||||
|
@ -67,7 +68,8 @@ class Dataset(object):
|
|||
drop_last = False,
|
||||
num_workers = 0,
|
||||
buffer_size = 512*1024*1024,
|
||||
stop_grad = True):
|
||||
stop_grad = True,
|
||||
keep_numpy_array = False):
|
||||
super().__init__()
|
||||
self.total_len = None
|
||||
self.batch_size = batch_size
|
||||
|
@ -76,6 +78,7 @@ class Dataset(object):
|
|||
self.num_workers = num_workers
|
||||
self.buffer_size = buffer_size
|
||||
self.stop_grad = stop_grad
|
||||
self.keep_numpy_array = keep_numpy_array
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
@ -117,6 +120,7 @@ class Dataset(object):
|
|||
'''
|
||||
Change batch data to jittor array, such as np.ndarray, int, and float.
|
||||
'''
|
||||
if self.keep_numpy_array: return batch
|
||||
if isinstance(batch, jt.Var): return batch
|
||||
to_jt = lambda x: jt.array(x).stop_grad() \
|
||||
if self.stop_grad else jt.array(x)
|
||||
|
@ -300,7 +304,8 @@ Example::
|
|||
self.num_idle_c = mp.Condition(self.gid.get_lock())
|
||||
for i in range(self.num_workers):
|
||||
w = Worker(target=self._worker_main, args=(i,),
|
||||
buffer_size=self.buffer_size)
|
||||
buffer_size=self.buffer_size,
|
||||
keep_numpy_array=self.keep_numpy_array)
|
||||
workers.append(w)
|
||||
self.workers = workers
|
||||
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
|
||||
|
|
|
@ -142,7 +142,7 @@ static PyObject* to_py_object3(ArrayArgs&& a) {
|
|||
return to_py_object(jit_op_maker::array_(move(a)));
|
||||
}
|
||||
|
||||
static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
|
||||
static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset, bool keep_numpy_array) {
|
||||
auto t = rb->pop_t<uint8>(offset);
|
||||
if (t==0) {
|
||||
auto x = rb->pop_t<int64>(offset);
|
||||
|
@ -161,7 +161,7 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
|
|||
auto size = rb->pop_t<int64>(offset);
|
||||
PyObjHolder list(PyList_New(size));
|
||||
for (uint i=0; i<size; i++) {
|
||||
PyObject* o = pop_py_object(rb, offset);
|
||||
PyObject* o = pop_py_object(rb, offset, keep_numpy_array);
|
||||
PyList_SET_ITEM(list.obj, i, o);
|
||||
}
|
||||
return list.release();
|
||||
|
@ -170,8 +170,8 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
|
|||
auto size = rb->pop_t<int64>(offset);
|
||||
PyObjHolder dict(PyDict_New());
|
||||
for (int64 i=0; i<size; i++) {
|
||||
PyObject* key = pop_py_object(rb, offset);
|
||||
PyObject* value = pop_py_object(rb, offset);
|
||||
PyObject* key = pop_py_object(rb, offset, keep_numpy_array);
|
||||
PyObject* value = pop_py_object(rb, offset, keep_numpy_array);
|
||||
PyDict_SetItem(dict.obj, key, value);
|
||||
}
|
||||
return dict.release();
|
||||
|
@ -185,7 +185,10 @@ static PyObject* pop_py_object(RingBuffer* rb, uint64& __restrict__ offset) {
|
|||
size *= args.shape[i];
|
||||
rb->pop(size, offset);
|
||||
args.ptr = rb->get_ptr(size, offset);
|
||||
return to_py_object3(move(args));
|
||||
if (!keep_numpy_array)
|
||||
return to_py_object3(move(args));
|
||||
else
|
||||
return to_py_object<ArrayArgs>(args);
|
||||
}
|
||||
if (t==6) {
|
||||
return pop_py_object_pickle(rb, offset);
|
||||
|
@ -212,7 +215,7 @@ void PyMultiprocessRingBuffer::push(PyObject* obj) {
|
|||
|
||||
PyObject* PyMultiprocessRingBuffer::pop() {
|
||||
auto offset = rb->l;
|
||||
auto obj = pop_py_object(rb, offset);
|
||||
auto obj = pop_py_object(rb, offset, _keep_numpy_array);
|
||||
rb->commit_pop(offset);
|
||||
return obj;
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ namespace jittor {
|
|||
// @pyjt(RingBuffer)
|
||||
struct PyMultiprocessRingBuffer {
|
||||
RingBuffer* rb;
|
||||
bool _keep_numpy_array = false;
|
||||
// @pyjt(__init__)
|
||||
PyMultiprocessRingBuffer(uint64 size);
|
||||
// @pyjt(__dealloc__)
|
||||
|
@ -23,6 +24,8 @@ struct PyMultiprocessRingBuffer {
|
|||
PyObject* pop();
|
||||
// @pyjt(clear)
|
||||
inline void clear() { rb->clear(); }
|
||||
// @pyjt(keep_numpy_array)
|
||||
inline void keep_numpy_array(bool keep) { _keep_numpy_array = keep; }
|
||||
// @pyjt(stop)
|
||||
inline void stop() { rb->stop(); }
|
||||
// @pyjt(is_stop)
|
||||
|
|
Loading…
Reference in New Issue