add endless dataset

This commit is contained in:
Dun Liang 2021-07-01 17:42:46 +08:00
parent 8fcdb63236
commit 4828cdc896
4 changed files with 118 additions and 85 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.49'
__version__ = '1.2.3.50'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -48,6 +48,8 @@ class Dataset(object):
[in] drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.
[in] num_workers(int): number of workers for loading data.
[in] buffer_size(int): buffer size for each worker in bytes, default(512MB).
[in] keep_numpy_array(bool): return numpy array rather than jittor array, default(False).
[in] endless(bool): will this dataset yield data forever, default(False).
Example::
@ -70,7 +72,8 @@ class Dataset(object):
num_workers = 0,
buffer_size = 512*1024*1024,
stop_grad = True,
keep_numpy_array = False):
keep_numpy_array = False,
endless = False):
super().__init__()
if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1':
num_workers = 0
@ -82,6 +85,8 @@ class Dataset(object):
self.buffer_size = buffer_size
self.stop_grad = stop_grad
self.keep_numpy_array = keep_numpy_array
self.endless = endless
self.epoch_id = 0
self.sampler = None
def __getitem__(self, index):
@ -182,14 +187,22 @@ class Dataset(object):
while True:
# get id
with gid_lock:
while gid_obj.value >= self.batch_len or buffer.is_stop():
while buffer.is_stop() or self.idqueue.is_stop():
self.num_idle.value += 1
self.num_idle_c.notify()
self.gidc.wait()
self.num_idle.value -= 1
cid = gid_obj.value
self.idmap[cid] = worker_id
gid_obj.value += 1
if cid == 0:
index_list = self._get_index_list()
self.index_list_numpy[:] = index_list
batch_index_list = self.index_list_numpy[
cid*self.real_batch_size:
min(self.real_len, (cid+1)*self.real_batch_size)
].copy()
# print(batch_index_list, cid, self.real_batch_size, self.real_len)
self.idqueue.push(worker_id)
gid_obj.value = (gid_obj.value + 1) % self.batch_len
self.gidc.notify()
now = time.time()
other_time = now - start
@ -199,8 +212,8 @@ class Dataset(object):
batch = []
if mp_log_v:
print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size))
for i in range(cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)):
batch.append(self[self.index_list[i]])
for i in batch_index_list:
batch.append(self[i])
batch = self.collate_batch(batch)
now = time.time()
data_time = now - start
@ -278,10 +291,10 @@ Example::
if not hasattr(self, "workers"):
return
msg = [""]
msg.append(f"progress:{self.last_id}/{self.batch_len}")
msg.append(f"progress:{self.batch_id}/{self.batch_len}")
msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}")
msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}")
msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id-9):self.last_id+1]}")
msg.append(f"last 10 workers: {self.last_ids}")
msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)")
for i in range(self.num_workers):
w = self.workers[i]
@ -293,6 +306,7 @@ Example::
# stop workers
for w in self.workers:
w.buffer.stop()
self.idqueue.stop()
# wait until all workers idle
if self.num_idle.value < self.num_workers:
with self.gid.get_lock():
@ -306,29 +320,32 @@ Example::
# clean workers' buffer
for w in self.workers:
w.buffer.clear()
self.idqueue.clear()
self.gid_obj.value = 0
def _init_workers(self):
jt.clean()
jt.gc()
self.index_list = mp.Array('i', self.real_len, lock=False)
workers = []
# batch id to worker id
self.idmap = mp.Array('i', self.batch_len, lock=False)
# get worker id
self.idqueue = jt.RingBuffer(2048)
# global token index
self.gid = mp.Value('i', self.batch_len)
self.gid.value = 0
# global token index condition
self.gidc = mp.Condition(self.gid.get_lock())
# number of idle workers
self.num_idle = mp.Value('i', 0, lock=False)
# number of idle workers condition
self.num_idle_c = mp.Condition(self.gid.get_lock())
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
for i in range(self.num_workers):
w = Worker(target=self._worker_main, args=(i,),
buffer_size=self.buffer_size,
keep_numpy_array=self.keep_numpy_array)
workers.append(w)
self.workers = workers
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
def reset(self):
if not hasattr(self, "workers"):
@ -353,8 +370,8 @@ Example::
if self.total_len is None:
self.total_len = len(self)
return self.total_len
def __iter__(self):
def _get_index_list(self):
if self.total_len is None:
self.total_len = len(self)
# maybe rewrite by sampler
@ -418,71 +435,79 @@ Example::
else:
self.real_len = self.total_len
self.real_batch_size = self.batch_size
self.batch_len = self.__batch_len__()
return index_list
def _epochs(self):
if self.endless:
while True:
yield
self.epoch_id += 1
else:
yield
def __iter__(self):
index_list = self._get_index_list()
if not hasattr(self, "workers") and self.num_workers:
self._init_workers()
self.last_ids = [-1] * 10
if self.num_workers:
self._stop_all_workers()
self.index_list_numpy[:] = index_list
gid_obj = self.gid.get_obj()
gid_lock = self.gid.get_lock()
with gid_lock:
gid_obj.value = 0
if self.num_idle.value:
self.gidc.notify_all()
start = time.time()
self.batch_time = 0
for i in range(self.batch_len):
# try not get lock first
if gid_obj.value <= i:
with gid_lock:
if gid_obj.value <= i:
if mp_log_v:
print("wait")
self.gidc.wait()
now = time.time()
self.wait_time = now - start
start = now
self.last_id = i
worker_id = self.idmap[i]
w = self.workers[worker_id]
if mp_log_v:
print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
batch = w.buffer.recv()
now = time.time()
self.recv_time = now - start
start = now
for _ in self._epochs():
for i in range(self.batch_len):
# get which worker has this batch
worker_id = self.idqueue.pop()
if mp_log_v:
print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
batch = self.to_jittor(batch)
now = time.time()
self.to_jittor_time = now - start
start = now
now = time.time()
self.wait_time = now - start
start = now
yield batch
self.last_ids[i%10] = worker_id
self.batch_id = i
w = self.workers[worker_id]
if mp_log_v:
print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
batch = w.buffer.recv()
now = time.time()
self.batch_time = now - start
start = now
now = time.time()
self.recv_time = now - start
start = now
if mp_log_v:
print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
batch = self.to_jittor(batch)
now = time.time()
self.to_jittor_time = now - start
start = now
yield batch
now = time.time()
self.batch_time = now - start
start = now
else:
batch_data = []
for idx in index_list:
batch_data.append(self[int(idx)])
if len(batch_data) == self.real_batch_size:
for _ in self._epochs():
batch_data = []
for idx in index_list:
batch_data.append(self[int(idx)])
if len(batch_data) == self.real_batch_size:
batch_data = self.collate_batch(batch_data)
batch_data = self.to_jittor(batch_data)
yield batch_data
batch_data = []
# depend on drop_last
if not self.drop_last and len(batch_data) > 0:
batch_data = self.collate_batch(batch_data)
batch_data = self.to_jittor(batch_data)
yield batch_data
batch_data = []
# depend on drop_last
if not self.drop_last and len(batch_data) > 0:
batch_data = self.collate_batch(batch_data)
batch_data = self.to_jittor(batch_data)
yield batch_data
class ImageFolder(Dataset):

View File

@ -117,12 +117,13 @@ class TestDatasetSeed(unittest.TestCase):
return np.random.rand(2)
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
dd = []
for d in dataset:
dd.append(d.numpy())
for i in range(len(d)):
for j in range(i+1, len(d)):
assert not np.allclose(dd[i], dd[j])
for _ in range(10):
dd = []
for d in dataset:
dd.append(d.numpy())
for i in range(len(d)):
for j in range(i+1, len(d)):
assert not np.allclose(dd[i], dd[j])
def test_py_native(self):
import random
@ -136,12 +137,13 @@ class TestDatasetSeed(unittest.TestCase):
jt.set_global_seed(0)
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
dd = []
for d in dataset:
dd.append(d.numpy())
for i in range(len(d)):
for j in range(i+1, len(d)):
assert not np.allclose(dd[i], dd[j])
for _ in range(10):
dd = []
for d in dataset:
dd.append(d.numpy())
for i in range(len(d)):
for j in range(i+1, len(d)):
assert not np.allclose(dd[i], dd[j])
def test_jtrand(self):
import random
@ -155,12 +157,13 @@ class TestDatasetSeed(unittest.TestCase):
jt.set_global_seed(0)
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
dd = []
for d in dataset:
dd.append(d.numpy())
for i in range(len(d)):
for j in range(i+1, len(d)):
assert not np.allclose(dd[i], dd[j])
for _ in range(10):
dd = []
for d in dataset:
dd.append(d.numpy())
for i in range(len(d)):
for j in range(i+1, len(d)):
assert not np.allclose(dd[i], dd[j])
def test_cifar(self):
from jittor.dataset.cifar import CIFAR10

View File

@ -63,8 +63,11 @@ class TestResnet(unittest.TestCase):
global prev
prev = time.time()
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
self.train_loader.endless = True
for batch_idx, (data, target) in enumerate(self.train_loader):
for data, target in self.train_loader:
batch_id = self.train_loader.batch_id
epoch_id = self.train_loader.epoch_id
# train step
with jt.log_capture_scope(
@ -74,7 +77,7 @@ class TestResnet(unittest.TestCase):
output = mnist_net(data)
loss = nn.cross_entropy_loss(output, target)
SGD.step(loss)
def callback(batch_idx, loss, output, target):
def callback(epoch_id, batch_id, loss, output, target):
# print train info
global prev
pred = np.argmax(output, axis=1)
@ -82,15 +85,15 @@ class TestResnet(unittest.TestCase):
loss_list.append(loss[0])
acc_list.append(acc)
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
.format(0, batch_idx, 600,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev))
.format(epoch_id, batch_id, 600,1. * batch_id / 6.0, loss[0], acc, time.time()-prev))
# prev = time.time()
jt.fetch(batch_idx, loss, output, target, callback)
jt.fetch(epoch_id, batch_id, loss, output, target, callback)
log_conv = find_log_with_re(logs,
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
log_matmul = find_log_with_re(logs,
"Jit op key (not )?found: ((mkl)|(cublas))_matmul.*")
if batch_idx > 2:
if batch_id > 2:
assert len(log_conv)==59 and len(log_matmul)==6, (len(log_conv), len(log_matmul))
mem_used = jt.flags.stat_allocator_total_alloc_byte \
@ -119,6 +122,8 @@ class TestResnet(unittest.TestCase):
assert jt.core.number_of_lived_vars() < 8100, jt.core.number_of_lived_vars()
else:
assert jt.core.number_of_lived_vars() < 7000, jt.core.number_of_lived_vars()
if self.train_loader.epoch_id >= 2:
break
jt.sync_all(True)
assert np.mean(loss_list[-50:])<0.5