polish code

This commit is contained in:
Dun Liang 2020-03-29 17:13:00 +08:00
parent 581acecc8e
commit 18fc4baa3d
3 changed files with 19 additions and 22 deletions

View File

@ -49,16 +49,19 @@ class Dataset(object):
for x, y in dataset:
......
'''
def __init__(self):
def __init__(self,
batch_size = 16,
shuffle = False,
drop_last = False,
num_workers = 0,
buffer_size = 512*1024*1024):
super().__init__()
self.batch_size = 16
self.total_len = None
self.shuffle = False
self.drop_last = False
self.num_workers = 0
self.buffer_size = 512*1024*1024
if "num_workers" in os.environ:
self.num_workers = int(os.environ["num_workers"])
self.batch_size = batch_size
self.shuffle = shuffle
self.drop_last = drop_last
self.num_workers = num_workers
self.buffer_size = buffer_size
def __getitem__(self, index):
raise NotImplementedError
@ -76,10 +79,12 @@ class Dataset(object):
Attrs:
batch_size(int): batch size, default 16.
totol_len(int): totol lenght.
num_workers: number of workers for loading data
shuffle(bool): shuffle at each epoch, default False.
drop_last(bool): if true, the last batch of dataset
might smaller than batch_size, default True.
num_workers: number of workers for loading data
buffer_size: buffer size for each worker in bytes,
default(512MB).
'''
for k,v in kw.items():
assert hasattr(self, k), k
@ -156,9 +161,13 @@ class Dataset(object):
workers = []
# batch id to worker id
self.idmap = mp.Array('i', self.batch_len, lock=False)
# global token index
self.gid = mp.Value('i', self.batch_len)
# global token index condition
self.gidc = mp.Condition(self.gid.get_lock())
# number of idle workers
self.num_idle = mp.Value('i', 0, lock=False)
# number of idle workers condition
self.num_idle_c = mp.Condition(self.gid.get_lock())
for i in range(self.num_workers):
w = Worker(target=self._worker_main, args=(i,),

View File

@ -19,7 +19,6 @@ def test_ring_buffer():
assert (recv == data).all()
else:
assert data == recv
print(buffer)
test_send_recv("float32")
test_send_recv("")
test_send_recv("xxxxxxxxxx")
@ -53,7 +52,6 @@ def test_ring_buffer_allocator(p=0.7):
for j in range(location, location+size):
a[j] = r
r += 1
# print(r-l, buffer)
continue
assert l<r
size = sizes[l]
@ -62,7 +60,6 @@ def test_ring_buffer_allocator(p=0.7):
for j in range(location, location+size):
assert a[j] == l
l += 1
# print(r-l, buffer)
class TestReindexOp(unittest.TestCase):

View File

@ -20,6 +20,7 @@ class RingBufferAllocator:
self.is_full = mp.Value(ctypes.c_bool, False, lock=False)
self.lock = mp.Lock()
self.cv = mp.Condition(self.lock)
def __repr__(self):
l = self.l.value
r = self.r.value
@ -36,7 +37,6 @@ class RingBufferAllocator:
while True:
location = self.alloc(size)
if location is not None: break
# print("alloc wait", size, self)
self.cv.wait()
return location
@ -204,7 +204,6 @@ class RingBuffer:
def send_pickle(self, data):
# pickle: int64[1] char[len]
# len data
# print("send pickle")
data = pickle.dumps(data)
data = np.frombuffer(data, dtype='c')
self.send_int(data.nbytes)
@ -220,32 +219,25 @@ class RingBuffer:
def send_raw(self, data):
assert isinstance(data, np.ndarray) # and data.flags.c_contiguous
# print("send raw", data.shape, data.dtype, data.nbytes)
with self.allocator.lock:
location = self.allocator.alloc(data.nbytes)
while location is None:
# print("send wait")
self.allocator.cv.wait()
location = self.allocator.alloc(data.nbytes)
window = np.ndarray(shape=data.shape, dtype=data.dtype,
buffer=self.buffer, offset=location)
window[:] = data
# print("send notify")
self.allocator.cv.notify()
assert window.nbytes == data.nbytes
def recv_raw(self, nbytes, shape, dtype):
# print("recv raw", shape, dtype, nbytes)
with self.allocator.lock:
location = self.allocator.free(nbytes)
# print("recv get location", location)
while location is None:
# print("recv wait")
self.allocator.cv.wait()
location = self.allocator.free(nbytes)
data = np.ndarray(shape=shape, dtype=dtype,
buffer=self.buffer, offset=location).copy()
# print("recv notify")
self.allocator.cv.notify()
assert data.nbytes == nbytes
return data
@ -266,7 +258,6 @@ class RingBuffer:
def recv(self):
ts = self.recv_fix_len_str()
# print("recv", ts)
recv = getattr(self, "recv_"+ts, self.recv_pickle)
return recv()