mirror of https://github.com/Jittor/Jittor
add support jt.var / Var.var to compute variance.
Acknowledgement: Thanks fangtiancheng https://discuss.jittor.org/t/topic/193/3 for a demo implementation.
This commit is contained in:
parent
9b50f6370c
commit
de8a193e9a
|
@ -394,6 +394,58 @@ def zeros_like(x):
|
|||
|
||||
flags = core.Flags()
|
||||
|
||||
def var(x, dim=None, dims=None, unbiased=False, keepdims=False):
|
||||
""" return the sample variance. If unbiased is True, Bessel's correction will be used.
|
||||
|
||||
:param x: the input jittor Var.
|
||||
:type x: jt.Var.
|
||||
:param dim: the dimension to compute the variance. If both dim and dims are None, the variance of the whole tensor will be computed.
|
||||
:type dim: int.
|
||||
:param dims: the dimensions to compute the variance. If both dim and dims are None, the variance of the whole tensor will be computed.
|
||||
:type dims: tuple of int.
|
||||
:param unbiased: if True, Bessel's correction will be used.
|
||||
:type unbiased: bool.
|
||||
:param keepdim: if True, the output shape is same as input shape except for the dimension in dim.
|
||||
:type keepdim: bool.
|
||||
|
||||
Example:
|
||||
|
||||
>>> a = jt.rand(3)
|
||||
>>> a
|
||||
jt.Var([0.79613626 0.29322362 0.19785859], dtype=float32)
|
||||
>>> a.var()
|
||||
jt.Var([0.06888353], dtype=float32)
|
||||
>>> a.var(unbiased=True)
|
||||
jt.Var([0.10332529], dtype=float32)
|
||||
"""
|
||||
shape = x.shape
|
||||
new_shape = list(x.shape)
|
||||
|
||||
assert dim is None or dims is None, "dim and dims can not be both set"
|
||||
if dim is None and dims is None:
|
||||
dims = list(range(len(shape)))
|
||||
elif dim is not None:
|
||||
dims = [dim]
|
||||
|
||||
mean = jt.mean(x, dims, keepdims=True)
|
||||
mean = jt.broadcast(mean, shape)
|
||||
|
||||
n = 1
|
||||
for d in dims:
|
||||
n *= shape[d]
|
||||
new_shape[d] = 1
|
||||
|
||||
sqr = (x - mean) ** 2
|
||||
sqr = jt.sum(sqr, dims=dims, keepdims=False)
|
||||
if unbiased:
|
||||
n -= 1
|
||||
sqr /= n
|
||||
|
||||
if keepdims:
|
||||
sqr = sqr.view(new_shape)
|
||||
return sqr
|
||||
Var.var = var
|
||||
|
||||
def std(x):
|
||||
matsize=1
|
||||
for i in x.shape:
|
||||
|
|
|
@ -1,43 +0,0 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
#
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import os
|
||||
import numpy as np
|
||||
import jittor.nn as jnn
|
||||
|
||||
from jittor.test.test_log import find_log_with_re
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
except:
|
||||
skip_this_test = True
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestStd(unittest.TestCase):
|
||||
def test_std(self):
|
||||
x=np.random.randn(100,1000).astype(np.float32)
|
||||
jt_x=jt.array(x)
|
||||
tc_x=torch.from_numpy(x)
|
||||
assert np.allclose(jt_x.std().numpy(), tc_x.std().numpy(), 1e-4) ,(x, jt_x.std().numpy(), tc_x.std().numpy())
|
||||
|
||||
def test_norm(self):
|
||||
x=np.random.randn(100,1000).astype(np.float32)
|
||||
jt_x=jt.array(x)
|
||||
tc_x=torch.from_numpy(x)
|
||||
assert np.allclose(jt_x.norm(1,1).numpy(), tc_x.norm(1,1).numpy())
|
||||
assert np.allclose(jt_x.norm(1,0).numpy(), tc_x.norm(1,0).numpy())
|
||||
assert np.allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy())
|
||||
assert np.allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy())
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -0,0 +1,51 @@
|
|||
# ***************************************************************
|
||||
# Copyright (c) 2021 Jittor. All Rights Reserved.
|
||||
# Maintainers:
|
||||
# Dun Liang <randonlang@gmail.com>.
|
||||
# Zheng-Ning Liu <lzhengning@gmail.com>
|
||||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
import numpy as np
|
||||
import jittor.nn as jnn
|
||||
|
||||
skip_this_test = False
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
except:
|
||||
skip_this_test = True
|
||||
|
||||
@unittest.skipIf(skip_this_test, "No Torch found")
|
||||
class TestVarFunctions(unittest.TestCase):
|
||||
def test_var(self):
|
||||
x = np.random.randn(100, 1000).astype(np.float32)
|
||||
|
||||
jt_x = jt.array(x)
|
||||
tc_x = torch.from_numpy(x)
|
||||
np.testing.assert_allclose(jt_x.var().numpy(), tc_x.var().numpy(), rtol=1e-3, atol=1e-4)
|
||||
np.testing.assert_allclose(jt_x.var(dim=1).numpy(), tc_x.var(dim=1).numpy(), rtol=1e-3, atol=1e-4)
|
||||
np.testing.assert_allclose(jt_x.var(dim=0, unbiased=True).numpy(), tc_x.var(dim=0, unbiased=True).numpy(), rtol=1e-3, atol=1e-4)
|
||||
|
||||
def test_std(self):
|
||||
x=np.random.randn(100, 1000).astype(np.float32)
|
||||
jt_x = jt.array(x)
|
||||
tc_x = torch.from_numpy(x)
|
||||
np.testing.assert_allclose(jt_x.std().numpy(), tc_x.std().numpy(), 1e-4)
|
||||
|
||||
def test_norm(self):
|
||||
x = np.random.randn(100, 1000).astype(np.float32)
|
||||
jt_x = jt.array(x)
|
||||
tc_x = torch.from_numpy(x)
|
||||
np.testing.assert_allclose(jt_x.norm(1,1).numpy(), tc_x.norm(1,1).numpy(), atol=1e-6)
|
||||
np.testing.assert_allclose(jt_x.norm(1,0).numpy(), tc_x.norm(1,0).numpy(), atol=1e-6)
|
||||
np.testing.assert_allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy(), atol=1e-6)
|
||||
np.testing.assert_allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy(), atol=1e-6)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue