mirror of https://github.com/Jittor/Jittor
polish dataset
This commit is contained in:
parent
1f0ea3b796
commit
aab4bda835
|
@ -147,9 +147,7 @@ class Dataset(object):
|
|||
return batch
|
||||
new_batch = []
|
||||
for a in batch:
|
||||
if isinstance(a, np.ndarray) or \
|
||||
isinstance(a, int) or \
|
||||
isinstance(a, float):
|
||||
if isinstance(a, np.ndarray):
|
||||
new_batch.append(to_jt(a))
|
||||
else:
|
||||
new_batch.append(self.to_jittor(a))
|
||||
|
@ -525,7 +523,9 @@ Example::
|
|||
batch_data.append(self[int(idx)])
|
||||
if len(batch_data) == self.real_batch_size:
|
||||
batch_data = self.collate_batch(batch_data)
|
||||
tmp = batch_data
|
||||
batch_data = self.to_jittor(batch_data)
|
||||
# breakpoint()
|
||||
yield batch_data
|
||||
self.batch_id += 1
|
||||
if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0:
|
||||
|
|
|
@ -105,7 +105,7 @@ void CuttTransposeOp::jit_run() {
|
|||
cuttExecute(iter->second, xp, yp);
|
||||
} else {
|
||||
cuttHandle plan;
|
||||
checkCudaErrors(cudaDeviceSynchronize());
|
||||
// checkCudaErrors(cudaDeviceSynchronize());
|
||||
auto ret = cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0);
|
||||
CHECK(0==ret) << ret << jk.to_string() << x << y;
|
||||
cutt_plan_cache[jk.to_string()] = plan;
|
||||
|
|
Loading…
Reference in New Issue