dataset and dataloader interface polish

This commit is contained in:
Dun Liang 2022-11-16 11:31:55 +08:00
parent 1f06bbf22e
commit 5bc160b19c
4 changed files with 13 additions and 2 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.3.5.32'
__version__ = '1.3.5.33'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -129,6 +129,7 @@ class Dataset(object):
self.sampler = None
self._disable_workers = False
self._shuffle_rng = np.random.default_rng(1)
self.dataset = self
def __getitem__(self, index):
raise NotImplementedError
@ -585,6 +586,8 @@ Example::
self.batch_id += 1
yield batch_data
def DataLoader(dataset: Dataset, *args, **kargs):
return dataset.set_attrs(*args, **kargs)
class ImageFolder(Dataset):
"""

View File

@ -1977,3 +1977,11 @@ def isneginf(x): return _simple_for(x, "x<0 && isinf(x)")
jt.Var.isneginf = isneginf
def isposinf(x): return _simple_for(x, "x>0 && isinf(x)")
jt.Var.isposinf = isposinf
# fake torch interface
def contiguous(x): return x.clone()
jt.Var.contiguous = contiguous
def cpu(x): return x.clone()
jt.Var.cpu = cpu
def to(x, *args, **kargs): return x.clone()
jt.Var.to = to

View File

@ -1372,7 +1372,7 @@ class ConvTranspose(Module):
b = self.bias.broadcast(y.shape, [0,2,3])
y = y + b
return y
ConvTranspose2d = ConvTranspose
class ConvTranspose3d(Module):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \