allow dict convert in dataset.to_jittor

This commit is contained in:
Dun Liang 2021-07-01 20:41:12 +08:00
parent 5d4912b6df
commit 085074b625
3 changed files with 26 additions and 1 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.51'
__version__ = '1.2.3.52'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -136,6 +136,11 @@ class Dataset(object):
if self.stop_grad else jt.array(x)
if isinstance(batch, np.ndarray):
return to_jt(batch)
if isinstance(batch, dict):
new_batch = {}
for k,v in batch.items():
new_batch[k] = self.to_jittor(v)
return new_batch
if not isinstance(batch, (list, tuple)):
return batch
new_batch = []

View File

@ -165,6 +165,26 @@ class TestDatasetSeed(unittest.TestCase):
for j in range(i+1, len(d)):
assert not np.allclose(dd[i], dd[j])
def test_dict(self):
import random
class YourDataset(Dataset):
def __init__(self):
super().__init__()
self.set_attrs(total_len=160)
def __getitem__(self, k):
return { "a":np.array([1,2,3]) }
jt.set_global_seed(0)
dataset = YourDataset().set_attrs(batch_size=1, shuffle=True, num_workers=4)
for _ in range(10):
dd = []
for d in dataset:
# breakpoint()
assert isinstance(d, dict)
assert isinstance(d['a'], jt.Var)
np.testing.assert_allclose(d['a'].numpy(), [[1,2,3]])
def test_cifar(self):
from jittor.dataset.cifar import CIFAR10
a = CIFAR10()