mirror of https://github.com/Jittor/Jittor
multiprocess dataset loader
This commit is contained in:
parent
fd9867ebb9
commit
581acecc8e
|
@ -6,17 +6,31 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import numpy as np
|
||||
from urllib import request
|
||||
import gzip
|
||||
import pickle
|
||||
import os
|
||||
from jittor.dataset.utils import get_random_list, get_order_list, collate_batch
|
||||
from collections.abc import Sequence, Mapping
|
||||
import pathlib
|
||||
from PIL import Image
|
||||
from jittor_utils.ring_buffer import RingBuffer
|
||||
import multiprocessing as mp
|
||||
import signal
|
||||
from jittor_utils import LOG
|
||||
import jittor as jt
|
||||
|
||||
dataset_root = os.path.join(pathlib.Path.home(), ".cache", "jittor", "dataset")
|
||||
mp_log_v = os.environ.get("mp_log_v", 0)
|
||||
|
||||
class Worker:
|
||||
def __init__(self, target, args, buffer_size):
|
||||
buffer = mp.Array('c', buffer_size, lock=False)
|
||||
self.buffer = RingBuffer(buffer)
|
||||
self.p = mp.Process(target=target, args=args+(self.buffer,))
|
||||
self.p.daemon = True
|
||||
self.p.start()
|
||||
|
||||
class Dataset(object):
|
||||
'''
|
||||
|
@ -41,6 +55,10 @@ class Dataset(object):
|
|||
self.total_len = None
|
||||
self.shuffle = False
|
||||
self.drop_last = False
|
||||
self.num_workers = 0
|
||||
self.buffer_size = 512*1024*1024
|
||||
if "num_workers" in os.environ:
|
||||
self.num_workers = int(os.environ["num_workers"])
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
@ -48,6 +66,8 @@ class Dataset(object):
|
|||
def __len__(self):
|
||||
assert self.total_len >= 0
|
||||
assert self.batch_size > 0
|
||||
if self.drop_last:
|
||||
return self.total_len // self.batch_size
|
||||
return (self.total_len-1) // self.batch_size + 1
|
||||
|
||||
def set_attrs(self, **kw):
|
||||
|
@ -56,6 +76,7 @@ class Dataset(object):
|
|||
Attrs:
|
||||
batch_size(int): batch size, default 16.
|
||||
totol_len(int): totol lenght.
|
||||
num_workers: number of workers for loading data
|
||||
shuffle(bool): shuffle at each epoch, default False.
|
||||
drop_last(bool): if true, the last batch of dataset
|
||||
might smaller than batch_size, default True.
|
||||
|
@ -65,26 +86,146 @@ class Dataset(object):
|
|||
setattr(self, k, v)
|
||||
return self
|
||||
|
||||
def to_jittor(self, batch):
|
||||
if isinstance(batch, np.ndarray):
|
||||
return jt.array(batch)
|
||||
assert isinstance(batch, Sequence)
|
||||
new_batch = []
|
||||
for a in batch:
|
||||
if isinstance(a, np.ndarray) or \
|
||||
isinstance(a, int) or \
|
||||
isinstance(a, float):
|
||||
new_batch.append(jt.array(a))
|
||||
else:
|
||||
new_batch.append(a)
|
||||
return new_batch
|
||||
|
||||
def collate_batch(self, batch):
|
||||
return collate_batch(batch)
|
||||
|
||||
def terminate(self):
|
||||
if hasattr(self, "workers"):
|
||||
for w in self.workers:
|
||||
w.p.terminate()
|
||||
|
||||
def _worker_main(self, worker_id, buffer):
|
||||
try:
|
||||
gid_obj = self.gid.get_obj()
|
||||
gid_lock = self.gid.get_lock()
|
||||
while True:
|
||||
with gid_lock:
|
||||
while gid_obj.value >= self.batch_len:
|
||||
self.num_idle.value += 1
|
||||
self.num_idle_c.notify()
|
||||
self.gidc.wait()
|
||||
self.num_idle.value -= 1
|
||||
cid = gid_obj.value
|
||||
self.idmap[cid] = worker_id
|
||||
gid_obj.value += 1
|
||||
self.gidc.notify()
|
||||
batch = []
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} load batch", cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size))
|
||||
for i in range(cid*self.batch_size, min(self.total_len, (cid+1)*self.batch_size)):
|
||||
batch.append(self[self.index_list[i]])
|
||||
batch = self.collate_batch(batch)
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} send", type(batch).__name__, [ type(b).__name__ for b in batch ], buffer)
|
||||
buffer.send(batch)
|
||||
except:
|
||||
os.kill(os.getppid(), signal.SIGINT)
|
||||
raise
|
||||
|
||||
def _stop_all_workers(self):
|
||||
# wait until all workers idle
|
||||
if self.num_idle.value < self.num_workers:
|
||||
with self.gid.get_lock():
|
||||
self.gid.get_obj().value = self.batch_len
|
||||
if mp_log_v:
|
||||
print("idle num", self.num_idle.value)
|
||||
while self.num_idle.value < self.num_workers:
|
||||
self.num_idle_c.wait()
|
||||
if mp_log_v:
|
||||
print("idle num", self.num_idle.value)
|
||||
# clean workers' buffer
|
||||
for w in self.workers:
|
||||
w.buffer.clear()
|
||||
|
||||
def _init_workers(self):
|
||||
self.index_list = mp.Array('i', self.total_len, lock=False)
|
||||
workers = []
|
||||
# batch id to worker id
|
||||
self.idmap = mp.Array('i', self.batch_len, lock=False)
|
||||
self.gid = mp.Value('i', self.batch_len)
|
||||
self.gidc = mp.Condition(self.gid.get_lock())
|
||||
self.num_idle = mp.Value('i', 0, lock=False)
|
||||
self.num_idle_c = mp.Condition(self.gid.get_lock())
|
||||
for i in range(self.num_workers):
|
||||
w = Worker(target=self._worker_main, args=(i,),
|
||||
buffer_size=self.buffer_size)
|
||||
workers.append(w)
|
||||
self.workers = workers
|
||||
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.total_len, buffer=self.index_list)
|
||||
|
||||
def __del__(self):
|
||||
if mp_log_v:
|
||||
print("dataset deleted")
|
||||
self.terminate()
|
||||
|
||||
def __iter__(self):
|
||||
if self.shuffle == False:
|
||||
index_list = get_order_list(self.total_len)
|
||||
else:
|
||||
index_list = get_random_list(self.total_len)
|
||||
batch_data = []
|
||||
for idx in index_list:
|
||||
batch_data.append(self[int(idx)])
|
||||
if len(batch_data) == self.batch_size:
|
||||
batch_data = self.collate_batch(batch_data)
|
||||
yield batch_data
|
||||
batch_data = []
|
||||
|
||||
self.batch_len = len(self)
|
||||
if "batch_len" in os.environ:
|
||||
self.batch_len = int(os.environ["batch_len"])
|
||||
|
||||
if not hasattr(self, "workers") and self.num_workers:
|
||||
self._init_workers()
|
||||
|
||||
if self.num_workers:
|
||||
self._stop_all_workers()
|
||||
self.index_list_numpy[:] = index_list
|
||||
gid_obj = self.gid.get_obj()
|
||||
gid_lock = self.gid.get_lock()
|
||||
with gid_lock:
|
||||
gid_obj.value = 0
|
||||
self.gidc.notify_all()
|
||||
for i in range(self.batch_len):
|
||||
# try not get lock first
|
||||
if gid_obj.value <= i:
|
||||
with gid_lock:
|
||||
if gid_obj.value <= i:
|
||||
if mp_log_v:
|
||||
print("wait")
|
||||
self.gidc.wait()
|
||||
worker_id = self.idmap[i]
|
||||
w = self.workers[worker_id]
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
|
||||
batch = w.buffer.recv()
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
|
||||
batch = self.to_jittor(batch)
|
||||
yield batch
|
||||
else:
|
||||
batch_data = []
|
||||
for idx in index_list:
|
||||
batch_data.append(self[int(idx)])
|
||||
if len(batch_data) == self.batch_size:
|
||||
batch_data = self.collate_batch(batch_data)
|
||||
batch_data = self.to_jittor(batch_data)
|
||||
yield batch_data
|
||||
batch_data = []
|
||||
|
||||
# depend on drop_last
|
||||
if not self.drop_last and len(batch_data) > 0:
|
||||
batch_data = self.collate_batch(batch_data)
|
||||
batch_data = self.to_jittor(batch_data)
|
||||
yield batch_data
|
||||
|
||||
# depend on drop_last
|
||||
if not self.drop_last and len(batch_data) > 0:
|
||||
batch_data = self.collate_batch(batch_data)
|
||||
yield batch_data
|
||||
|
||||
class ImageFolder(Dataset):
|
||||
"""A image classify dataset, load image and label from directory:
|
||||
|
@ -120,7 +261,7 @@ class ImageFolder(Dataset):
|
|||
if os.path.splitext(fname)[-1].lower() in image_exts:
|
||||
path = os.path.join(class_dir, fname)
|
||||
self.imgs.append((path, i))
|
||||
print(f"Found {len(self.classes)} classes and {len(self.imgs)} images.")
|
||||
LOG.i(f"Found {len(self.classes)} classes and {len(self.imgs)} images.")
|
||||
self.set_attrs(total_len=len(self.imgs))
|
||||
|
||||
def __getitem__(self, k):
|
||||
|
|
|
@ -98,23 +98,26 @@ def collate_batch(batch):
|
|||
return jt.array(temp_data)
|
||||
if elem_type is np.ndarray:
|
||||
temp_data = np.stack([data for data in batch], 0)
|
||||
return jt.array(temp_data)
|
||||
return temp_data
|
||||
elif np.issubdtype(elem_type, np.integer):
|
||||
return jt.array(batch)
|
||||
return np.int32(batch)
|
||||
elif isinstance(elem, int):
|
||||
return jt.array(batch)
|
||||
return np.int32(batch)
|
||||
elif isinstance(elem, float):
|
||||
return jt.array(batch)
|
||||
return np.float32(batch)
|
||||
elif isinstance(elem, str):
|
||||
return batch
|
||||
elif isinstance(elem, Mapping):
|
||||
return {key: collate_batch([d[key] for d in batch]) for key in elem}
|
||||
elif isinstance(elem, tuple):
|
||||
transposed = zip(*batch)
|
||||
return tuple(collate_batch(samples) for samples in transposed)
|
||||
elif isinstance(elem, Sequence):
|
||||
transposed = zip(*batch)
|
||||
return [collate_batch(samples) for samples in transposed]
|
||||
elif isinstance(elem, Image.Image):
|
||||
temp_data = np.stack([np.array(data) for data in batch], 0)
|
||||
return jt.array(temp_data)
|
||||
return temp_data
|
||||
else:
|
||||
raise TypeError(f"Not support type <{elem_type.__name__}>")
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ from jittor.test.test_log import find_log_with_re
|
|||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
skip_this_test = True
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
from jittor.dataset.dataset import ImageFolder
|
||||
import jittor.transform as transform
|
||||
|
||||
import jittor as jt
|
||||
import unittest
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
pass_this_test = False
|
||||
msg = ""
|
||||
mid = 0
|
||||
if os.uname()[1] == "jittor-ce":
|
||||
mid = 1
|
||||
try:
|
||||
traindir = ["/data1/cjld/imagenet/train/","/home/cjld/imagenet/train/"][mid]
|
||||
assert os.path.isdir(traindir)
|
||||
except Exception as e:
|
||||
pass_this_test = True
|
||||
msg = str(e)
|
||||
|
||||
@unittest.skipIf(pass_this_test, f"can not run imagenet dataset test: {msg}")
|
||||
class TestDataset(unittest.TestCase):
|
||||
def test_multi_workers(self):
|
||||
check_num_batch = 10
|
||||
tc_data = []
|
||||
|
||||
def get_dataset():
|
||||
dataset = ImageFolder(traindir).set_attrs(batch_size=256, shuffle=False)
|
||||
dataset.set_attrs(transform = transform.Compose([
|
||||
transform.Resize(224),
|
||||
transform.ImageNormalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
]), num_workers=0)
|
||||
return dataset
|
||||
|
||||
dataset = get_dataset()
|
||||
|
||||
for i, data in enumerate(dataset):
|
||||
print("get batch", i)
|
||||
tc_data.append(data)
|
||||
if i==check_num_batch: break
|
||||
|
||||
def check(num_workers, epoch=1):
|
||||
dataset = get_dataset().set_attrs(num_workers=num_workers)
|
||||
|
||||
random.seed(0)
|
||||
|
||||
for _ in range(epoch):
|
||||
for i, (images, labels) in enumerate(dataset):
|
||||
print("compare", i)
|
||||
assert np.allclose(images.data, tc_data[i][0].data), \
|
||||
(images.sum(), tc_data[i][0].sum(), images.shape,
|
||||
tc_data[i][0].shape)
|
||||
assert np.allclose(labels.data, tc_data[i][1].data)
|
||||
if i==check_num_batch: break
|
||||
# dataset.terminate()
|
||||
check(1)
|
||||
check(2)
|
||||
check(4,2)
|
||||
|
||||
def test_collate_batch(self):
|
||||
from jittor.dataset.utils import collate_batch
|
||||
batch = collate_batch([(1,1),(1,2),(1,3)])
|
||||
assert isinstance(batch[0], np.ndarray)
|
||||
assert isinstance(batch[1], np.ndarray)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -20,6 +20,7 @@ if os.uname()[1] == "jittor-ce":
|
|||
try:
|
||||
# check can we run this test
|
||||
# test code
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torchvision.datasets as datasets
|
||||
import torchvision.transforms as transforms
|
||||
import torch
|
||||
|
@ -29,7 +30,7 @@ try:
|
|||
assert os.path.isdir(traindir)
|
||||
|
||||
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225])
|
||||
std=[0.229, 0.224, 0.225])
|
||||
train_dataset = datasets.ImageFolder(
|
||||
traindir,
|
||||
transforms.Compose([
|
||||
|
|
|
@ -0,0 +1,78 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
from jittor_utils.ring_buffer import *
|
||||
import unittest
|
||||
|
||||
def test_ring_buffer():
|
||||
buffer = mp.Array('c', 8000, lock=False)
|
||||
buffer = RingBuffer(buffer)
|
||||
def test_send_recv(data):
|
||||
print("test send recv", type(data))
|
||||
buffer.send(data)
|
||||
recv = buffer.recv()
|
||||
if isinstance(recv, np.ndarray):
|
||||
assert (recv == data).all()
|
||||
else:
|
||||
assert data == recv
|
||||
print(buffer)
|
||||
test_send_recv("float32")
|
||||
test_send_recv("")
|
||||
test_send_recv("xxxxxxxxxx")
|
||||
|
||||
test_send_recv(1)
|
||||
test_send_recv(100000000000)
|
||||
|
||||
test_send_recv(1e-5)
|
||||
test_send_recv(100000000000.0)
|
||||
|
||||
test_send_recv([1,0.2])
|
||||
test_send_recv({'asd':1})
|
||||
|
||||
test_send_recv(np.random.rand(10,10))
|
||||
|
||||
def test_ring_buffer_allocator(p=0.7):
|
||||
print("test_ring_buffer_allocator", p)
|
||||
n = 1000
|
||||
buffer = RingBufferAllocator(n)
|
||||
m = 10000
|
||||
sizes = [0]*m
|
||||
a = [-1]*n
|
||||
l = 0
|
||||
r = 0
|
||||
for i in range(m):
|
||||
if l==r or random.random()<0.7:
|
||||
size = random.randint(10, 20)
|
||||
location = buffer.alloc(size)
|
||||
if location is not None:
|
||||
sizes[r] = size
|
||||
for j in range(location, location+size):
|
||||
a[j] = r
|
||||
r += 1
|
||||
# print(r-l, buffer)
|
||||
continue
|
||||
assert l<r
|
||||
size = sizes[l]
|
||||
location = buffer.free(size)
|
||||
assert location is not None, buffer
|
||||
for j in range(location, location+size):
|
||||
assert a[j] == l
|
||||
l += 1
|
||||
# print(r-l, buffer)
|
||||
|
||||
|
||||
class TestReindexOp(unittest.TestCase):
|
||||
def test_ring_buffer_allocator(self):
|
||||
test_ring_buffer_allocator(0.7)
|
||||
test_ring_buffer_allocator(0.3)
|
||||
|
||||
def test_ring_buffer(self):
|
||||
test_ring_buffer()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,272 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import multiprocessing as mp
|
||||
import numpy as np
|
||||
import ctypes
|
||||
import random
|
||||
import pickle
|
||||
import ctypes
|
||||
|
||||
class RingBufferAllocator:
|
||||
def __init__(self, size):
|
||||
self.size = size
|
||||
self.l = mp.Value(ctypes.c_longlong, 0, lock=False)
|
||||
self.r = mp.Value(ctypes.c_longlong, 0, lock=False)
|
||||
self.is_full = mp.Value(ctypes.c_bool, False, lock=False)
|
||||
self.lock = mp.Lock()
|
||||
self.cv = mp.Condition(self.lock)
|
||||
def __repr__(self):
|
||||
l = self.l.value
|
||||
r = self.r.value
|
||||
is_full = self.is_full.value
|
||||
if is_full:
|
||||
cap = 0
|
||||
else:
|
||||
cap = int((r - l) / self.size * 100)
|
||||
if cap<=0: cap += 100
|
||||
return f"Buffer(free={int(cap)}%)"
|
||||
|
||||
def alloc_with_lock(self, size):
|
||||
with self.lock:
|
||||
while True:
|
||||
location = self.alloc(size)
|
||||
if location is not None: break
|
||||
# print("alloc wait", size, self)
|
||||
self.cv.wait()
|
||||
return location
|
||||
|
||||
def free_with_lock(self, size):
|
||||
with self.lock:
|
||||
location = self.free(size)
|
||||
self.cv.notify()
|
||||
return location
|
||||
|
||||
def clear(self):
|
||||
with self.lock:
|
||||
self.l.value = 0
|
||||
self.r.value = 0
|
||||
self.is_full.value = False
|
||||
|
||||
def alloc(self, size):
|
||||
if size > self.size:
|
||||
raise RuntimeError(f"Buffer size too small {self.size}<{size}")
|
||||
l = self.l.value
|
||||
r = self.r.value
|
||||
is_full = self.is_full.value
|
||||
if is_full: return None
|
||||
if l == r and l > 0:
|
||||
self.l.value = self.r.value = l = r = 0
|
||||
# [l, r)
|
||||
if r > l:
|
||||
freed = r - l
|
||||
if freed < size:
|
||||
# |----l......r---|
|
||||
# |----#########--|
|
||||
return None
|
||||
# |----l......r---|
|
||||
# |----#####------|
|
||||
location = l
|
||||
self.l.value = l = l + size
|
||||
else:
|
||||
freed = self.size - l
|
||||
if freed < size:
|
||||
# |.....r------l...|
|
||||
# |------------#######
|
||||
if size > r:
|
||||
# |.....r------l...|
|
||||
# |#######-----------
|
||||
return None
|
||||
# |.....r------l...|
|
||||
# |#####-----------
|
||||
if size == r:
|
||||
self.is_full.value = is_full= True
|
||||
location = 0
|
||||
self.l.value = l = size
|
||||
else:
|
||||
# |.....r------l...|
|
||||
# |------------##--|
|
||||
location = l
|
||||
if freed == size:
|
||||
self.l.value = l = 0
|
||||
else:
|
||||
self.l.value = l = l + size
|
||||
if l == r:
|
||||
self.is_full.value = is_full = True
|
||||
return location
|
||||
|
||||
def free(self, size):
|
||||
l = self.l.value
|
||||
r = self.r.value
|
||||
is_full = self.is_full.value
|
||||
if size==0: return r
|
||||
if is_full:
|
||||
self.is_full.value = is_full = False
|
||||
elif l == r:
|
||||
return None
|
||||
location = r
|
||||
self.r.value = r = r + size
|
||||
if r > self.size:
|
||||
location = 0
|
||||
self.r.value = r = size
|
||||
elif r == self.size:
|
||||
self.r.value = r = 0
|
||||
return location
|
||||
|
||||
def str_to_char_array(s, array_len):
|
||||
if len(s) > array_len: s = s[:array_len]
|
||||
a = np.array(s, dtype='c')
|
||||
if len(s) < array_len:
|
||||
a = np.pad(a, (0,array_len-len(s)), constant_values=' ')
|
||||
return a
|
||||
|
||||
def char_array_to_str(a):
|
||||
return str(a.tostring(), 'ascii').strip()
|
||||
|
||||
class RingBuffer:
|
||||
def __init__(self, buffer):
|
||||
self.allocator = RingBufferAllocator(len(buffer))
|
||||
self.buffer = buffer
|
||||
|
||||
def clear(self): self.allocator.clear()
|
||||
|
||||
def send_int(self, data):
|
||||
# int: int64[1]
|
||||
# data
|
||||
self.send_raw(np.array([data], dtype='int64'))
|
||||
def recv_int(self):
|
||||
return int(self.recv_raw(8, (1,), 'int64')[0])
|
||||
|
||||
def send_float(self, data):
|
||||
# float: float64[1]
|
||||
# data
|
||||
self.send_raw(np.array([data], dtype='float64'))
|
||||
def recv_float(self):
|
||||
return float(self.recv_raw(8, (1,), 'float64')[0])
|
||||
|
||||
def send_str(self, data):
|
||||
# str: int64[1] char[len]
|
||||
# len data
|
||||
data = np.array(data, dtype='c')
|
||||
self.send_int(data.nbytes)
|
||||
self.send_raw(data)
|
||||
def recv_str(self):
|
||||
nbytes = self.recv_int()
|
||||
data = self.recv_raw(nbytes, nbytes, 'c')
|
||||
return str(data.tostring(), 'ascii')
|
||||
|
||||
def send_ndarray(self, data):
|
||||
# str: int64[1] char[8] int64[1] int64[slen] char[nbytes]
|
||||
# slen dtype nbytes shape data
|
||||
shape = data.shape
|
||||
slen = len(shape)
|
||||
self.send_int(slen)
|
||||
self.send_fix_len_str(str(data.dtype))
|
||||
self.send_int(data.nbytes)
|
||||
self.send_raw(np.array(shape, dtype='int64'))
|
||||
self.send_raw(data)
|
||||
|
||||
def recv_ndarray(self):
|
||||
slen = self.recv_int()
|
||||
dtype = self.recv_fix_len_str()
|
||||
nbytes = self.recv_int()
|
||||
shape = self.recv_raw(slen*8, slen, 'int64')
|
||||
data = self.recv_raw(nbytes, shape, dtype)
|
||||
return data
|
||||
|
||||
def send_tuple(self, data):
|
||||
# tuple: int64[1] ....
|
||||
# len
|
||||
length = len(data)
|
||||
self.send_int(length)
|
||||
for a in data:
|
||||
self.send(a)
|
||||
def recv_tuple(self):
|
||||
length = self.recv_int()
|
||||
return tuple(self.recv() for i in range(length))
|
||||
|
||||
def send_list(self, data):
|
||||
# list: int64[1] ....
|
||||
# len
|
||||
length = len(data)
|
||||
self.send_int(length)
|
||||
for a in data:
|
||||
self.send(a)
|
||||
|
||||
def recv_list(self):
|
||||
length = self.recv_int()
|
||||
return [self.recv() for i in range(length)]
|
||||
|
||||
def send_pickle(self, data):
|
||||
# pickle: int64[1] char[len]
|
||||
# len data
|
||||
# print("send pickle")
|
||||
data = pickle.dumps(data)
|
||||
data = np.frombuffer(data, dtype='c')
|
||||
self.send_int(data.nbytes)
|
||||
self.send_raw(data)
|
||||
|
||||
def recv_pickle(self):
|
||||
nbytes = self.recv_int()
|
||||
data = self.recv_raw(nbytes, nbytes, 'c')
|
||||
return pickle.loads(data.tostring())
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.allocator}@0x{hex(ctypes.addressof(self.buffer))}"
|
||||
|
||||
def send_raw(self, data):
|
||||
assert isinstance(data, np.ndarray) # and data.flags.c_contiguous
|
||||
# print("send raw", data.shape, data.dtype, data.nbytes)
|
||||
with self.allocator.lock:
|
||||
location = self.allocator.alloc(data.nbytes)
|
||||
while location is None:
|
||||
# print("send wait")
|
||||
self.allocator.cv.wait()
|
||||
location = self.allocator.alloc(data.nbytes)
|
||||
window = np.ndarray(shape=data.shape, dtype=data.dtype,
|
||||
buffer=self.buffer, offset=location)
|
||||
window[:] = data
|
||||
# print("send notify")
|
||||
self.allocator.cv.notify()
|
||||
assert window.nbytes == data.nbytes
|
||||
|
||||
def recv_raw(self, nbytes, shape, dtype):
|
||||
# print("recv raw", shape, dtype, nbytes)
|
||||
with self.allocator.lock:
|
||||
location = self.allocator.free(nbytes)
|
||||
# print("recv get location", location)
|
||||
while location is None:
|
||||
# print("recv wait")
|
||||
self.allocator.cv.wait()
|
||||
location = self.allocator.free(nbytes)
|
||||
data = np.ndarray(shape=shape, dtype=dtype,
|
||||
buffer=self.buffer, offset=location).copy()
|
||||
# print("recv notify")
|
||||
self.allocator.cv.notify()
|
||||
assert data.nbytes == nbytes
|
||||
return data
|
||||
|
||||
def send_fix_len_str(self, s, array_len=8):
|
||||
data = str_to_char_array(s, array_len)
|
||||
self.send_raw(data)
|
||||
|
||||
def recv_fix_len_str(self, array_len=8):
|
||||
data = self.recv_raw(8, 8, 'c')
|
||||
return char_array_to_str(data)
|
||||
|
||||
def send(self, data):
|
||||
ts = type(data).__name__
|
||||
send = getattr(self, "send_"+ts, self.send_pickle)
|
||||
self.send_fix_len_str(ts)
|
||||
send(data)
|
||||
|
||||
def recv(self):
|
||||
ts = self.recv_fix_len_str()
|
||||
# print("recv", ts)
|
||||
recv = getattr(self, "recv_"+ts, self.recv_pickle)
|
||||
return recv()
|
||||
|
Loading…
Reference in New Issue