polish dataset

This commit is contained in:
li-xl 2021-08-04 16:48:55 +08:00
parent 1f0ea3b796
commit aab4bda835
2 changed files with 4 additions and 4 deletions

View File

@ -147,9 +147,7 @@ class Dataset(object):
return batch
new_batch = []
for a in batch:
if isinstance(a, np.ndarray) or \
isinstance(a, int) or \
isinstance(a, float):
if isinstance(a, np.ndarray):
new_batch.append(to_jt(a))
else:
new_batch.append(self.to_jittor(a))
@ -525,7 +523,9 @@ Example::
batch_data.append(self[int(idx)])
if len(batch_data) == self.real_batch_size:
batch_data = self.collate_batch(batch_data)
tmp = batch_data
batch_data = self.to_jittor(batch_data)
# breakpoint()
yield batch_data
self.batch_id += 1
if CHECK_MEMORY and self.batch_id % CHECK_MEMORY == 0:

View File

@ -105,7 +105,7 @@ void CuttTransposeOp::jit_run() {
cuttExecute(iter->second, xp, yp);
} else {
cuttHandle plan;
checkCudaErrors(cudaDeviceSynchronize());
// checkCudaErrors(cudaDeviceSynchronize());
auto ret = cuttPlan(&plan, dim, x_shape.data(), reverse.data(), x->dtype().dsize(), 0);
CHECK(0==ret) << ret << jk.to_string() << x << y;
cutt_plan_cache[jk.to_string()] = plan;