mirror of https://github.com/Jittor/Jittor
polish cifar
This commit is contained in:
parent
e160a83a7e
commit
e4089ecc4a
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.29'
|
||||
__version__ = '1.2.3.30'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -6,6 +6,7 @@ import zipfile
|
|||
from jittor_utils.misc import download_url_to_local, check_md5
|
||||
from PIL import Image
|
||||
import sys, pickle
|
||||
import numpy as np
|
||||
|
||||
|
||||
def check_integrity(fpath, md5=None):
|
||||
|
@ -103,7 +104,7 @@ class CIFAR10(Dataset):
|
|||
|
||||
from jittor.dataset.cifar import CIFAR10
|
||||
a = CIFAR10()
|
||||
a.set_attr(batch_size=16)
|
||||
a.set_attrs(batch_size=16)
|
||||
for imgs, labels in a:
|
||||
print(imgs.shape, labels.shape)
|
||||
break
|
||||
|
@ -242,7 +243,7 @@ class CIFAR100(CIFAR10):
|
|||
|
||||
from jittor.dataset.cifar import CIFAR100
|
||||
a = CIFAR100()
|
||||
a.set_attr(batch_size=16)
|
||||
a.set_attrs(batch_size=16)
|
||||
for imgs, labels in a:
|
||||
print(imgs.shape, labels.shape)
|
||||
break
|
||||
|
|
|
@ -165,9 +165,11 @@ class TestDatasetSeed(unittest.TestCase):
|
|||
def test_cifar(self):
|
||||
from jittor.dataset.cifar import CIFAR10
|
||||
a = CIFAR10()
|
||||
a.set_attr(batch_size=16)
|
||||
a.set_attrs(batch_size=16)
|
||||
for imgs, labels in a:
|
||||
print(imgs.shape, labels.shape)
|
||||
assert imgs.shape == [16,32,32,3,]
|
||||
assert labels.shape == [16,]
|
||||
break
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue