From b4a01c9b5752eff536b34f2f2fda3770bdfe12c8 Mon Sep 17 00:00:00 2001 From: li-xl <1905692338@qq.com> Date: Tue, 15 Dec 2020 15:18:22 +0800 Subject: [PATCH] add sparse --- python/jittor/sparse.py | 53 +++++++++++++++++++++++++++++++ python/jittor/test/test_loss.py | 15 +++++++++ python/jittor/test/test_relu.py | 11 ++++++- python/jittor/test/test_sparse.py | 39 +++++++++++++++++++++++ 4 files changed, 117 insertions(+), 1 deletion(-) create mode 100644 python/jittor/sparse.py create mode 100644 python/jittor/test/test_sparse.py diff --git a/python/jittor/sparse.py b/python/jittor/sparse.py new file mode 100644 index 00000000..d8f206e2 --- /dev/null +++ b/python/jittor/sparse.py @@ -0,0 +1,53 @@ +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Dun Liang . +# Xiangli Li <190569238@qq.com> +# +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** + +import jittor as jt +import numpy as np + +class SparseVar: + def __init__(self,indices,values,shape): + assert isinstance(indices,jt.Var) and isinstance(values,jt.Var) and isinstance(shape,jt.NanoVector) + self.indices = indices + self.values = values + self.shape = shape + self.ndim = len(shape) + + def _indices(self): + return self.indices + + def _values(self): + return self.values + + def t(self): + indices = list(self.indices.split(1,dim=0)) + indices[-1],indices[-2] = indices[-2],indices[-1] + indices = jt.contrib.concat(indices,dim=0) + shape = list(self.shape) + shape[-1],shape[-2] = shape[-2],shape[-1] + shape = jt.NanoVector(shape) + return SparseVar(indices,self.values,shape) + + def to_dense(self): + ret = jt.zeros(self.shape,self.values.dtype) + indices = tuple(self.indices.split(1,dim=0)) + ret[indices]=self.values + return ret + +def sparse_array(indices,values,shape): + return SparseVar(indices,values,shape) + +def spmm(spase_x,y): + assert isinstance(spase_x,SparseVar) and isinstance(y,jt.Var) + assert spase_x.ndim==2 and y.ndim==2 and spase_x.shape[-1]==y.shape[0] + + # TODO + x = spase_x.to_dense() + return jt.matmul(x,y) + \ No newline at end of file diff --git a/python/jittor/test/test_loss.py b/python/jittor/test/test_loss.py index 9a96dc66..0cea7264 100644 --- a/python/jittor/test/test_loss.py +++ b/python/jittor/test/test_loss.py @@ -40,6 +40,21 @@ class TestLoss(unittest.TestCase): jt_y=jt_loss(jt.array(output), jt.array(target)) tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) assert np.allclose(jt_y.numpy(), tc_y.numpy()) + + def test_nll_loss(self): + tc_loss = tnn.functional.nll_loss + jt_loss = jnn.nll_loss + output=np.random.randn(10,10).astype(np.float32) + target=np.random.randint(10, size=(10)) + jt_y=jt_loss(jt.array(output), jt.array(target),reduction='mean') + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),reduction='mean') + assert np.allclose(jt_y.numpy(), tc_y.numpy()) + output=np.random.randn(10,10).astype(np.float32) + target=np.random.randint(10, size=(10)) + weight=np.random.randn(10,).astype(np.float32) + jt_y=jt_loss(jt.array(output), jt.array(target),jt.array(weight),reduction='mean') + tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),torch.from_numpy(weight),reduction='mean') + assert np.allclose(jt_y.numpy(), tc_y.numpy()) def test_cross_entropy_loss(self): jt_loss=jnn.CrossEntropyLoss() diff --git a/python/jittor/test/test_relu.py b/python/jittor/test/test_relu.py index 6c260d26..f382c347 100644 --- a/python/jittor/test/test_relu.py +++ b/python/jittor/test/test_relu.py @@ -28,7 +28,7 @@ def check_equal(arr, j_layer, p_layer): pytorch_arr = torch.Tensor(arr) jittor_result = j_layer(jittor_arr) pytorch_result = p_layer(pytorch_arr) - assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy()) + assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(),rtol=1e-5,atol=1e-5) @unittest.skipIf(skip_this_test, "No Torch found") class TestRelu(unittest.TestCase): @@ -61,6 +61,15 @@ class TestRelu(unittest.TestCase): check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU()) check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2)) check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9)) + + # *************************************************************** + # Test ELU Layer + # *************************************************************** + arr = np.random.randn(16,10,224,224) + check_equal(arr, jnn.ELU(), tnn.ELU()) + check_equal(arr, jnn.ELU(0.3), tnn.ELU(0.3)) + check_equal(arr, jnn.ELU(2), tnn.ELU(2)) + check_equal(arr, jnn.ELU(99.9), tnn.ELU(99.9)) # *************************************************************** # Test GELU Layer diff --git a/python/jittor/test/test_sparse.py b/python/jittor/test/test_sparse.py new file mode 100644 index 00000000..ae5206c5 --- /dev/null +++ b/python/jittor/test/test_sparse.py @@ -0,0 +1,39 @@ + +# *************************************************************** +# Copyright (c) 2020 Jittor. Authors: +# Xiangli Li <1905692338@qq.com> +# Dun Liang . +# All Rights Reserved. +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + torch = None + tnn = None + skip_this_test = True + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestSparse(unittest.TestCase): + def test_sparse_var(self): + indices = np.array([[0,1,1],[2,0,2]]) + values = np.array([3,4,5]).astype(np.float32) + shape = [2,3] + jt_array = jt.sparse.sparse_array(jt.array(indices),jt.array(values),jt.NanoVector(shape)) + torch_tensor = torch.sparse.FloatTensor(torch.from_numpy(indices),torch.from_numpy(values),torch.Size(shape)) + jt_numpy = jt_array.to_dense().numpy() + torch_numpy = torch_tensor.to_dense().numpy() + assert np.allclose(jt_numpy,torch_numpy) + +if __name__ == "__main__": + unittest.main() \ No newline at end of file