diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index fab38fb1..c89e4c27 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -394,6 +394,58 @@ def zeros_like(x): flags = core.Flags() +def var(x, dim=None, dims=None, unbiased=False, keepdims=False): + """ return the sample variance. If unbiased is True, Bessel's correction will be used. + + :param x: the input jittor Var. + :type x: jt.Var. + :param dim: the dimension to compute the variance. If both dim and dims are None, the variance of the whole tensor will be computed. + :type dim: int. + :param dims: the dimensions to compute the variance. If both dim and dims are None, the variance of the whole tensor will be computed. + :type dims: tuple of int. + :param unbiased: if True, Bessel's correction will be used. + :type unbiased: bool. + :param keepdim: if True, the output shape is same as input shape except for the dimension in dim. + :type keepdim: bool. + + Example: + + >>> a = jt.rand(3) + >>> a + jt.Var([0.79613626 0.29322362 0.19785859], dtype=float32) + >>> a.var() + jt.Var([0.06888353], dtype=float32) + >>> a.var(unbiased=True) + jt.Var([0.10332529], dtype=float32) + """ + shape = x.shape + new_shape = list(x.shape) + + assert dim is None or dims is None, "dim and dims can not be both set" + if dim is None and dims is None: + dims = list(range(len(shape))) + elif dim is not None: + dims = [dim] + + mean = jt.mean(x, dims, keepdims=True) + mean = jt.broadcast(mean, shape) + + n = 1 + for d in dims: + n *= shape[d] + new_shape[d] = 1 + + sqr = (x - mean) ** 2 + sqr = jt.sum(sqr, dims=dims, keepdims=False) + if unbiased: + n -= 1 + sqr /= n + + if keepdims: + sqr = sqr.view(new_shape) + return sqr +Var.var = var + def std(x): matsize=1 for i in x.shape: diff --git a/python/jittor/test/test_std.py b/python/jittor/test/test_std.py deleted file mode 100644 index 7eebc2ef..00000000 --- a/python/jittor/test/test_std.py +++ /dev/null @@ -1,43 +0,0 @@ -# *************************************************************** -# Copyright (c) 2021 Jittor. All Rights Reserved. -# Maintainers: -# Dun Liang . -# -# This file is subject to the terms and conditions defined in -# file 'LICENSE.txt', which is part of this source code package. -# *************************************************************** -import unittest -import jittor as jt -import os -import numpy as np -import jittor.nn as jnn - -from jittor.test.test_log import find_log_with_re -skip_this_test = False - -try: - jt.dirty_fix_pytorch_runtime_error() - import torch - import torch.nn as tnn -except: - skip_this_test = True - -@unittest.skipIf(skip_this_test, "No Torch found") -class TestStd(unittest.TestCase): - def test_std(self): - x=np.random.randn(100,1000).astype(np.float32) - jt_x=jt.array(x) - tc_x=torch.from_numpy(x) - assert np.allclose(jt_x.std().numpy(), tc_x.std().numpy(), 1e-4) ,(x, jt_x.std().numpy(), tc_x.std().numpy()) - - def test_norm(self): - x=np.random.randn(100,1000).astype(np.float32) - jt_x=jt.array(x) - tc_x=torch.from_numpy(x) - assert np.allclose(jt_x.norm(1,1).numpy(), tc_x.norm(1,1).numpy()) - assert np.allclose(jt_x.norm(1,0).numpy(), tc_x.norm(1,0).numpy()) - assert np.allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy()) - assert np.allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy()) - -if __name__ == "__main__": - unittest.main() diff --git a/python/jittor/test/test_var.py b/python/jittor/test/test_var.py new file mode 100644 index 00000000..e9e047d8 --- /dev/null +++ b/python/jittor/test/test_var.py @@ -0,0 +1,51 @@ +# *************************************************************** +# Copyright (c) 2021 Jittor. All Rights Reserved. +# Maintainers: +# Dun Liang . +# Zheng-Ning Liu +# This file is subject to the terms and conditions defined in +# file 'LICENSE.txt', which is part of this source code package. +# *************************************************************** +import unittest +import jittor as jt +import numpy as np +import jittor.nn as jnn + +skip_this_test = False + +try: + jt.dirty_fix_pytorch_runtime_error() + import torch + import torch.nn as tnn +except: + skip_this_test = True + +@unittest.skipIf(skip_this_test, "No Torch found") +class TestVarFunctions(unittest.TestCase): + def test_var(self): + x = np.random.randn(100, 1000).astype(np.float32) + + jt_x = jt.array(x) + tc_x = torch.from_numpy(x) + np.testing.assert_allclose(jt_x.var().numpy(), tc_x.var().numpy(), rtol=1e-3, atol=1e-4) + np.testing.assert_allclose(jt_x.var(dim=1).numpy(), tc_x.var(dim=1).numpy(), rtol=1e-3, atol=1e-4) + np.testing.assert_allclose(jt_x.var(dim=0, unbiased=True).numpy(), tc_x.var(dim=0, unbiased=True).numpy(), rtol=1e-3, atol=1e-4) + + def test_std(self): + x=np.random.randn(100, 1000).astype(np.float32) + jt_x = jt.array(x) + tc_x = torch.from_numpy(x) + np.testing.assert_allclose(jt_x.std().numpy(), tc_x.std().numpy(), 1e-4) + + def test_norm(self): + x = np.random.randn(100, 1000).astype(np.float32) + jt_x = jt.array(x) + tc_x = torch.from_numpy(x) + np.testing.assert_allclose(jt_x.norm(1,1).numpy(), tc_x.norm(1,1).numpy(), atol=1e-6) + np.testing.assert_allclose(jt_x.norm(1,0).numpy(), tc_x.norm(1,0).numpy(), atol=1e-6) + np.testing.assert_allclose(jt_x.norm(2,1).numpy(), tc_x.norm(2,1).numpy(), atol=1e-6) + np.testing.assert_allclose(jt_x.norm(2,0).numpy(), tc_x.norm(2,0).numpy(), atol=1e-6) + + +if __name__ == "__main__": + unittest.main()