add sparse

This commit is contained in:
li-xl 2020-12-15 15:18:22 +08:00 committed by Dun Liang
parent 1f8ea20f01
commit b4a01c9b57
4 changed files with 117 additions and 1 deletions

53
python/jittor/sparse.py Normal file
View File

@ -0,0 +1,53 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Dun Liang <randonlang@gmail.com>.
# Xiangli Li <190569238@qq.com>
#
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import jittor as jt
import numpy as np
class SparseVar:
def __init__(self,indices,values,shape):
assert isinstance(indices,jt.Var) and isinstance(values,jt.Var) and isinstance(shape,jt.NanoVector)
self.indices = indices
self.values = values
self.shape = shape
self.ndim = len(shape)
def _indices(self):
return self.indices
def _values(self):
return self.values
def t(self):
indices = list(self.indices.split(1,dim=0))
indices[-1],indices[-2] = indices[-2],indices[-1]
indices = jt.contrib.concat(indices,dim=0)
shape = list(self.shape)
shape[-1],shape[-2] = shape[-2],shape[-1]
shape = jt.NanoVector(shape)
return SparseVar(indices,self.values,shape)
def to_dense(self):
ret = jt.zeros(self.shape,self.values.dtype)
indices = tuple(self.indices.split(1,dim=0))
ret[indices]=self.values
return ret
def sparse_array(indices,values,shape):
return SparseVar(indices,values,shape)
def spmm(spase_x,y):
assert isinstance(spase_x,SparseVar) and isinstance(y,jt.Var)
assert spase_x.ndim==2 and y.ndim==2 and spase_x.shape[-1]==y.shape[0]
# TODO
x = spase_x.to_dense()
return jt.matmul(x,y)

View File

@ -40,6 +40,21 @@ class TestLoss(unittest.TestCase):
jt_y=jt_loss(jt.array(output), jt.array(target)) jt_y=jt_loss(jt.array(output), jt.array(target))
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target)) tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
assert np.allclose(jt_y.numpy(), tc_y.numpy()) assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_nll_loss(self):
tc_loss = tnn.functional.nll_loss
jt_loss = jnn.nll_loss
output=np.random.randn(10,10).astype(np.float32)
target=np.random.randint(10, size=(10))
jt_y=jt_loss(jt.array(output), jt.array(target),reduction='mean')
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),reduction='mean')
assert np.allclose(jt_y.numpy(), tc_y.numpy())
output=np.random.randn(10,10).astype(np.float32)
target=np.random.randint(10, size=(10))
weight=np.random.randn(10,).astype(np.float32)
jt_y=jt_loss(jt.array(output), jt.array(target),jt.array(weight),reduction='mean')
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),torch.from_numpy(weight),reduction='mean')
assert np.allclose(jt_y.numpy(), tc_y.numpy())
def test_cross_entropy_loss(self): def test_cross_entropy_loss(self):
jt_loss=jnn.CrossEntropyLoss() jt_loss=jnn.CrossEntropyLoss()

View File

@ -28,7 +28,7 @@ def check_equal(arr, j_layer, p_layer):
pytorch_arr = torch.Tensor(arr) pytorch_arr = torch.Tensor(arr)
jittor_result = j_layer(jittor_arr) jittor_result = j_layer(jittor_arr)
pytorch_result = p_layer(pytorch_arr) pytorch_result = p_layer(pytorch_arr)
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy()) assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(),rtol=1e-5,atol=1e-5)
@unittest.skipIf(skip_this_test, "No Torch found") @unittest.skipIf(skip_this_test, "No Torch found")
class TestRelu(unittest.TestCase): class TestRelu(unittest.TestCase):
@ -61,6 +61,15 @@ class TestRelu(unittest.TestCase):
check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU()) check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU())
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2)) check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9)) check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
# ***************************************************************
# Test ELU Layer
# ***************************************************************
arr = np.random.randn(16,10,224,224)
check_equal(arr, jnn.ELU(), tnn.ELU())
check_equal(arr, jnn.ELU(0.3), tnn.ELU(0.3))
check_equal(arr, jnn.ELU(2), tnn.ELU(2))
check_equal(arr, jnn.ELU(99.9), tnn.ELU(99.9))
# *************************************************************** # ***************************************************************
# Test GELU Layer # Test GELU Layer

View File

@ -0,0 +1,39 @@
# ***************************************************************
# Copyright (c) 2020 Jittor. Authors:
# Xiangli Li <1905692338@qq.com>
# Dun Liang <randonlang@gmail.com>.
# All Rights Reserved.
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
import unittest
import jittor as jt
import numpy as np
import jittor.nn as jnn
skip_this_test = False
try:
jt.dirty_fix_pytorch_runtime_error()
import torch
import torch.nn as tnn
except:
torch = None
tnn = None
skip_this_test = True
@unittest.skipIf(skip_this_test, "No Torch found")
class TestSparse(unittest.TestCase):
def test_sparse_var(self):
indices = np.array([[0,1,1],[2,0,2]])
values = np.array([3,4,5]).astype(np.float32)
shape = [2,3]
jt_array = jt.sparse.sparse_array(jt.array(indices),jt.array(values),jt.NanoVector(shape))
torch_tensor = torch.sparse.FloatTensor(torch.from_numpy(indices),torch.from_numpy(values),torch.Size(shape))
jt_numpy = jt_array.to_dense().numpy()
torch_numpy = torch_tensor.to_dense().numpy()
assert np.allclose(jt_numpy,torch_numpy)
if __name__ == "__main__":
unittest.main()