mirror of https://github.com/Jittor/Jittor
add sparse
This commit is contained in:
parent
1f8ea20f01
commit
b4a01c9b57
|
@ -0,0 +1,53 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# Xiangli Li <190569238@qq.com>
|
||||
#
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
|
||||
class SparseVar:
|
||||
def __init__(self,indices,values,shape):
|
||||
assert isinstance(indices,jt.Var) and isinstance(values,jt.Var) and isinstance(shape,jt.NanoVector)
|
||||
self.indices = indices
|
||||
self.values = values
|
||||
self.shape = shape
|
||||
self.ndim = len(shape)
|
||||
|
||||
def _indices(self):
|
||||
return self.indices
|
||||
|
||||
def _values(self):
|
||||
return self.values
|
||||
|
||||
def t(self):
|
||||
indices = list(self.indices.split(1,dim=0))
|
||||
indices[-1],indices[-2] = indices[-2],indices[-1]
|
||||
indices = jt.contrib.concat(indices,dim=0)
|
||||
shape = list(self.shape)
|
||||
shape[-1],shape[-2] = shape[-2],shape[-1]
|
||||
shape = jt.NanoVector(shape)
|
||||
return SparseVar(indices,self.values,shape)
|
||||
|
||||
def to_dense(self):
|
||||
ret = jt.zeros(self.shape,self.values.dtype)
|
||||
indices = tuple(self.indices.split(1,dim=0))
|
||||
ret[indices]=self.values
|
||||
return ret
|
||||
|
||||
def sparse_array(indices,values,shape):
|
||||
return SparseVar(indices,values,shape)
|
||||
|
||||
def spmm(spase_x,y):
|
||||
assert isinstance(spase_x,SparseVar) and isinstance(y,jt.Var)
|
||||
assert spase_x.ndim==2 and y.ndim==2 and spase_x.shape[-1]==y.shape[0]
|
||||
|
||||
# TODO
|
||||
x = spase_x.to_dense()
|
||||
return jt.matmul(x,y)
|
||||
|
|
@ -40,6 +40,21 @@ class TestLoss(unittest.TestCase):
|
|||
jt_y=jt_loss(jt.array(output), jt.array(target))
|
||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target))
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
def test_nll_loss(self):
|
||||
tc_loss = tnn.functional.nll_loss
|
||||
jt_loss = jnn.nll_loss
|
||||
output=np.random.randn(10,10).astype(np.float32)
|
||||
target=np.random.randint(10, size=(10))
|
||||
jt_y=jt_loss(jt.array(output), jt.array(target),reduction='mean')
|
||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),reduction='mean')
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
output=np.random.randn(10,10).astype(np.float32)
|
||||
target=np.random.randint(10, size=(10))
|
||||
weight=np.random.randn(10,).astype(np.float32)
|
||||
jt_y=jt_loss(jt.array(output), jt.array(target),jt.array(weight),reduction='mean')
|
||||
tc_y=tc_loss(torch.from_numpy(output), torch.from_numpy(target),torch.from_numpy(weight),reduction='mean')
|
||||
assert np.allclose(jt_y.numpy(), tc_y.numpy())
|
||||
|
||||
def test_cross_entropy_loss(self):
|
||||
jt_loss=jnn.CrossEntropyLoss()
|
||||
|
|
|
@ -28,7 +28,7 @@ def check_equal(arr, j_layer, p_layer):
|
|||
pytorch_arr = torch.Tensor(arr)
|
||||
jittor_result = j_layer(jittor_arr)
|
||||
pytorch_result = p_layer(pytorch_arr)
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy())
|
||||
assert np.allclose(pytorch_result.detach().numpy(), jittor_result.numpy(),rtol=1e-5,atol=1e-5)
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestRelu(unittest.TestCase):
|
||||
|
@ -61,6 +61,15 @@ class TestRelu(unittest.TestCase):
|
|||
check_equal(arr, jnn.LeakyReLU(), tnn.LeakyReLU())
|
||||
check_equal(arr, jnn.LeakyReLU(2), tnn.LeakyReLU(2))
|
||||
check_equal(arr, jnn.LeakyReLU(99.9), tnn.LeakyReLU(99.9))
|
||||
|
||||
# ***************************************************************
|
||||
# Test ELU Layer
|
||||
# ***************************************************************
|
||||
arr = np.random.randn(16,10,224,224)
|
||||
check_equal(arr, jnn.ELU(), tnn.ELU())
|
||||
check_equal(arr, jnn.ELU(0.3), tnn.ELU(0.3))
|
||||
check_equal(arr, jnn.ELU(2), tnn.ELU(2))
|
||||
check_equal(arr, jnn.ELU(99.9), tnn.ELU(99.9))
|
||||
|
||||
# ***************************************************************
|
||||
# Test GELU Layer
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
|
||||
# ***************************************************************
|
||||
# Copyright (c) 2020 Jittor. Authors:
|
||||
# Xiangli Li <1905692338@qq.com>
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# All Rights Reserved.
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import jittor.nn as jnn
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
except:
|
||||
torch = None
|
||||
tnn = None
|
||||
skip_this_test = True
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestSparse(unittest.TestCase):
|
||||
def test_sparse_var(self):
|
||||
indices = np.array([[0,1,1],[2,0,2]])
|
||||
values = np.array([3,4,5]).astype(np.float32)
|
||||
shape = [2,3]
|
||||
jt_array = jt.sparse.sparse_array(jt.array(indices),jt.array(values),jt.NanoVector(shape))
|
||||
torch_tensor = torch.sparse.FloatTensor(torch.from_numpy(indices),torch.from_numpy(values),torch.Size(shape))
|
||||
jt_numpy = jt_array.to_dense().numpy()
|
||||
torch_numpy = torch_tensor.to_dense().numpy()
|
||||
assert np.allclose(jt_numpy,torch_numpy)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue