mirror of https://github.com/Jittor/Jittor
dataset and dataloader interface polish
This commit is contained in:
parent
1f06bbf22e
commit
5bc160b19c
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.3.5.32'
|
||||
__version__ = '1.3.5.33'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -129,6 +129,7 @@ class Dataset(object):
|
|||
self.sampler = None
|
||||
self._disable_workers = False
|
||||
self._shuffle_rng = np.random.default_rng(1)
|
||||
self.dataset = self
|
||||
|
||||
def __getitem__(self, index):
|
||||
raise NotImplementedError
|
||||
|
@ -585,6 +586,8 @@ Example::
|
|||
self.batch_id += 1
|
||||
yield batch_data
|
||||
|
||||
def DataLoader(dataset: Dataset, *args, **kargs):
|
||||
return dataset.set_attrs(*args, **kargs)
|
||||
|
||||
class ImageFolder(Dataset):
|
||||
"""
|
||||
|
|
|
@ -1977,3 +1977,11 @@ def isneginf(x): return _simple_for(x, "x<0 && isinf(x)")
|
|||
jt.Var.isneginf = isneginf
|
||||
def isposinf(x): return _simple_for(x, "x>0 && isinf(x)")
|
||||
jt.Var.isposinf = isposinf
|
||||
|
||||
# fake torch interface
|
||||
def contiguous(x): return x.clone()
|
||||
jt.Var.contiguous = contiguous
|
||||
def cpu(x): return x.clone()
|
||||
jt.Var.cpu = cpu
|
||||
def to(x, *args, **kargs): return x.clone()
|
||||
jt.Var.to = to
|
||||
|
|
|
@ -1372,7 +1372,7 @@ class ConvTranspose(Module):
|
|||
b = self.bias.broadcast(y.shape, [0,2,3])
|
||||
y = y + b
|
||||
return y
|
||||
|
||||
ConvTranspose2d = ConvTranspose
|
||||
|
||||
class ConvTranspose3d(Module):
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, \
|
||||
|
|
Loading…
Reference in New Issue