optimize cumsum with unroll

This commit is contained in:
Dun Liang 2021-09-02 17:22:46 +08:00
parent 546860e19e
commit 4e712f283f
3 changed files with 18 additions and 15 deletions

View File

@ -9,7 +9,7 @@
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.3.96'
__version__ = '1.2.3.97'
from jittor_utils import lock
with lock.lock_scope():
ori_int = int

View File

@ -67,25 +67,26 @@ __global__ void BlockScanKernel(Tx* __restrict__ xp, Ty* __restrict__ yp, int ba
items = num_items - block_offset;
}
Tx thread_data[ITEMS_PER_THREAD];
#if reverse
for (int i = 0; i < items; ++i) {
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (i<items)
#if reverse
thread_data[i] = xp[batch_id * num_items + (num_items - 1 - (block_offset + i))];
}
#else
for (int i = 0; i < items; ++i) {
#else
thread_data[i] = xp[batch_id * num_items + block_offset + i];
}
#endif
#endif
}
BlockScanT(temp_storage).InclusiveSum(thread_data, thread_data);
#if reverse
for (int i = 0; i < items; ++i) {
__syncthreads();
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; ++i) {
if (i<items)
#if reverse
yp[batch_id * num_items + (num_items - 1 - (block_offset + i))] = thread_data[i];
}
#else
for (int i = 0; i < items; ++i) {
#else
yp[batch_id * num_items + block_offset + i] = thread_data[i];
}
#endif
#endif
}
}
}

View File

@ -82,6 +82,7 @@ class TestCubCumsumOp(unittest.TestCase):
test_forward([16,14,14,2048], 2)
test_forward([16,14,14,2048], 3)
test_forward([16,14,14,2048], -1)
test_forward([16,14,14,2047], 3)
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
@ -96,6 +97,7 @@ class TestCubCumsumOp(unittest.TestCase):
test_backward([16,14,14,2048], 2)
test_backward([16,14,14,2048], 3)
test_backward([16,14,14,2048], -1)
test_backward([16,14,14,2047], 3)
if __name__ == "__main__":
unittest.main()