mirror of https://github.com/Jittor/Jittor
add endless dataset
This commit is contained in:
parent
8fcdb63236
commit
4828cdc896
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.49'
|
||||
__version__ = '1.2.3.50'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -48,6 +48,8 @@ class Dataset(object):
|
|||
[in] drop_last(bool): if true, the last batch of dataset might smaller than batch_size, default True.
|
||||
[in] num_workers(int): number of workers for loading data.
|
||||
[in] buffer_size(int): buffer size for each worker in bytes, default(512MB).
|
||||
[in] keep_numpy_array(bool): return numpy array rather than jittor array, default(False).
|
||||
[in] endless(bool): will this dataset yield data forever, default(False).
|
||||
|
||||
Example::
|
||||
|
||||
|
@ -70,7 +72,8 @@ class Dataset(object):
|
|||
num_workers = 0,
|
||||
buffer_size = 512*1024*1024,
|
||||
stop_grad = True,
|
||||
keep_numpy_array = False):
|
||||
keep_numpy_array = False,
|
||||
endless = False):
|
||||
super().__init__()
|
||||
if os.environ.get("DISABLE_MULTIPROCESSING", '0') == '1':
|
||||
num_workers = 0
|
||||
|
@ -82,6 +85,8 @@ class Dataset(object):
|
|||
self.buffer_size = buffer_size
|
||||
self.stop_grad = stop_grad
|
||||
self.keep_numpy_array = keep_numpy_array
|
||||
self.endless = endless
|
||||
self.epoch_id = 0
|
||||
self.sampler = None
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
@ -182,14 +187,22 @@ class Dataset(object):
|
|||
while True:
|
||||
# get id
|
||||
with gid_lock:
|
||||
while gid_obj.value >= self.batch_len or buffer.is_stop():
|
||||
while buffer.is_stop() or self.idqueue.is_stop():
|
||||
self.num_idle.value += 1
|
||||
self.num_idle_c.notify()
|
||||
self.gidc.wait()
|
||||
self.num_idle.value -= 1
|
||||
cid = gid_obj.value
|
||||
self.idmap[cid] = worker_id
|
||||
gid_obj.value += 1
|
||||
if cid == 0:
|
||||
index_list = self._get_index_list()
|
||||
self.index_list_numpy[:] = index_list
|
||||
batch_index_list = self.index_list_numpy[
|
||||
cid*self.real_batch_size:
|
||||
min(self.real_len, (cid+1)*self.real_batch_size)
|
||||
].copy()
|
||||
# print(batch_index_list, cid, self.real_batch_size, self.real_len)
|
||||
self.idqueue.push(worker_id)
|
||||
gid_obj.value = (gid_obj.value + 1) % self.batch_len
|
||||
self.gidc.notify()
|
||||
now = time.time()
|
||||
other_time = now - start
|
||||
|
@ -199,8 +212,8 @@ class Dataset(object):
|
|||
batch = []
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} load batch", cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size))
|
||||
for i in range(cid*self.real_batch_size, min(self.real_len, (cid+1)*self.real_batch_size)):
|
||||
batch.append(self[self.index_list[i]])
|
||||
for i in batch_index_list:
|
||||
batch.append(self[i])
|
||||
batch = self.collate_batch(batch)
|
||||
now = time.time()
|
||||
data_time = now - start
|
||||
|
@ -278,10 +291,10 @@ Example::
|
|||
if not hasattr(self, "workers"):
|
||||
return
|
||||
msg = [""]
|
||||
msg.append(f"progress:{self.last_id}/{self.batch_len}")
|
||||
msg.append(f"progress:{self.batch_id}/{self.batch_len}")
|
||||
msg.append(f"batch(s): {self.batch_time:.3f}\twait(s):{self.wait_time:.3f}")
|
||||
msg.append(f"recv(s): {self.recv_time:.3f}\tto_jittor(s):{self.to_jittor_time:.3f}")
|
||||
msg.append(f"last 10 workers: {self.idmap[max(0, self.last_id-9):self.last_id+1]}")
|
||||
msg.append(f"last 10 workers: {self.last_ids}")
|
||||
msg.append(f"ID\twait(s)\topen(s)\tload(s)\tsend(s)\ttotal(s)")
|
||||
for i in range(self.num_workers):
|
||||
w = self.workers[i]
|
||||
|
@ -293,6 +306,7 @@ Example::
|
|||
# stop workers
|
||||
for w in self.workers:
|
||||
w.buffer.stop()
|
||||
self.idqueue.stop()
|
||||
# wait until all workers idle
|
||||
if self.num_idle.value < self.num_workers:
|
||||
with self.gid.get_lock():
|
||||
|
@ -306,29 +320,32 @@ Example::
|
|||
# clean workers' buffer
|
||||
for w in self.workers:
|
||||
w.buffer.clear()
|
||||
self.idqueue.clear()
|
||||
self.gid_obj.value = 0
|
||||
|
||||
def _init_workers(self):
|
||||
jt.clean()
|
||||
jt.gc()
|
||||
self.index_list = mp.Array('i', self.real_len, lock=False)
|
||||
workers = []
|
||||
# batch id to worker id
|
||||
self.idmap = mp.Array('i', self.batch_len, lock=False)
|
||||
# get worker id
|
||||
self.idqueue = jt.RingBuffer(2048)
|
||||
# global token index
|
||||
self.gid = mp.Value('i', self.batch_len)
|
||||
self.gid.value = 0
|
||||
# global token index condition
|
||||
self.gidc = mp.Condition(self.gid.get_lock())
|
||||
# number of idle workers
|
||||
self.num_idle = mp.Value('i', 0, lock=False)
|
||||
# number of idle workers condition
|
||||
self.num_idle_c = mp.Condition(self.gid.get_lock())
|
||||
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
|
||||
for i in range(self.num_workers):
|
||||
w = Worker(target=self._worker_main, args=(i,),
|
||||
buffer_size=self.buffer_size,
|
||||
keep_numpy_array=self.keep_numpy_array)
|
||||
workers.append(w)
|
||||
self.workers = workers
|
||||
self.index_list_numpy = np.ndarray(dtype='int32', shape=self.real_len, buffer=self.index_list)
|
||||
|
||||
def reset(self):
|
||||
if not hasattr(self, "workers"):
|
||||
|
@ -353,8 +370,8 @@ Example::
|
|||
if self.total_len is None:
|
||||
self.total_len = len(self)
|
||||
return self.total_len
|
||||
|
||||
def __iter__(self):
|
||||
|
||||
def _get_index_list(self):
|
||||
if self.total_len is None:
|
||||
self.total_len = len(self)
|
||||
# maybe rewrite by sampler
|
||||
|
@ -418,71 +435,79 @@ Example::
|
|||
else:
|
||||
self.real_len = self.total_len
|
||||
self.real_batch_size = self.batch_size
|
||||
|
||||
self.batch_len = self.__batch_len__()
|
||||
return index_list
|
||||
|
||||
def _epochs(self):
|
||||
if self.endless:
|
||||
while True:
|
||||
yield
|
||||
self.epoch_id += 1
|
||||
else:
|
||||
yield
|
||||
|
||||
def __iter__(self):
|
||||
index_list = self._get_index_list()
|
||||
|
||||
if not hasattr(self, "workers") and self.num_workers:
|
||||
self._init_workers()
|
||||
self.last_ids = [-1] * 10
|
||||
|
||||
if self.num_workers:
|
||||
self._stop_all_workers()
|
||||
self.index_list_numpy[:] = index_list
|
||||
gid_obj = self.gid.get_obj()
|
||||
gid_lock = self.gid.get_lock()
|
||||
with gid_lock:
|
||||
gid_obj.value = 0
|
||||
if self.num_idle.value:
|
||||
self.gidc.notify_all()
|
||||
start = time.time()
|
||||
self.batch_time = 0
|
||||
for i in range(self.batch_len):
|
||||
# try not get lock first
|
||||
if gid_obj.value <= i:
|
||||
with gid_lock:
|
||||
if gid_obj.value <= i:
|
||||
if mp_log_v:
|
||||
print("wait")
|
||||
self.gidc.wait()
|
||||
now = time.time()
|
||||
self.wait_time = now - start
|
||||
start = now
|
||||
|
||||
self.last_id = i
|
||||
worker_id = self.idmap[i]
|
||||
w = self.workers[worker_id]
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
|
||||
batch = w.buffer.recv()
|
||||
now = time.time()
|
||||
self.recv_time = now - start
|
||||
start = now
|
||||
for _ in self._epochs():
|
||||
for i in range(self.batch_len):
|
||||
# get which worker has this batch
|
||||
worker_id = self.idqueue.pop()
|
||||
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
|
||||
batch = self.to_jittor(batch)
|
||||
now = time.time()
|
||||
self.to_jittor_time = now - start
|
||||
start = now
|
||||
now = time.time()
|
||||
self.wait_time = now - start
|
||||
start = now
|
||||
|
||||
yield batch
|
||||
self.last_ids[i%10] = worker_id
|
||||
self.batch_id = i
|
||||
w = self.workers[worker_id]
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} recv buffer", w.buffer)
|
||||
batch = w.buffer.recv()
|
||||
|
||||
now = time.time()
|
||||
self.batch_time = now - start
|
||||
start = now
|
||||
now = time.time()
|
||||
self.recv_time = now - start
|
||||
start = now
|
||||
|
||||
if mp_log_v:
|
||||
print(f"#{worker_id} {os.getpid()} recv", type(batch).__name__, [ type(b).__name__ for b in batch ])
|
||||
batch = self.to_jittor(batch)
|
||||
|
||||
now = time.time()
|
||||
self.to_jittor_time = now - start
|
||||
start = now
|
||||
|
||||
yield batch
|
||||
|
||||
now = time.time()
|
||||
self.batch_time = now - start
|
||||
start = now
|
||||
else:
|
||||
batch_data = []
|
||||
for idx in index_list:
|
||||
batch_data.append(self[int(idx)])
|
||||
if len(batch_data) == self.real_batch_size:
|
||||
for _ in self._epochs():
|
||||
batch_data = []
|
||||
for idx in index_list:
|
||||
batch_data.append(self[int(idx)])
|
||||
if len(batch_data) == self.real_batch_size:
|
||||
batch_data = self.collate_batch(batch_data)
|
||||
batch_data = self.to_jittor(batch_data)
|
||||
yield batch_data
|
||||
batch_data = []
|
||||
|
||||
# depend on drop_last
|
||||
if not self.drop_last and len(batch_data) > 0:
|
||||
batch_data = self.collate_batch(batch_data)
|
||||
batch_data = self.to_jittor(batch_data)
|
||||
yield batch_data
|
||||
batch_data = []
|
||||
|
||||
# depend on drop_last
|
||||
if not self.drop_last and len(batch_data) > 0:
|
||||
batch_data = self.collate_batch(batch_data)
|
||||
batch_data = self.to_jittor(batch_data)
|
||||
yield batch_data
|
||||
|
||||
|
||||
class ImageFolder(Dataset):
|
||||
|
|
|
@ -117,12 +117,13 @@ class TestDatasetSeed(unittest.TestCase):
|
|||
return np.random.rand(2)
|
||||
|
||||
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
||||
dd = []
|
||||
for d in dataset:
|
||||
dd.append(d.numpy())
|
||||
for i in range(len(d)):
|
||||
for j in range(i+1, len(d)):
|
||||
assert not np.allclose(dd[i], dd[j])
|
||||
for _ in range(10):
|
||||
dd = []
|
||||
for d in dataset:
|
||||
dd.append(d.numpy())
|
||||
for i in range(len(d)):
|
||||
for j in range(i+1, len(d)):
|
||||
assert not np.allclose(dd[i], dd[j])
|
||||
|
||||
def test_py_native(self):
|
||||
import random
|
||||
|
@ -136,12 +137,13 @@ class TestDatasetSeed(unittest.TestCase):
|
|||
|
||||
jt.set_global_seed(0)
|
||||
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
||||
dd = []
|
||||
for d in dataset:
|
||||
dd.append(d.numpy())
|
||||
for i in range(len(d)):
|
||||
for j in range(i+1, len(d)):
|
||||
assert not np.allclose(dd[i], dd[j])
|
||||
for _ in range(10):
|
||||
dd = []
|
||||
for d in dataset:
|
||||
dd.append(d.numpy())
|
||||
for i in range(len(d)):
|
||||
for j in range(i+1, len(d)):
|
||||
assert not np.allclose(dd[i], dd[j])
|
||||
|
||||
def test_jtrand(self):
|
||||
import random
|
||||
|
@ -155,12 +157,13 @@ class TestDatasetSeed(unittest.TestCase):
|
|||
|
||||
jt.set_global_seed(0)
|
||||
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
||||
dd = []
|
||||
for d in dataset:
|
||||
dd.append(d.numpy())
|
||||
for i in range(len(d)):
|
||||
for j in range(i+1, len(d)):
|
||||
assert not np.allclose(dd[i], dd[j])
|
||||
for _ in range(10):
|
||||
dd = []
|
||||
for d in dataset:
|
||||
dd.append(d.numpy())
|
||||
for i in range(len(d)):
|
||||
for j in range(i+1, len(d)):
|
||||
assert not np.allclose(dd[i], dd[j])
|
||||
|
||||
def test_cifar(self):
|
||||
from jittor.dataset.cifar import CIFAR10
|
||||
|
|
|
@ -63,8 +63,11 @@ class TestResnet(unittest.TestCase):
|
|||
global prev
|
||||
prev = time.time()
|
||||
SGD = nn.SGD(mnist_net.parameters(), self.learning_rate, self.momentum, self.weight_decay)
|
||||
self.train_loader.endless = True
|
||||
|
||||
for batch_idx, (data, target) in enumerate(self.train_loader):
|
||||
for data, target in self.train_loader:
|
||||
batch_id = self.train_loader.batch_id
|
||||
epoch_id = self.train_loader.epoch_id
|
||||
|
||||
# train step
|
||||
with jt.log_capture_scope(
|
||||
|
@ -74,7 +77,7 @@ class TestResnet(unittest.TestCase):
|
|||
output = mnist_net(data)
|
||||
loss = nn.cross_entropy_loss(output, target)
|
||||
SGD.step(loss)
|
||||
def callback(batch_idx, loss, output, target):
|
||||
def callback(epoch_id, batch_id, loss, output, target):
|
||||
# print train info
|
||||
global prev
|
||||
pred = np.argmax(output, axis=1)
|
||||
|
@ -82,15 +85,15 @@ class TestResnet(unittest.TestCase):
|
|||
loss_list.append(loss[0])
|
||||
acc_list.append(acc)
|
||||
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tAcc: {:.6f} \tTime:{:.3f}'
|
||||
.format(0, batch_idx, 600,1. * batch_idx / 6.0, loss[0], acc, time.time()-prev))
|
||||
.format(epoch_id, batch_id, 600,1. * batch_id / 6.0, loss[0], acc, time.time()-prev))
|
||||
# prev = time.time()
|
||||
jt.fetch(batch_idx, loss, output, target, callback)
|
||||
jt.fetch(epoch_id, batch_id, loss, output, target, callback)
|
||||
|
||||
log_conv = find_log_with_re(logs,
|
||||
"Jit op key (not )?found: ((mkl)|(cudnn))_conv.*")
|
||||
log_matmul = find_log_with_re(logs,
|
||||
"Jit op key (not )?found: ((mkl)|(cublas))_matmul.*")
|
||||
if batch_idx > 2:
|
||||
if batch_id > 2:
|
||||
assert len(log_conv)==59 and len(log_matmul)==6, (len(log_conv), len(log_matmul))
|
||||
|
||||
mem_used = jt.flags.stat_allocator_total_alloc_byte \
|
||||
|
@ -119,6 +122,8 @@ class TestResnet(unittest.TestCase):
|
|||
assert jt.core.number_of_lived_vars() < 8100, jt.core.number_of_lived_vars()
|
||||
else:
|
||||
assert jt.core.number_of_lived_vars() < 7000, jt.core.number_of_lived_vars()
|
||||
if self.train_loader.epoch_id >= 2:
|
||||
break
|
||||
|
||||
jt.sync_all(True)
|
||||
assert np.mean(loss_list[-50:])<0.5
|
||||
|
|
Loading…
Reference in New Issue