mirror of https://github.com/Jittor/Jittor
55 lines
1.9 KiB
Python
55 lines
1.9 KiB
Python
# ***************************************************************
|
|
# Copyright (c) 2020 Jittor. Authors: Dun Liang <randonlang@gmail.com>. All Rights Reserved.
|
|
# This file is subject to the terms and conditions defined in
|
|
# file 'LICENSE.txt', which is part of this source code package.
|
|
# ***************************************************************
|
|
import unittest
|
|
import jittor as jt
|
|
import numpy as np
|
|
from .test_grad import ngrad
|
|
from .test_cuda import test_cuda
|
|
|
|
def check(op, *args):
|
|
x = eval(f"np.{op}(*args)")
|
|
y = eval(f"jt.{op}(*args).data")
|
|
convert = lambda x: x.astype("uint8") if x.dtype=="bool" else x
|
|
x = convert(x)
|
|
y = convert(y)
|
|
# str match nan and inf
|
|
assert x.dtype == y.dtype and x.shape == y.shape and str(x)==str(y), f"{x}\n{y}"
|
|
|
|
class TestUnaryOp(unittest.TestCase):
|
|
def test_unary_op(self):
|
|
assert jt.float64(1).data.dtype == "float64"
|
|
assert (jt.abs(-1) == 1).data.all()
|
|
assert (abs(-jt.float64(1)) == 1).data.all()
|
|
a = [-1,2,3,0]
|
|
check("abs", a)
|
|
check("negative", a)
|
|
check("logical_not", a)
|
|
check("bitwise_not", a)
|
|
b = [1.1, 2.2, 3.3, 4.4, -1, 0]
|
|
check("log", a)
|
|
check("exp", a)
|
|
check("sqrt", a)
|
|
|
|
def test_grad(self):
|
|
ops = ["abs", "negative", "log", "exp", "sqrt"]
|
|
a = [1.1, 2.2, 3.3, 4.4]
|
|
for op in ops:
|
|
if op == "abs":
|
|
b = np.array(a+[-1,])
|
|
else:
|
|
b = np.array(a)
|
|
func = lambda x: eval(f"np.{op}(x[0]).sum()")
|
|
x, (da,) = ngrad(func, [b], 1e-8)
|
|
ja = jt.array(b)
|
|
jb = eval(f"jt.{op}(ja)")
|
|
jda = jt.grad(jb, ja)
|
|
assert (np.abs(jda.data-da)<1e-5).all(), (jda.data,da,op)
|
|
|
|
class TestUnaryOpCuda(TestUnaryOp, test_cuda(2)):
|
|
pass
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main() |