fix dataset to_jittor bug

This commit is contained in:
zhouwy19 2020-12-16 15:56:34 +08:00
parent 73eb05b36e
commit 23de860d68
1 changed files with 1 additions and 0 deletions

View File

@ -117,6 +117,7 @@ class Dataset(object):
'''
Change batch data to jittor array, such as np.ndarray, int, and float.
'''
if isinstance(batch, jt.Var): return batch
to_jt = lambda x: jt.array(x).stop_grad() \
if self.stop_grad else jt.array(x)
if isinstance(batch, np.ndarray):