add sampler (#187)

* fix bicubic,add fold.

* add eye.

* add test for qr,bicubic,fold,unfold.

* fix bicubic and fold to code_op ver.add grad test.

* add docs.update pinv to support (..,M,N) shape

* edit maintainer and testfunc's name.

* fix nn

* fix nn

* add sampler and subset.

* fix subset,change jittor op to np op.

* fix.

* fix space.

* fix?.

* copy.

* copy.

* delete .idea

* add sampler hook in dataset

* add doc

* add sampler

Co-authored-by: Exusial <2247838039@qq.com>
Co-authored-by: Gword <471184555@qq.com>
Co-authored-by: Dun Liang <randonlang@gmail.com>
This commit is contained in:
Exusial 2021-03-23 19:08:16 +08:00 committed by GitHub
parent 7ace88d718
commit cd3319edf4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 191 additions and 7 deletions

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.55'
__version__ = '1.2.2.56'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -1,4 +1,5 @@
from .dataset import Dataset, ImageFolder
from .mnist import MNIST
from .voc import VOC
from .voc import VOC
from .sampler import *

View File

@ -80,6 +80,7 @@ class Dataset(object):
self.buffer_size = buffer_size
self.stop_grad = stop_grad
self.keep_numpy_array = keep_numpy_array
self.sampler = None
def __getitem__(self, index):
raise NotImplementedError
@ -343,10 +344,23 @@ Example::
print("dataset deleted")
self.terminate()
def __real_len__(self):
if self.total_len is None:
self.total_len = len(self)
return self.total_len
def __iter__(self):
if self.total_len is None:
self.total_len = len(self)
if self.shuffle == False:
# maybe rewrite by sampler
total_len = self.total_len
if self.sampler:
index_list = list(self.sampler.__iter__())
total_len = len(index_list)
# check is not batch sampler
if len(index_list):
assert not isinstance(index_list[0], (list,tuple)), "Batch sampler not support yet."
elif self.shuffle == False:
index_list = get_order_list(self.total_len)
else:
index_list = get_random_list(self.total_len)
@ -373,8 +387,8 @@ Example::
LOG.w("Batch size is not divisible by MPI world size, "
"The distributed version may be different from "
"the single-process version.")
fix_batch = self.total_len // self.batch_size
last_batch = self.total_len - fix_batch * self.batch_size
fix_batch = total_len // self.batch_size
last_batch = total_len - fix_batch * self.batch_size
fix_batch_l = index_list[0:fix_batch*self.batch_size] \
.reshape(-1,self.batch_size)
fix_batch_l = fix_batch_l[
@ -394,8 +408,8 @@ Example::
self.real_len = len(index_list)
self.real_batch_size = real_batch_size
assert self.total_len // self.batch_size == \
self.real_len // self.real_batch_size, f"Number of batches({self.total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {self.total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}"
assert total_len // self.batch_size == \
self.real_len // self.real_batch_size, f"Number of batches({total_len // self.batch_size}!={self.real_len // self.real_batch_size}) not match, total_len: {total_len}, batch_size: {self.batch_size}, real_len: {self.real_len}, real_batch_size: {self.real_batch_size}"
else:
self.real_len = self.total_len
self.real_batch_size = self.batch_size

View File

@ -0,0 +1,110 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Hao-Yang Peng
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
from .dataset import Dataset
import numpy as np
from PIL import Image
class Sampler():
def __init__(self, dataset):
self.dataset = dataset
# MUST set sampler here
dataset.sampler = self
def __iter__(self):
raise NotImplementedError
def __len__(self):
raise NotImplementedError
class SequentialSampler(Sampler):
def __init__(self, dataset):
# MUST set sampler here
dataset.sampler = self
self.dataset = dataset
def __iter__(self):
return iter(range(self.dataset.__real_len__()))
def __len__(self):
return self.dataset.__real_len__()
class RandomSampler(Sampler):
def __init__(self, dataset, replacement=False, num_samples=None):
# MUST set sampler here
dataset.sampler = self
self.dataset = dataset
self.rep = replacement
self._num_samples = num_samples
@property
def num_samples(self):
if self._num_samples is None:
return self.dataset.__real_len__()
return self._num_samples
def __len__(self):
return self.num_samples
def __iter__(self):
n = self.dataset.__real_len__()
if self.rep:
return iter(np.random.randint(low=0, high=n, size=(self.num_samples,), dtype=np.int64).tolist())
return iter(np.random.permutation(n).tolist())
class SubsetRandomSampler(Sampler):
def __init__(self, dataset, indice):
'''
testdataset = TestSamplerDataset()
subsetsampler = SubsetRandomSampler(testdataset, (20, 30))
for i, data in enumerate(testdataset):
# data between 20 ~ 29
......
'''
# MUST set sampler here
dataset.sampler = self
self.dataset = dataset
self.indices = indice
assert indice[0] >= 0 and indice[1] < dataset.__real_len__() and indice[0] < indice[1]
def __iter__(self):
return (int(i) + self.indices[0] for i in np.random.permutation(self.indices[1] - self.indices[0]))
def __len__(self):
return self.indices[1] - self.indices[0]
class BatchSampler(Sampler):
def __init__(self, sampler, batch_size, drop_last):
self.sampler = sampler
self.batch_size = batch_size
self.drop_last = drop_last
def __iter__(self):
batch = []
for idx in self.sampler:
batch.append(idx)
if len(batch) == self.batch_size:
yield batch
batch = []
if len(batch) > 0 and not self.drop_last:
yield batch
def __len__(self):
if self.drop_last:
return len(self.sampler) // self.batch_size
else:
return (len(self.sampler) + self.batch_size - 1) // self.batch_size

View File

@ -0,0 +1,59 @@
import jittor as jt
from jittor.dataset import *
from PIL import Image
import numpy as np
import unittest
class TestSamplerDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=40, batch_size=1)
def __getitem__(self, idx):
return idx**2
class TestSampler(unittest.TestCase):
def test_sequential_sampler(self):
testdataset = TestSamplerDataset()
seqsampler = SequentialSampler(testdataset)
assert len(seqsampler) == 40
for idx, batch in enumerate(seqsampler):
assert idx == batch
for i, data in enumerate(testdataset):
assert data.item() == i**2
def test_random_sampler(self):
testdataset = TestSamplerDataset()
randomsampler = RandomSampler(testdataset)
assert len(randomsampler) == 40
diff = 0
for i, data in enumerate(testdataset):
diff += data.item() == i**2
assert diff < 10
def test_subset_random_sampler(self):
testdataset = TestSamplerDataset()
subsetsampler = SubsetRandomSampler(testdataset, (20, 30))
assert len(subsetsampler) == 10
s = 0
for i, data in enumerate(testdataset):
s += data.item()
s2 = 0
for i in range(20,30):
s2 += i**2
assert s == s2, (s, s2)
def test_batch_sampler(self):
testdataset = TestSamplerDataset()
seqforbatch = SequentialSampler(testdataset)
batchsampler = BatchSampler(seqforbatch, 4, drop_last=False)
assert len(batchsampler) == 10
for batch in batchsampler:
assert len(batch) == 4
if __name__ == "__main__":
unittest.main()