mirror of https://github.com/Jittor/Jittor
add sampler (#187)
* fix bicubic,add fold. * add eye. * add test for qr,bicubic,fold,unfold. * fix bicubic and fold to code_op ver.add grad test. * add docs.update pinv to support (..,M,N) shape * edit maintainer and testfunc's name. * fix nn * fix nn * add sampler and subset. * fix subset,change jittor op to np op. * fix. * fix space. * fix?. * copy. * copy. * delete .idea * add sampler hook in dataset * add doc * add sampler Co-authored-by: Exusial <2247838039@qq.com> Co-authored-by: Gword <471184555@qq.com> Co-authored-by: Dun Liang <randonlang@gmail.com>
This commit is contained in:
parent
7ace88d718
commit
cd3319edf4
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.55'
|
||||
__version__ = '1.2.2.56'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
|
||||
from .dataset import Dataset, ImageFolder
|
||||
from .mnist import MNIST
|
||||
from .voc import VOC
|
||||
from .voc import VOC
|
||||
from .sampler import *
|
|
@ -80,6 +80,7 @@ class Dataset(object):
|
|||
self.buffer_size = buffer_size
|
||||
self.stop_grad = stop_grad
|
||||
self.keep_numpy_array = keep_numpy_array
|
||||
self.sampler = None
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
@ -343,10 +344,23 @@ Example::
|
|||
print("dataset deleted")
|
||||
self.terminate()
|
||||
|
||||
def __real_len__(self):
|
||||
if self.total_len is None:
|
||||
self.total_len = len(self)
|
||||
return self.total_len
|
||||
|
||||
def __iter__(self):
|
||||
if self.total_len is None:
|
||||
self.total_len = len(self)
|
||||
if self.shuffle == False:
|
||||
# maybe rewrite by sampler
|
||||
total_len = self.total_len
|
||||
if self.sampler:
|
||||
index_list = list(self.sampler.__iter__())
|
||||
total_len = len(index_list)
|
||||
# check is not batch sampler
|
||||
if len(index_list):
|
||||
assert not isinstance(index_list[0], (list,tuple)), "Batch sampler not support yet."
|
||||
elif self.shuffle == False:
|
||||
index_list = get_order_list(self.total_len)
|
||||
else:
|
||||
index_list = get_random_list(self.total_len)
|
||||
|
@ -373,8 +387,8 @@ Example::
|
|||
LOG.w("Batch size is not divisible by MPI world size, "
|
||||
"The distributed version may be different from "
|
||||
"the single-process version.")
|
||||
fix_batch = self.total_len // self.batch_size
|
||||
last_batch = self.total_len - fix_batch * self.batch_size
|
||||
fix_batch = total_len // self.batch_size
|
||||
last_batch = total_len - fix_batch * self.batch_size
|
||||
fix_batch_l = index_list[0:fix_batch*self.batch_size] \
|
||||
.reshape(-1,self.batch_size)
|
||||
fix_batch_l = fix_batch_l[
|
||||
|
@ -394,8 +408,8 @@ Example::
|
|||
|
||||
self.real_len = len(index_list)
|
||||
self.real_batch_size = real_batch_size
|
||||
assert self.total_len // self.batch_size == \
|
||||
self.real_len // self.real_batch_size, f"Number of batches({self.total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {self.total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}"
|
||||
assert total_len // self.batch_size == \
|
||||
self.real_len // self.real_batch_size, f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}"
|
||||
else:
|
||||
self.real_len = self.total_len
|
||||
self.real_batch_size = self.batch_size
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Hao-Yang Peng
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import jittor as jt
|
||||
from .dataset import Dataset
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Sampler():
|
||||
def __init__(self, dataset):
|
||||
self.dataset = dataset
|
||||
# MUST set sampler here
|
||||
dataset.sampler = self
|
||||
|
||||
def __iter__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SequentialSampler(Sampler):
|
||||
def __init__(self, dataset):
|
||||
# MUST set sampler here
|
||||
dataset.sampler = self
|
||||
self.dataset = dataset
|
||||
|
||||
def __iter__(self):
|
||||
return iter(range(self.dataset.__real_len__()))
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset.__real_len__()
|
||||
|
||||
|
||||
class RandomSampler(Sampler):
|
||||
def __init__(self, dataset, replacement=False, num_samples=None):
|
||||
# MUST set sampler here
|
||||
dataset.sampler = self
|
||||
self.dataset = dataset
|
||||
self.rep = replacement
|
||||
self._num_samples = num_samples
|
||||
|
||||
@property
|
||||
def num_samples(self):
|
||||
if self._num_samples is None:
|
||||
return self.dataset.__real_len__()
|
||||
return self._num_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __iter__(self):
|
||||
n = self.dataset.__real_len__()
|
||||
if self.rep:
|
||||
return iter(np.random.randint(low=0, high=n, size=(self.num_samples,), dtype=np.int64).tolist())
|
||||
return iter(np.random.permutation(n).tolist())
|
||||
|
||||
|
||||
class SubsetRandomSampler(Sampler):
|
||||
def __init__(self, dataset, indice):
|
||||
'''
|
||||
testdataset = TestSamplerDataset()
|
||||
subsetsampler = SubsetRandomSampler(testdataset, (20, 30))
|
||||
|
||||
for i, data in enumerate(testdataset):
|
||||
# data between 20 ~ 29
|
||||
......
|
||||
|
||||
'''
|
||||
# MUST set sampler here
|
||||
dataset.sampler = self
|
||||
self.dataset = dataset
|
||||
self.indices = indice
|
||||
assert indice[0] >= 0 and indice[1] < dataset.__real_len__() and indice[0] < indice[1]
|
||||
|
||||
def __iter__(self):
|
||||
return (int(i) + self.indices[0] for i in np.random.permutation(self.indices[1] - self.indices[0]))
|
||||
|
||||
def __len__(self):
|
||||
return self.indices[1] - self.indices[0]
|
||||
|
||||
|
||||
class BatchSampler(Sampler):
|
||||
def __init__(self, sampler, batch_size, drop_last):
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
self.drop_last = drop_last
|
||||
|
||||
def __iter__(self):
|
||||
batch = []
|
||||
for idx in self.sampler:
|
||||
batch.append(idx)
|
||||
if len(batch) == self.batch_size:
|
||||
yield batch
|
||||
batch = []
|
||||
if len(batch) > 0 and not self.drop_last:
|
||||
yield batch
|
||||
|
||||
def __len__(self):
|
||||
if self.drop_last:
|
||||
return len(self.sampler) // self.batch_size
|
||||
else:
|
||||
return (len(self.sampler) + self.batch_size - 1) // self.batch_size
|
|
@ -0,0 +1,59 @@
|
|||
import jittor as jt
|
||||
from jittor.dataset import *
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import unittest
|
||||
|
||||
|
||||
|
||||
class TestSamplerDataset(Dataset):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.set_attrs(total_len=40, batch_size=1)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return idx**2
|
||||
|
||||
|
||||
class TestSampler(unittest.TestCase):
|
||||
def test_sequential_sampler(self):
|
||||
testdataset = TestSamplerDataset()
|
||||
seqsampler = SequentialSampler(testdataset)
|
||||
assert len(seqsampler) == 40
|
||||
for idx, batch in enumerate(seqsampler):
|
||||
assert idx == batch
|
||||
for i, data in enumerate(testdataset):
|
||||
assert data.item() == i**2
|
||||
|
||||
def test_random_sampler(self):
|
||||
testdataset = TestSamplerDataset()
|
||||
randomsampler = RandomSampler(testdataset)
|
||||
assert len(randomsampler) == 40
|
||||
diff = 0
|
||||
for i, data in enumerate(testdataset):
|
||||
diff += data.item() == i**2
|
||||
assert diff < 10
|
||||
|
||||
def test_subset_random_sampler(self):
|
||||
testdataset = TestSamplerDataset()
|
||||
subsetsampler = SubsetRandomSampler(testdataset, (20, 30))
|
||||
assert len(subsetsampler) == 10
|
||||
s = 0
|
||||
for i, data in enumerate(testdataset):
|
||||
s += data.item()
|
||||
s2 = 0
|
||||
for i in range(20,30):
|
||||
s2 += i**2
|
||||
assert s == s2, (s, s2)
|
||||
|
||||
def test_batch_sampler(self):
|
||||
testdataset = TestSamplerDataset()
|
||||
seqforbatch = SequentialSampler(testdataset)
|
||||
batchsampler = BatchSampler(seqforbatch, 4, drop_last=False)
|
||||
assert len(batchsampler) == 10
|
||||
for batch in batchsampler:
|
||||
assert len(batch) == 4
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue