add TensorDataset

This commit is contained in:
Dun Liang 2021-07-23 13:30:16 +08:00
parent 87e639730c
commit b0a8943404
5 changed files with 71 additions and 4 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.74'
__version__ = '1.2.3.75'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int
@ -1316,6 +1316,9 @@ Var.float = Var.float32
double = float64
Var.double = Var.float64
def is_var(v):
return isinstance(v, Var)
# __array__ interface is used for np.array(jt_var)
Var.__array__ = Var.numpy
Var.__array_priority__ = 2000

View File

@ -1,5 +1,5 @@
from .dataset import Dataset, ImageFolder, dataset_root
from .dataset import Dataset, ImageFolder, dataset_root, TensorDataset
from .mnist import MNIST
from .cifar import CIFAR10, CIFAR100
from .voc import VOC

View File

@ -88,6 +88,7 @@ class Dataset(object):
self.endless = endless
self.epoch_id = 0
self.sampler = None
self._disable_workers = False
def __getitem__(self, index):
raise NotImplementedError
@ -452,6 +453,8 @@ Example::
yield
def __iter__(self):
if self._disable_workers:
self.num_workers = 0
index_list = self._get_index_list()
if not hasattr(self, "workers") and self.num_workers:
@ -582,3 +585,45 @@ class ImageFolder(Dataset):
if self.transform:
img = self.transform(img)
return img, self.imgs[k][1]
class TensorDataset(Dataset):
""" Dataset using Tensor directly, Example::
import jittor as jt
from jittor.dataset import TensorDataset
x = jt.array([1,2,3])
y = jt.array([4,5,6])
z = jt.array([7,8,9])
dataset = TensorDataset(x, y, z)
dataset.set_attrs(batch_size=1)
for a,b,c in dataset:
print(a,b,c)
# will print
# 1,4,7
# 2,5,8
# 3,6,9
"""
def __init__(self, *args):
super().__init__()
self.args = args
self._disable_workers = True
assert len(args), "At lease one args"
l = len(args[0])
for a in args:
assert l == len(a), "Len should be the same"
self.set_attrs(total_len=l)
def __getitem__(self, idx):
return [ a[idx] for a in self.args ]
def collate_batch(self, batch):
b = collate_batch(batch)
for i in range(len(self.args)):
x = b[i]
if jt.is_var(self.args[i]) and self.args[i].ndim == 1:
x.assign(x.squeeze(-1))
return b

View File

@ -27,8 +27,7 @@ def collate_batch(batch):
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, jt.Var):
# TODO: use jittor
temp_data = np.stack([data.data for data in batch], 0)
temp_data = jt.stack([data for data in batch], 0)
return temp_data
if elem_type is np.ndarray:
temp_data = np.stack([data for data in batch], 0)

View File

@ -194,6 +194,26 @@ class TestDatasetSeed(unittest.TestCase):
assert imgs.shape == [16,32,32,3,]
assert labels.shape == [16,]
break
def test_tensor_dataset(self):
import jittor as jt
from jittor.dataset import TensorDataset
x = jt.array([1,2,3])
y = jt.array([4,5,6])
z = jt.array([7,8,9])
dataset = TensorDataset(x, y, z)
# dataset.set_attrs(batch_size=2)
dataset.set_attrs(batch_size=1)
for i,(a,b,c) in enumerate(dataset):
# print(a,b,c)
# print(a.shape)
assert a.shape == [1]
assert x[i] == a
assert y[i] == b
assert z[i] == c
if __name__ == "__main__":