mirror of https://github.com/Jittor/Jittor
optimize cumsum with unroll
This commit is contained in:
parent
546860e19e
commit
4e712f283f
|
@ -9,7 +9,7 @@
|
|||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
__version__ = '1.2.3.96'
|
||||
__version__ = '1.2.3.97'
|
||||
from jittor_utils import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -67,25 +67,26 @@ __global__ void BlockScanKernel(Tx* __restrict__ xp, Ty* __restrict__ yp, int ba
|
|||
items = num_items - block_offset;
|
||||
}
|
||||
Tx thread_data[ITEMS_PER_THREAD];
|
||||
#if reverse
|
||||
for (int i = 0; i < items; ++i) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
|
||||
if (i<items)
|
||||
#if reverse
|
||||
thread_data[i] = xp[batch_id * num_items + (num_items - 1 - (block_offset + i))];
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < items; ++i) {
|
||||
#else
|
||||
thread_data[i] = xp[batch_id * num_items + block_offset + i];
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
BlockScanT(temp_storage).InclusiveSum(thread_data, thread_data);
|
||||
#if reverse
|
||||
for (int i = 0; i < items; ++i) {
|
||||
__syncthreads();
|
||||
#pragma unroll
|
||||
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
|
||||
if (i<items)
|
||||
#if reverse
|
||||
yp[batch_id * num_items + (num_items - 1 - (block_offset + i))] = thread_data[i];
|
||||
}
|
||||
#else
|
||||
for (int i = 0; i < items; ++i) {
|
||||
#else
|
||||
yp[batch_id * num_items + block_offset + i] = thread_data[i];
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -82,6 +82,7 @@ class TestCubCumsumOp(unittest.TestCase):
|
|||
test_forward([16,14,14,2048], 2)
|
||||
test_forward([16,14,14,2048], 3)
|
||||
test_forward([16,14,14,2048], -1)
|
||||
test_forward([16,14,14,2047], 3)
|
||||
|
||||
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
|
@ -96,6 +97,7 @@ class TestCubCumsumOp(unittest.TestCase):
|
|||
test_backward([16,14,14,2048], 2)
|
||||
test_backward([16,14,14,2048], 3)
|
||||
test_backward([16,14,14,2048], -1)
|
||||
test_backward([16,14,14,2047], 3)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue