multiprocess dataset loader

This commit is contained in:
Dun Liang 2020-03-28 13:53:34 +08:00
parent fd9867ebb9
commit 581acecc8e
7 changed files with 594 additions and 19 deletions

View File

@ -6,17 +6,31 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import numpy as np
from urllib import request
import gzip
import pickle
import os
from jittor.dataset.utils import get_random_list, get_order_list, collate_batch
from collections.abc import Sequence, Mapping
import pathlib
from PIL import Image
from jittor_utils.ring_buffer import RingBuffer
import multiprocessing as mp
import signal
from jittor_utils import LOG
import jittor as jt
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
mp_log_v = os.environ.get("mp_log_v", 0)
class Worker:
def __init__(self, target, args, buffer_size):
buffer = mp.Array('c', buffer_size, lock=False)
self.buffer = RingBuffer(buffer)
self.p = mp.Process(target=target, args=args+(self.buffer,))
self.p.daemon = True
self.p.start()
class Dataset(object):
'''
@ -41,6 +55,10 @@ class Dataset(object):
self.total_len = None
self.shuffle = False
self.drop_last = False
self.num_workers = 0
self.buffer_size = 512*1024*1024
if "num_workers" in os.environ:
self.num_workers = int(os.environ["num_workers"])
def __getitem__(self, index):
raise NotImplementedError
@ -48,6 +66,8 @@ class Dataset(object):
def __len__(self):
assert self.total_len >= 0
assert self.batch_size > 0
if self.drop_last:
return self.total_len // self.batch_size
return (self.total_len-1) // self.batch_size + 1
def set_attrs(self, **kw):
@ -56,6 +76,7 @@ class Dataset(object):
Attrs:
batch_size(int): batch size, default 16.
totol_len(int): totol lenght.
num_workers: number of workers for loading data
shuffle(bool): shuffle at each epoch, default False.
drop_last(bool): if true, the last batch of dataset
might smaller than batch_size, default True.
@ -65,26 +86,146 @@ class Dataset(object):
setattr(self, k, v)
return self
def to_jittor(self, batch):
if isinstance(batch, np.ndarray):
return jt.array(batch)
assert isinstance(batch, Sequence)
new_batch = []
for a in batch:
if isinstance(a, np.ndarray) or \
isinstance(a, int) or \
isinstance(a, float):
new_batch.append(jt.array(a))
else:
new_batch.append(a)
return new_batch
def collate_batch(self, batch):
return collate_batch(batch)
def terminate(self):
if hasattr(self, "workers"):
for w in self.workers:
w.p.terminate()
def _worker_main(self, worker_id, buffer):
try:
gid_obj = self.gid.get_obj()
gid_lock = self.gid.get_lock()
while True:
with gid_lock:
while gid_obj.value >= self.batch_len:
self.num_idle.value += 1
self.num_idle_c.notify()
self.gidc.wait()
self.num_idle.value -= 1
cid = gid_obj.value
self.idmap[cid] = worker_id
gid_obj.value += 1
self.gidc.notify()
batch = []
if mp_log_v:
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size))
for i in range(cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size)):
batch.append(self[self.index_list[i]])
batch = self.collate_batch(batch)
if mp_log_v:
print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer)
buffer.send(batch)
except:
os.kill(os.getppid(), signal.SIGINT)
raise
def _stop_all_workers(self):
# wait until all workers idle
if self.num_idle.value < self.num_workers:
with self.gid.get_lock():
self.gid.get_obj().value = self.batch_len
if mp_log_v:
print("idle num", self.num_idle.value)
while self.num_idle.value < self.num_workers:
self.num_idle_c.wait()
if mp_log_v:
print("idle num", self.num_idle.value)
# clean workers' buffer
for w in self.workers:
w.buffer.clear()
def _init_workers(self):
self.index_list = mp.Array('i', self.total_len, lock=False)
workers = []
# batch id to worker id
self.idmap = mp.Array('i', self.batch_len, lock=False)
self.gid = mp.Value('i', self.batch_len)
self.gidc = mp.Condition(self.gid.get_lock())
self.num_idle = mp.Value('i', 0, lock=False)
self.num_idle_c = mp.Condition(self.gid.get_lock())
for i in range(self.num_workers):
w = Worker(target=self._worker_main, args=(i,),
buffer_size=self.buffer_size)
workers.append(w)
self.workers = workers
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.total_len, buffer=self.index_list)
def __del__(self):
if mp_log_v:
print("dataset deleted")
self.terminate()
def __iter__(self):
if self.shuffle == False:
index_list = get_order_list(self.total_len)
else:
index_list = get_random_list(self.total_len)
batch_data = []
for idx in index_list:
batch_data.append(self[int(idx)])
if len(batch_data) == self.batch_size:
batch_data = self.collate_batch(batch_data)
yield batch_data
batch_data = []
self.batch_len = len(self)
if "batch_len" in os.environ:
self.batch_len = int(os.environ["batch_len"])
if not hasattr(self, "workers") and self.num_workers:
self._init_workers()
if self.num_workers:
self._stop_all_workers()
self.index_list_numpy[:] = index_list
gid_obj = self.gid.get_obj()
gid_lock = self.gid.get_lock()
with gid_lock:
gid_obj.value = 0
self.gidc.notify_all()
for i in range(self.batch_len):
# try not get lock first
if gid_obj.value <= i:
with gid_lock:
if gid_obj.value <= i:
if mp_log_v:
print("wait")
self.gidc.wait()
worker_id = self.idmap[i]
w = self.workers[worker_id]
if mp_log_v:
print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
batch = w.buffer.recv()
if mp_log_v:
print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
batch = self.to_jittor(batch)
yield batch
else:
batch_data = []
for idx in index_list:
batch_data.append(self[int(idx)])
if len(batch_data) == self.batch_size:
batch_data = self.collate_batch(batch_data)
batch_data = self.to_jittor(batch_data)
yield batch_data
batch_data = []
# depend on drop_last
if not self.drop_last and len(batch_data) > 0:
batch_data = self.collate_batch(batch_data)
batch_data = self.to_jittor(batch_data)
yield batch_data
# depend on drop_last
if not self.drop_last and len(batch_data) > 0:
batch_data = self.collate_batch(batch_data)
yield batch_data
class ImageFolder(Dataset):
"""A image classify dataset, load image and label from directory:
@ -120,7 +261,7 @@ class ImageFolder(Dataset):
if os.path.splitext(fname)[-1].lower() in image_exts:
path = os.path.join(class_dir, fname)
self.imgs.append((path, i))
print(f"Found {len(self.classes)} classes and {len(self.imgs)} images.")
LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.")
self.set_attrs(total_len=len(self.imgs))
def __getitem__(self, k):

View File

@ -98,23 +98,26 @@ def collate_batch(batch):
return jt.array(temp_data)
if elem_type is np.ndarray:
temp_data = np.stack([data for data in batch], 0)
return jt.array(temp_data)
return temp_data
elif np.issubdtype(elem_type, np.integer):
return jt.array(batch)
return np.int32(batch)
elif isinstance(elem, int):
return jt.array(batch)
return np.int32(batch)
elif isinstance(elem, float):
return jt.array(batch)
return np.float32(batch)
elif isinstance(elem, str):
return batch
elif isinstance(elem, Mapping):
return {key: collate_batch([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple):
transposed = zip(*batch)
return tuple(collate_batch(samples) for samples in transposed)
elif isinstance(elem, Sequence):
transposed = zip(*batch)
return [collate_batch(samples) for samples in transposed]
elif isinstance(elem, Image.Image):
temp_data = np.stack([np.array(data) for data in batch], 0)
return jt.array(temp_data)
return temp_data
else:
raise TypeError(f"Not support type <{elem_type.__name__}>")

View File

@ -14,6 +14,7 @@ from jittor.test.test_log import find_log_with_re
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
except:
skip_this_test = True

View File

@ -0,0 +1,79 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
from jittor.dataset.dataset import ImageFolder
import jittor.transform as transform
import jittor as jt
import unittest
import os
import numpy as np
import random
pass_this_test = False
msg = ""
mid = 0
if os.uname()[1] == "jittor-ce":
mid = 1
try:
traindir = ["/data1/cjld/imagenet/train/","/home/cjld/imagenet/train/"][mid]
assert os.path.isdir(traindir)
except Exception as e:
pass_this_test = True
msg = str(e)
@unittest.skipIf(pass_this_test, f"can not run imagenet dataset test: {msg}")
class TestDataset(unittest.TestCase):
def test_multi_workers(self):
check_num_batch = 10
tc_data = []
def get_dataset():
dataset = ImageFolder(traindir).set_attrs(batch_size=256, shuffle=False)
dataset.set_attrs(transform = transform.Compose([
transform.Resize(224),
transform.ImageNormalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
]), num_workers=0)
return dataset
dataset = get_dataset()
for i, data in enumerate(dataset):
print("get batch", i)
tc_data.append(data)
if i==check_num_batch: break
def check(num_workers, epoch=1):
dataset = get_dataset().set_attrs(num_workers=num_workers)
random.seed(0)
for _ in range(epoch):
for i, (images, labels) in enumerate(dataset):
print("compare", i)
assert np.allclose(images.data, tc_data[i][0].data), \
(images.sum(), tc_data[i][0].sum(), images.shape,
tc_data[i][0].shape)
assert np.allclose(labels.data, tc_data[i][1].data)
if i==check_num_batch: break
# dataset.terminate()
check(1)
check(2)
check(4,2)
def test_collate_batch(self):
from jittor.dataset.utils import collate_batch
batch = collate_batch([(1,1),(1,2),(1,3)])
assert isinstance(batch[0], np.ndarray)
assert isinstance(batch[1], np.ndarray)
if __name__ == "__main__":
unittest.main()

View File

@ -20,6 +20,7 @@ if os.uname()[1] == "jittor-ce":
try:
# check can we run this test
# test code
jt.dirty_fix_pytorch_runtime_error()
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch
@ -29,7 +30,7 @@ try:
assert os.path.isdir(traindir)
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
std=[0.229, 0.224, 0.225])
train_dataset = datasets.ImageFolder(
traindir,
transforms.Compose([

View File

@ -0,0 +1,78 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
from jittor_utils.ring_buffer import *
import unittest
def test_ring_buffer():
buffer = mp.Array('c', 8000, lock=False)
buffer = RingBuffer(buffer)
def test_send_recv(data):
print("test send recv", type(data))
buffer.send(data)
recv = buffer.recv()
if isinstance(recv, np.ndarray):
assert (recv == data).all()
else:
assert data == recv
print(buffer)
test_send_recv("float32")
test_send_recv("")
test_send_recv("xxxxxxxxxx")
test_send_recv(1)
test_send_recv(100000000000)
test_send_recv(1e-5)
test_send_recv(100000000000.0)
test_send_recv([1,0.2])
test_send_recv({'asd':1})
test_send_recv(np.random.rand(10,10))
def test_ring_buffer_allocator(p=0.7):
print("test_ring_buffer_allocator", p)
n = 1000
buffer = RingBufferAllocator(n)
m = 10000
sizes = [0]*m
a = [-1]*n
l = 0
r = 0
for i in range(m):
if l==r or random.random()<0.7:
size = random.randint(10, 20)
location = buffer.alloc(size)
if location is not None:
sizes[r] = size
for j in range(location, location+size):
a[j] = r
r += 1
# print(r-l, buffer)
continue
assert l<r
size = sizes[l]
location = buffer.free(size)
assert location is not None, buffer
for j in range(location, location+size):
assert a[j] == l
l += 1
# print(r-l, buffer)
class TestReindexOp(unittest.TestCase):
def test_ring_buffer_allocator(self):
test_ring_buffer_allocator(0.7)
test_ring_buffer_allocator(0.3)
def test_ring_buffer(self):
test_ring_buffer()
if __name__ == "__main__":
unittest.main()

View File

@ -0,0 +1,272 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import multiprocessing as mp
import numpy as np
import ctypes
import random
import pickle
import ctypes
class RingBufferAllocator:
def __init__(self, size):
self.size = size
self.l = mp.Value(ctypes.c_longlong, 0, lock=False)
self.r = mp.Value(ctypes.c_longlong, 0, lock=False)
self.is_full = mp.Value(ctypes.c_bool, False, lock=False)
self.lock = mp.Lock()
self.cv = mp.Condition(self.lock)
def __repr__(self):
l = self.l.value
r = self.r.value
is_full = self.is_full.value
if is_full:
cap = 0
else:
cap = int((r - l) / self.size * 100)
if cap<=0: cap += 100
return f"Buffer(free={int(cap)}%)"
def alloc_with_lock(self, size):
with self.lock:
while True:
location = self.alloc(size)
if location is not None: break
# print("alloc wait", size, self)
self.cv.wait()
return location
def free_with_lock(self, size):
with self.lock:
location = self.free(size)
self.cv.notify()
return location
def clear(self):
with self.lock:
self.l.value = 0
self.r.value = 0
self.is_full.value = False
def alloc(self, size):
if size > self.size:
raise RuntimeError(f"Buffer size too small {self.size}<{size}")
l = self.l.value
r = self.r.value
is_full = self.is_full.value
if is_full: return None
if l == r and l > 0:
self.l.value = self.r.value = l = r = 0
# [l, r)
if r > l:
freed = r - l
if freed < size:
# |----l......r---|
# |----#########--|
return None
# |----l......r---|
# |----#####------|
location = l
self.l.value = l = l + size
else:
freed = self.size - l
if freed < size:
# |.....r------l...|
# |------------#######
if size > r:
# |.....r------l...|
# |#######-----------
return None
# |.....r------l...|
# |#####-----------
if size == r:
self.is_full.value = is_full= True
location = 0
self.l.value = l = size
else:
# |.....r------l...|
# |------------##--|
location = l
if freed == size:
self.l.value = l = 0
else:
self.l.value = l = l + size
if l == r:
self.is_full.value = is_full = True
return location
def free(self, size):
l = self.l.value
r = self.r.value
is_full = self.is_full.value
if size==0: return r
if is_full:
self.is_full.value = is_full = False
elif l == r:
return None
location = r
self.r.value = r = r + size
if r > self.size:
location = 0
self.r.value = r = size
elif r == self.size:
self.r.value = r = 0
return location
def str_to_char_array(s, array_len):
if len(s) > array_len: s = s[:array_len]
a = np.array(s, dtype='c')
if len(s) < array_len:
a = np.pad(a, (0,array_len-len(s)), constant_values=' ')
return a
def char_array_to_str(a):
return str(a.tostring(), 'ascii').strip()
class RingBuffer:
def __init__(self, buffer):
self.allocator = RingBufferAllocator(len(buffer))
self.buffer = buffer
def clear(self): self.allocator.clear()
def send_int(self, data):
# int: int64[1]
# data
self.send_raw(np.array([data], dtype='int64'))
def recv_int(self):
return int(self.recv_raw(8, (1,), 'int64')[0])
def send_float(self, data):
# float: float64[1]
# data
self.send_raw(np.array([data], dtype='float64'))
def recv_float(self):
return float(self.recv_raw(8, (1,), 'float64')[0])
def send_str(self, data):
# str: int64[1] char[len]
# len data
data = np.array(data, dtype='c')
self.send_int(data.nbytes)
self.send_raw(data)
def recv_str(self):
nbytes = self.recv_int()
data = self.recv_raw(nbytes, nbytes, 'c')
return str(data.tostring(), 'ascii')
def send_ndarray(self, data):
# str: int64[1] char[8] int64[1] int64[slen] char[nbytes]
# slen dtype nbytes shape data
shape = data.shape
slen = len(shape)
self.send_int(slen)
self.send_fix_len_str(str(data.dtype))
self.send_int(data.nbytes)
self.send_raw(np.array(shape, dtype='int64'))
self.send_raw(data)
def recv_ndarray(self):
slen = self.recv_int()
dtype = self.recv_fix_len_str()
nbytes = self.recv_int()
shape = self.recv_raw(slen*8, slen, 'int64')
data = self.recv_raw(nbytes, shape, dtype)
return data
def send_tuple(self, data):
# tuple: int64[1] ....
# len
length = len(data)
self.send_int(length)
for a in data:
self.send(a)
def recv_tuple(self):
length = self.recv_int()
return tuple(self.recv() for i in range(length))
def send_list(self, data):
# list: int64[1] ....
# len
length = len(data)
self.send_int(length)
for a in data:
self.send(a)
def recv_list(self):
length = self.recv_int()
return [self.recv() for i in range(length)]
def send_pickle(self, data):
# pickle: int64[1] char[len]
# len data
# print("send pickle")
data = pickle.dumps(data)
data = np.frombuffer(data, dtype='c')
self.send_int(data.nbytes)
self.send_raw(data)
def recv_pickle(self):
nbytes = self.recv_int()
data = self.recv_raw(nbytes, nbytes, 'c')
return pickle.loads(data.tostring())
def __repr__(self):
return f"{self.allocator}@0x{hex(ctypes.addressof(self.buffer))}"
def send_raw(self, data):
assert isinstance(data, np.ndarray) # and data.flags.c_contiguous
# print("send raw", data.shape, data.dtype, data.nbytes)
with self.allocator.lock:
location = self.allocator.alloc(data.nbytes)
while location is None:
# print("send wait")
self.allocator.cv.wait()
location = self.allocator.alloc(data.nbytes)
window = np.ndarray(shape=data.shape, dtype=data.dtype,
buffer=self.buffer, offset=location)
window[:] = data
# print("send notify")
self.allocator.cv.notify()
assert window.nbytes == data.nbytes
def recv_raw(self, nbytes, shape, dtype):
# print("recv raw", shape, dtype, nbytes)
with self.allocator.lock:
location = self.allocator.free(nbytes)
# print("recv get location", location)
while location is None:
# print("recv wait")
self.allocator.cv.wait()
location = self.allocator.free(nbytes)
data = np.ndarray(shape=shape, dtype=dtype,
buffer=self.buffer, offset=location).copy()
# print("recv notify")
self.allocator.cv.notify()
assert data.nbytes == nbytes
return data
def send_fix_len_str(self, s, array_len=8):
data = str_to_char_array(s, array_len)
self.send_raw(data)
def recv_fix_len_str(self, array_len=8):
data = self.recv_raw(8, 8, 'c')
return char_array_to_str(data)
def send(self, data):
ts = type(data).__name__
send = getattr(self, "send_"+ts, self.send_pickle)
self.send_fix_len_str(ts)
send(data)
def recv(self):
ts = self.recv_fix_len_str()
# print("recv", ts)
recv = getattr(self, "recv_"+ts, self.recv_pickle)
return recv()