mirror of https://github.com/Jittor/Jittor
polish cifar
This commit is contained in:
parent
e4089ecc4a
commit
9c0f3cfdf4
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.30'
|
||||
__version__ = '1.2.3.31'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
|
||||
from .dataset import Dataset, ImageFolder
|
||||
from .dataset import Dataset, ImageFolder, dataset_root
|
||||
from .mnist import MNIST
|
||||
from .cifar import CIFAR10, CIFAR100
|
||||
from .voc import VOC
|
||||
|
|
|
@ -1,87 +1,10 @@
|
|||
from .dataset import Dataset, dataset_root
|
||||
|
||||
import os
|
||||
import gzip
|
||||
import tarfile
|
||||
import zipfile
|
||||
from jittor_utils.misc import download_url_to_local, check_md5
|
||||
from jittor_utils.misc import download_and_extract_archive, check_integrity
|
||||
from PIL import Image
|
||||
import sys, pickle
|
||||
import numpy as np
|
||||
|
||||
|
||||
def check_integrity(fpath, md5=None):
|
||||
if not os.path.isfile(fpath):
|
||||
return False
|
||||
if md5 is None:
|
||||
return True
|
||||
return check_md5(fpath, md5)
|
||||
|
||||
def _is_tarxz(filename):
|
||||
return filename.endswith(".tar.xz")
|
||||
|
||||
|
||||
def _is_tar(filename):
|
||||
return filename.endswith(".tar")
|
||||
|
||||
|
||||
def _is_targz(filename):
|
||||
return filename.endswith(".tar.gz")
|
||||
|
||||
|
||||
def _is_tgz(filename):
|
||||
return filename.endswith(".tgz")
|
||||
|
||||
|
||||
def _is_gzip(filename):
|
||||
return filename.endswith(".gz") and not filename.endswith(".tar.gz")
|
||||
|
||||
|
||||
def _is_zip(filename):
|
||||
return filename.endswith(".zip")
|
||||
|
||||
|
||||
def extract_archive(from_path, to_path=None, remove_finished=False):
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
if _is_tar(from_path):
|
||||
with tarfile.open(from_path, 'r') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_targz(from_path) or _is_tgz(from_path):
|
||||
with tarfile.open(from_path, 'r:gz') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_tarxz(from_path):
|
||||
# .tar.xz archive only supported in Python 3.x
|
||||
with tarfile.open(from_path, 'r:xz') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_gzip(from_path):
|
||||
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
|
||||
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
|
||||
out_f.write(zip_f.read())
|
||||
elif _is_zip(from_path):
|
||||
with zipfile.ZipFile(from_path, 'r') as z:
|
||||
z.extractall(to_path)
|
||||
else:
|
||||
raise ValueError("Extraction of {} not supported".format(from_path))
|
||||
|
||||
if remove_finished:
|
||||
os.remove(from_path)
|
||||
|
||||
|
||||
def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
|
||||
md5=None, remove_finished=False):
|
||||
download_root = os.path.expanduser(download_root)
|
||||
if extract_root is None:
|
||||
extract_root = download_root
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
|
||||
download_url_to_local(url, filename, download_root, md5)
|
||||
|
||||
archive = os.path.join(download_root, filename)
|
||||
print("Extracting {} to {}".format(archive, extract_root))
|
||||
extract_archive(archive, extract_root, remove_finished)
|
||||
|
||||
from jittor.dataset import Dataset, dataset_root
|
||||
|
||||
class CIFAR10(Dataset):
|
||||
"""`CIFAR10 <https://www.cs.toronto.edu/~kriz/cifar.html>`_ Dataset.
|
||||
|
|
|
@ -12,6 +12,9 @@ import hashlib
|
|||
import urllib.request
|
||||
from tqdm import tqdm
|
||||
from jittor_utils import lock
|
||||
import gzip
|
||||
import tarfile
|
||||
import zipfile
|
||||
|
||||
def ensure_dir(dir_path):
|
||||
if not os.path.isdir(dir_path):
|
||||
|
@ -69,3 +72,77 @@ def calculate_md5(file_path, chunk_size=1024 * 1024):
|
|||
def check_md5(file_path, md5, **kwargs):
|
||||
return md5 == calculate_md5(file_path, **kwargs)
|
||||
|
||||
|
||||
def check_integrity(fpath, md5=None):
|
||||
if not os.path.isfile(fpath):
|
||||
return False
|
||||
if md5 is None:
|
||||
return True
|
||||
return check_md5(fpath, md5)
|
||||
|
||||
|
||||
def _is_tarxz(filename):
|
||||
return filename.endswith(".tar.xz")
|
||||
|
||||
|
||||
def _is_tar(filename):
|
||||
return filename.endswith(".tar")
|
||||
|
||||
|
||||
def _is_targz(filename):
|
||||
return filename.endswith(".tar.gz")
|
||||
|
||||
|
||||
def _is_tgz(filename):
|
||||
return filename.endswith(".tgz")
|
||||
|
||||
|
||||
def _is_gzip(filename):
|
||||
return filename.endswith(".gz") and not filename.endswith(".tar.gz")
|
||||
|
||||
|
||||
def _is_zip(filename):
|
||||
return filename.endswith(".zip")
|
||||
|
||||
|
||||
def extract_archive(from_path, to_path=None, remove_finished=False):
|
||||
if to_path is None:
|
||||
to_path = os.path.dirname(from_path)
|
||||
|
||||
if _is_tar(from_path):
|
||||
with tarfile.open(from_path, 'r') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_targz(from_path) or _is_tgz(from_path):
|
||||
with tarfile.open(from_path, 'r:gz') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_tarxz(from_path):
|
||||
# .tar.xz archive only supported in Python 3.x
|
||||
with tarfile.open(from_path, 'r:xz') as tar:
|
||||
tar.extractall(path=to_path)
|
||||
elif _is_gzip(from_path):
|
||||
to_path = os.path.join(to_path, os.path.splitext(os.path.basename(from_path))[0])
|
||||
with open(to_path, "wb") as out_f, gzip.GzipFile(from_path) as zip_f:
|
||||
out_f.write(zip_f.read())
|
||||
elif _is_zip(from_path):
|
||||
with zipfile.ZipFile(from_path, 'r') as z:
|
||||
z.extractall(to_path)
|
||||
else:
|
||||
raise ValueError("Extraction of {} not supported".format(from_path))
|
||||
|
||||
if remove_finished:
|
||||
os.remove(from_path)
|
||||
|
||||
|
||||
def download_and_extract_archive(url, download_root, extract_root=None, filename=None,
|
||||
md5=None, remove_finished=False):
|
||||
download_root = os.path.expanduser(download_root)
|
||||
if extract_root is None:
|
||||
extract_root = download_root
|
||||
if not filename:
|
||||
filename = os.path.basename(url)
|
||||
|
||||
download_url_to_local(url, filename, download_root, md5)
|
||||
|
||||
archive = os.path.join(download_root, filename)
|
||||
print("Extracting {} to {}".format(archive, extract_root))
|
||||
extract_archive(archive, extract_root, remove_finished)
|
||||
|
|
Loading…
Reference in New Issue