mirror of https://github.com/Jittor/Jittor
polish code
This commit is contained in:
parent
581acecc8e
commit
18fc4baa3d
|
@ -49,16 +49,19 @@ class Dataset(object):
|
|||
for x, y in dataset:
|
||||
......
|
||||
'''
|
||||
def __init__(self):
|
||||
def __init__(self,
|
||||
batch_size = 16,
|
||||
shuffle = False,
|
||||
drop_last = False,
|
||||
num_workers = 0,
|
||||
buffer_size = 512*1024*1024):
|
||||
super().__init__()
|
||||
self.batch_size = 16
|
||||
self.total_len = None
|
||||
self.shuffle = False
|
||||
self.drop_last = False
|
||||
self.num_workers = 0
|
||||
self.buffer_size = 512*1024*1024
|
||||
if "num_workers" in os.environ:
|
||||
self.num_workers = int(os.environ["num_workers"])
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
self.drop_last = drop_last
|
||||
self.num_workers = num_workers
|
||||
self.buffer_size = buffer_size
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
@ -76,10 +79,12 @@ class Dataset(object):
|
|||
Attrs:
|
||||
batch_size(int): batch size, default 16.
|
||||
totol_len(int): totol lenght.
|
||||
num_workers: number of workers for loading data
|
||||
shuffle(bool): shuffle at each epoch, default False.
|
||||
drop_last(bool): if true, the last batch of dataset
|
||||
might smaller than batch_size, default True.
|
||||
num_workers: number of workers for loading data
|
||||
buffer_size: buffer size for each worker in bytes,
|
||||
default(512MB).
|
||||
'''
|
||||
for k,v in kw.items():
|
||||
assert hasattr(self, k), k
|
||||
|
@ -156,9 +161,13 @@ class Dataset(object):
|
|||
workers = []
|
||||
# batch id to worker id
|
||||
self.idmap = mp.Array('i', self.batch_len, lock=False)
|
||||
# global token index
|
||||
self.gid = mp.Value('i', self.batch_len)
|
||||
# global token index condition
|
||||
self.gidc = mp.Condition(self.gid.get_lock())
|
||||
# number of idle workers
|
||||
self.num_idle = mp.Value('i', 0, lock=False)
|
||||
# number of idle workers condition
|
||||
self.num_idle_c = mp.Condition(self.gid.get_lock())
|
||||
for i in range(self.num_workers):
|
||||
w = Worker(target=self._worker_main, args=(i,),
|
||||
|
|
|
@ -19,7 +19,6 @@ def test_ring_buffer():
|
|||
assert (recv == data).all()
|
||||
else:
|
||||
assert data == recv
|
||||
print(buffer)
|
||||
test_send_recv("float32")
|
||||
test_send_recv("")
|
||||
test_send_recv("xxxxxxxxxx")
|
||||
|
@ -53,7 +52,6 @@ def test_ring_buffer_allocator(p=0.7):
|
|||
for j in range(location, location+size):
|
||||
a[j] = r
|
||||
r += 1
|
||||
# print(r-l, buffer)
|
||||
continue
|
||||
assert l<r
|
||||
size = sizes[l]
|
||||
|
@ -62,7 +60,6 @@ def test_ring_buffer_allocator(p=0.7):
|
|||
for j in range(location, location+size):
|
||||
assert a[j] == l
|
||||
l += 1
|
||||
# print(r-l, buffer)
|
||||
|
||||
|
||||
class TestReindexOp(unittest.TestCase):
|
||||
|
|
|
@ -20,6 +20,7 @@ class RingBufferAllocator:
|
|||
self.is_full = mp.Value(ctypes.c_bool, False, lock=False)
|
||||
self.lock = mp.Lock()
|
||||
self.cv = mp.Condition(self.lock)
|
||||
|
||||
def __repr__(self):
|
||||
l = self.l.value
|
||||
r = self.r.value
|
||||
|
@ -36,7 +37,6 @@ class RingBufferAllocator:
|
|||
while True:
|
||||
location = self.alloc(size)
|
||||
if location is not None: break
|
||||
# print("alloc wait", size, self)
|
||||
self.cv.wait()
|
||||
return location
|
||||
|
||||
|
@ -204,7 +204,6 @@ class RingBuffer:
|
|||
def send_pickle(self, data):
|
||||
# pickle: int64[1] char[len]
|
||||
# len data
|
||||
# print("send pickle")
|
||||
data = pickle.dumps(data)
|
||||
data = np.frombuffer(data, dtype='c')
|
||||
self.send_int(data.nbytes)
|
||||
|
@ -220,32 +219,25 @@ class RingBuffer:
|
|||
|
||||
def send_raw(self, data):
|
||||
assert isinstance(data, np.ndarray) # and data.flags.c_contiguous
|
||||
# print("send raw", data.shape, data.dtype, data.nbytes)
|
||||
with self.allocator.lock:
|
||||
location = self.allocator.alloc(data.nbytes)
|
||||
while location is None:
|
||||
# print("send wait")
|
||||
self.allocator.cv.wait()
|
||||
location = self.allocator.alloc(data.nbytes)
|
||||
window = np.ndarray(shape=data.shape, dtype=data.dtype,
|
||||
buffer=self.buffer, offset=location)
|
||||
window[:] = data
|
||||
# print("send notify")
|
||||
self.allocator.cv.notify()
|
||||
assert window.nbytes == data.nbytes
|
||||
|
||||
def recv_raw(self, nbytes, shape, dtype):
|
||||
# print("recv raw", shape, dtype, nbytes)
|
||||
with self.allocator.lock:
|
||||
location = self.allocator.free(nbytes)
|
||||
# print("recv get location", location)
|
||||
while location is None:
|
||||
# print("recv wait")
|
||||
self.allocator.cv.wait()
|
||||
location = self.allocator.free(nbytes)
|
||||
data = np.ndarray(shape=shape, dtype=dtype,
|
||||
buffer=self.buffer, offset=location).copy()
|
||||
# print("recv notify")
|
||||
self.allocator.cv.notify()
|
||||
assert data.nbytes == nbytes
|
||||
return data
|
||||
|
@ -266,7 +258,6 @@ class RingBuffer:
|
|||
|
||||
def recv(self):
|
||||
ts = self.recv_fix_len_str()
|
||||
# print("recv", ts)
|
||||
recv = getattr(self, "recv_"+ts, self.recv_pickle)
|
||||
return recv()
|
||||
|
||||
|
|
Loading…
Reference in New Issue