mirror of https://github.com/Jittor/Jittor
add TensorDataset
This commit is contained in:
parent
87e639730c
commit
b0a8943404
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.74'
|
||||
__version__ = '1.2.3.75'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -1316,6 +1316,9 @@ Var.float = Var.float32
|
|||
double = float64
|
||||
Var.double = Var.float64
|
||||
|
||||
def is_var(v):
|
||||
return isinstance(v, Var)
|
||||
|
||||
# __array__ interface is used for np.array(jt_var)
|
||||
Var.__array__ = Var.numpy
|
||||
Var.__array_priority__ = 2000
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
from .dataset import Dataset, ImageFolder, dataset_root
|
||||
from .dataset import Dataset, ImageFolder, dataset_root, TensorDataset
|
||||
from .mnist import MNIST
|
||||
from .cifar import CIFAR10, CIFAR100
|
||||
from .voc import VOC
|
||||
|
|
|
@ -88,6 +88,7 @@ class Dataset(object):
|
|||
self.endless = endless
|
||||
self.epoch_id = 0
|
||||
self.sampler = None
|
||||
self._disable_workers = False
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
@ -452,6 +453,8 @@ Example::
|
|||
yield
|
||||
|
||||
def __iter__(self):
|
||||
if self._disable_workers:
|
||||
self.num_workers = 0
|
||||
index_list = self._get_index_list()
|
||||
|
||||
if not hasattr(self, "workers") and self.num_workers:
|
||||
|
@ -582,3 +585,45 @@ class ImageFolder(Dataset):
|
|||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return img, self.imgs[k][1]
|
||||
|
||||
class TensorDataset(Dataset):
|
||||
""" Dataset using Tensor directly, Example::
|
||||
|
||||
import jittor as jt
|
||||
from jittor.dataset import TensorDataset
|
||||
|
||||
x = jt.array([1,2,3])
|
||||
y = jt.array([4,5,6])
|
||||
z = jt.array([7,8,9])
|
||||
dataset = TensorDataset(x, y, z)
|
||||
dataset.set_attrs(batch_size=1)
|
||||
|
||||
for a,b,c in dataset:
|
||||
print(a,b,c)
|
||||
# will print
|
||||
# 1,4,7
|
||||
# 2,5,8
|
||||
# 3,6,9
|
||||
|
||||
"""
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
self.args = args
|
||||
self._disable_workers = True
|
||||
assert len(args), "At lease one args"
|
||||
l = len(args[0])
|
||||
for a in args:
|
||||
assert l == len(a), "Len should be the same"
|
||||
self.set_attrs(total_len=l)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return [ a[idx] for a in self.args ]
|
||||
|
||||
|
||||
def collate_batch(self, batch):
|
||||
b = collate_batch(batch)
|
||||
for i in range(len(self.args)):
|
||||
x = b[i]
|
||||
if jt.is_var(self.args[i]) and self.args[i].ndim == 1:
|
||||
x.assign(x.squeeze(-1))
|
||||
return b
|
||||
|
|
|
@ -27,8 +27,7 @@ def collate_batch(batch):
|
|||
elem = batch[0]
|
||||
elem_type = type(elem)
|
||||
if isinstance(elem, jt.Var):
|
||||
# TODO: use jittor
|
||||
temp_data = np.stack([data.data for data in batch], 0)
|
||||
temp_data = jt.stack([data for data in batch], 0)
|
||||
return temp_data
|
||||
if elem_type is np.ndarray:
|
||||
temp_data = np.stack([data for data in batch], 0)
|
||||
|
|
|
@ -194,6 +194,26 @@ class TestDatasetSeed(unittest.TestCase):
|
|||
assert imgs.shape == [16,32,32,3,]
|
||||
assert labels.shape == [16,]
|
||||
break
|
||||
|
||||
def test_tensor_dataset(self):
|
||||
import jittor as jt
|
||||
from jittor.dataset import TensorDataset
|
||||
|
||||
x = jt.array([1,2,3])
|
||||
y = jt.array([4,5,6])
|
||||
z = jt.array([7,8,9])
|
||||
|
||||
dataset = TensorDataset(x, y, z)
|
||||
# dataset.set_attrs(batch_size=2)
|
||||
dataset.set_attrs(batch_size=1)
|
||||
|
||||
for i,(a,b,c) in enumerate(dataset):
|
||||
# print(a,b,c)
|
||||
# print(a.shape)
|
||||
assert a.shape == [1]
|
||||
assert x[i] == a
|
||||
assert y[i] == b
|
||||
assert z[i] == c
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue