test cub cumsum

This commit is contained in:
cxjyxx_me 2021-09-02 04:02:40 -04:00
parent 5b830dcf66
commit 393b83e22c
1 changed files with 101 additions and 0 deletions

View File

@ -0,0 +1,101 @@
# ***************************************************************
# Copyright (c) 2021 Jittor. All Rights Reserved.
# Maintainers:
# Guoye Yang <498731903@qq.com>
# Dun Liang <randonlang@gmail.com>.
#
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
from jittor import compile_extern
if jt.has_cuda:
from jittor.compile_extern import cublas_ops, cudnn_ops, cub_ops
else:
cublas_ops = cudnn_ops = cub_ops = None
def test_forward(shape, dim=None):
x = jt.random(shape)
y = jt.numpy_cumsum(x)
y_ = jt.cub_cumsum(x)
assert(np.allclose(y.data, y_.data))
def test_backward(shape, dim=None):
x = jt.random(shape)
z = jt.random(shape)
y = jt.numpy_cumsum(x)
loss = (y * z).sum()
grad = jt.grad(loss, x)
y_ = jt.cub_cumsum(x)
loss_ = (y_ * z).sum()
grad_ = jt.grad(loss_, x)
assert(np.allclose(grad.data, grad_.data))
class TestCubCumsumOp(unittest.TestCase):
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_1d(self):
test_forward([20])
test_forward([3007])
test_forward([3007], 0)
test_forward([3007], -1)
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_1d_backward(self):
test_backward([20])
test_backward([3007])
test_backward([3007], 0)
test_backward([3007], -1)
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_2d(self):
test_forward([5,5])
test_forward([2000, 3007])
test_forward([2000, 3007], 1)
test_forward([2000, 3007], -1)
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_2d_backward(self):
test_backward([5,5])
test_backward([2000, 3007])
test_backward([2000, 3007], 1)
test_backward([2000, 3007], -1)
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_nd(self):
test_forward([5,6,7,8], 0)
test_forward([5,6,7,8], 1)
test_forward([5,6,7,8], 2)
test_forward([5,6,7,8], 3)
test_forward([5,6,7,8], -1)
test_forward([16,14,14,2048], 0)
test_forward([16,14,14,2048], 1)
test_forward([16,14,14,2048], 2)
test_forward([16,14,14,2048], 3)
test_forward([16,14,14,2048], -1)
@unittest.skipIf(cub_ops==None, "Not use cub, Skip")
@jt.flag_scope(use_cuda=1)
def test_nd_backward(self):
test_backward([5,6,7,8], 0)
test_backward([5,6,7,8], 1)
test_backward([5,6,7,8], 2)
test_backward([5,6,7,8], 3)
test_backward([5,6,7,8], -1)
test_backward([16,14,14,2048], 0)
test_backward([16,14,14,2048], 1)
test_backward([16,14,14,2048], 2)
test_backward([16,14,14,2048], 3)
test_backward([16,14,14,2048], -1)
if __name__ == "__main__":
unittest.main()