mirror of https://github.com/Jittor/Jittor
allow dict convert in dataset.to_jittor
This commit is contained in:
parent
5d4912b6df
commit
085074b625
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.51'
|
||||
__version__ = '1.2.3.52'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -136,6 +136,11 @@ class Dataset(object):
|
|||
if self.stop_grad else jt.array(x)
|
||||
if isinstance(batch, np.ndarray):
|
||||
return to_jt(batch)
|
||||
if isinstance(batch, dict):
|
||||
new_batch = {}
|
||||
for k,v in batch.items():
|
||||
new_batch[k] = self.to_jittor(v)
|
||||
return new_batch
|
||||
if not isinstance(batch, (list, tuple)):
|
||||
return batch
|
||||
new_batch = []
|
||||
|
|
|
@ -165,6 +165,26 @@ class TestDatasetSeed(unittest.TestCase):
|
|||
for j in range(i+1, len(d)):
|
||||
assert not np.allclose(dd[i], dd[j])
|
||||
|
||||
def test_dict(self):
|
||||
import random
|
||||
class YourDataset(Dataset):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.set_attrs(total_len=160)
|
||||
|
||||
def __getitem__(self, k):
|
||||
return { "a":np.array([1,2,3]) }
|
||||
|
||||
jt.set_global_seed(0)
|
||||
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
|
||||
for _ in range(10):
|
||||
dd = []
|
||||
for d in dataset:
|
||||
# breakpoint()
|
||||
assert isinstance(d, dict)
|
||||
assert isinstance(d['a'], jt.Var)
|
||||
np.testing.assert_allclose(d['a'].numpy(), [[1,2,3]])
|
||||
|
||||
def test_cifar(self):
|
||||
from jittor.dataset.cifar import CIFAR10
|
||||
a = CIFAR10()
|
||||
|
|
Loading…
Reference in New Issue