polish cifar

This commit is contained in:
Dun Liang 2021-06-16 20:14:27 +08:00
parent e160a83a7e
commit e4089ecc4a
3 changed files with 7 additions and 4 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.29'
__version__ = '1.2.3.30'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -6,6 +6,7 @@ import zipfile
from jittor_utils.misc import download_url_to_local, check_md5
from PIL import Image
import sys, pickle
import numpy as np
def check_integrity(fpath, md5=None):
@ -103,7 +104,7 @@ class CIFAR10(Dataset):
from jittor.dataset.cifar import CIFAR10
a = CIFAR10()
a.set_attr(batch_size=16)
a.set_attrs(batch_size=16)
for imgs, labels in a:
print(imgs.shape, labels.shape)
break
@ -242,7 +243,7 @@ class CIFAR100(CIFAR10):
from jittor.dataset.cifar import CIFAR100
a = CIFAR100()
a.set_attr(batch_size=16)
a.set_attrs(batch_size=16)
for imgs, labels in a:
print(imgs.shape, labels.shape)
break

View File

@ -165,9 +165,11 @@ class TestDatasetSeed(unittest.TestCase):
def test_cifar(self):
from jittor.dataset.cifar import CIFAR10
a = CIFAR10()
a.set_attr(batch_size=16)
a.set_attrs(batch_size=16)
for imgs, labels in a:
print(imgs.shape, labels.shape)
assert imgs.shape == [16,32,32,3,]
assert labels.shape == [16,]
break