mirror of https://github.com/Jittor/Jittor
add emnist dataset and mac polish
This commit is contained in:
parent
4e712f283f
commit
69325deb45
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.97'
|
||||
__version__ = '1.2.3.98'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -7,6 +7,8 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import os
|
||||
import string
|
||||
import numpy as np
|
||||
import gzip
|
||||
from PIL import Image
|
||||
|
@ -94,3 +96,105 @@ class MNIST(Dataset):
|
|||
for url, md5 in resources:
|
||||
filename = url.rpartition('/')[2]
|
||||
download_url_to_local(url, filename, self.data_root, md5)
|
||||
|
||||
class EMNIST(Dataset):
|
||||
'''
|
||||
Jittor's own class for loading EMNIST dataset.
|
||||
|
||||
Args::
|
||||
|
||||
[in] data_root(str): your data root.
|
||||
[in] split(str): one of 'byclass', 'bymerge', 'balanced', 'letters', 'digits', 'mnist'.
|
||||
[in] train(bool): choose model train or val.
|
||||
[in] download(bool): Download data automatically if download is Ture.
|
||||
[in] batch_size(int): Data batch size.
|
||||
[in] shuffle(bool): Shuffle data if true.
|
||||
[in] transform(jittor.transform): transform data.
|
||||
|
||||
Example::
|
||||
|
||||
from jittor.dataset.mnist import EMNIST
|
||||
train_loader = EMNIST(train=True).set_attrs(batch_size=16, shuffle=True)
|
||||
for i, (imgs, target) in enumerate(train_loader):
|
||||
...
|
||||
'''
|
||||
|
||||
_merged_classes = {'c', 'i', 'j', 'k', 'l', 'm', 'o', 'p', 's', 'u', 'v', 'w', 'x', 'y', 'z'}
|
||||
_all_classes = set(string.digits + string.ascii_letters)
|
||||
classes_split_dict = {
|
||||
'byclass': sorted(list(_all_classes)),
|
||||
'bymerge': sorted(list(_all_classes - _merged_classes)),
|
||||
'balanced': sorted(list(_all_classes - _merged_classes)),
|
||||
'letters': ['N/A'] + list(string.ascii_lowercase),
|
||||
'digits': list(string.digits),
|
||||
'mnist': list(string.digits),
|
||||
}
|
||||
|
||||
def __init__(self, data_root=dataset_root+"/emnist_data/",
|
||||
split='byclass',
|
||||
train=True,
|
||||
download=True,
|
||||
batch_size = 16,
|
||||
shuffle = False,
|
||||
transform=None):
|
||||
# if you want to test resnet etc you should set input_channel = 3, because the net set 3 as the input dimensions
|
||||
super().__init__()
|
||||
self.data_root = data_root
|
||||
self.is_train = train
|
||||
self.transform = transform
|
||||
self.batch_size = batch_size
|
||||
self.shuffle = shuffle
|
||||
if download == True:
|
||||
self.download_url()
|
||||
data_root = os.path.join(data_root, "gzip")
|
||||
|
||||
filesname = [
|
||||
f"emnist-{split}-train-images-idx3-ubyte.gz",
|
||||
f"emnist-{split}-t10k-images-idx3-ubyte.gz",
|
||||
f"emnist-{split}-train-labels-idx1-ubyte.gz",
|
||||
f"emnist-{split}-t10k-labels-idx1-ubyte.gz"
|
||||
]
|
||||
for i in range(4):
|
||||
filesname[i] = os.path.join(data_root, filesname[i])
|
||||
self.mnist = {}
|
||||
if self.is_train:
|
||||
with gzip.open(filesname[0], 'rb') as f:
|
||||
self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28).transpose(0,2,1)
|
||||
with gzip.open(filesname[2], 'rb') as f:
|
||||
self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
else:
|
||||
with gzip.open(filesname[1], 'rb') as f:
|
||||
self.mnist["images"] = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1,28, 28).transpose(0,2,1)
|
||||
with gzip.open(filesname[3], 'rb') as f:
|
||||
self.mnist["labels"] = np.frombuffer(f.read(), np.uint8, offset=8)
|
||||
assert(self.mnist["images"].shape[0] == self.mnist["labels"].shape[0])
|
||||
self.total_len = self.mnist["images"].shape[0]
|
||||
# this function must be called
|
||||
self.set_attrs(total_len = self.total_len)
|
||||
|
||||
def __getitem__(self, index):
|
||||
img = Image.fromarray(self.mnist['images'][index]).convert('RGB')
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
return trans.to_tensor(img), self.mnist['labels'][index]
|
||||
|
||||
def download_url(self):
|
||||
'''
|
||||
Download mnist data set function, this function will be called when download is True.
|
||||
'''
|
||||
resources = [
|
||||
("https://www.itl.nist.gov/iaui/vip/cs_links/EMNIST/gzip.zip", "58c8d27c78d21e728a6bc7b3cc06412e"),
|
||||
]
|
||||
|
||||
for url, md5 in resources:
|
||||
filename = "emnist.zip"
|
||||
download_url_to_local(url, filename, self.data_root, md5)
|
||||
import zipfile
|
||||
zf = zipfile.ZipFile(os.path.join(self.data_root, filename))
|
||||
try:
|
||||
zf.extractall(path=self.data_root)
|
||||
except RuntimeError as e:
|
||||
print(e)
|
||||
raise
|
||||
zf.close()
|
||||
|
||||
|
|
|
@ -16,12 +16,19 @@ const char* AlignedAllocator::name() const {return "aligned";}
|
|||
void* AlignedAllocator::alloc(size_t size, size_t& allocation) {
|
||||
#ifdef __APPLE__
|
||||
size += 32-size%32;
|
||||
#endif
|
||||
// low version of mac don't have aligned_alloc
|
||||
return new char[size];
|
||||
#else
|
||||
return aligned_alloc(alignment, size);
|
||||
#endif
|
||||
}
|
||||
|
||||
void AlignedAllocator::free(void* mem_ptr, size_t size, const size_t& allocation) {
|
||||
#ifdef __APPLE__
|
||||
delete[] (char*)mem_ptr;
|
||||
#else
|
||||
::free(mem_ptr);
|
||||
#endif
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -0,0 +1,31 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
from jittor.dataset.mnist import EMNIST, MNIST
|
||||
import jittor.transform as transform
|
||||
|
||||
@unittest.skipIf(True, f"skip emnist test")
|
||||
class TestEMNIST(unittest.TestCase):
|
||||
def test_emnist(self):
|
||||
import pylab as pl
|
||||
# emnist_dataset = EMNIST()
|
||||
emnist_dataset = EMNIST()
|
||||
for imgs, labels in emnist_dataset:
|
||||
print(imgs.shape, labels.shape)
|
||||
print(labels.max(), labels.min())
|
||||
# imgs = imgs.transpose(0,1,3,2).transpose(1,2,0,3)[0].reshape(28, -1)
|
||||
imgs = imgs.transpose(1,2,0,3)[0].reshape(28, -1)
|
||||
print(labels)
|
||||
pl.imshow(imgs), pl.show()
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue