mirror of https://github.com/Jittor/Jittor
Merge pull request #171 from Jittor/randn
add more pool & random method
This commit is contained in:
commit
43bc415710
|
@ -7,7 +7,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.10'
|
||||
__version__ = '1.2.2.11'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -960,3 +960,26 @@ from . import numpy2cupy
|
|||
from .contrib import concat
|
||||
from .misc import *
|
||||
from . import sparse
|
||||
|
||||
|
||||
def randn(*size, dtype="float32", requires_grad=False):
|
||||
if isinstance(size, tuple) and isinstance(size[0], tuple): size = size[0]
|
||||
arr = jt.random(size, dtype, "normal")
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def rand(*size, dtype="float32", requires_grad=False):
|
||||
if isinstance(size, tuple) and isinstance(size[0], tuple): size = size[0]
|
||||
arr = jt.random(size, dtype)
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def normal(mean, std, size=None, dtype="float32"):
|
||||
if size is None:
|
||||
if isinstance(mean, Var) and isinstance(std, Var):
|
||||
assert mean.shape == std.shape
|
||||
size = mean.shape
|
||||
else:
|
||||
if isinstance(mean, Var): size = mean.shape
|
||||
if isinstance(std, Var): size = std.shape
|
||||
return jt.init.gauss(size, dtype, mean, std)
|
|
@ -16,7 +16,7 @@ import numpy as np
|
|||
import collections
|
||||
import math
|
||||
from collections import OrderedDict
|
||||
from jittor.pool import Pool, pool, AdaptiveAvgPool2d
|
||||
from jittor.pool import *
|
||||
from jittor.optim import *
|
||||
from jittor.misc import _pair
|
||||
|
||||
|
|
|
@ -191,4 +191,24 @@ class AdaptiveAvgPool2d(Module):
|
|||
return xx.reduce("mean", [4,5])
|
||||
|
||||
def pool(x, kernel_size, op, padding=0, stride=None):
|
||||
return Pool(kernel_size, stride, padding, op=op)(x)
|
||||
return Pool(kernel_size, stride, padding, op=op)(x)
|
||||
|
||||
class AvgPool2d(Module):
|
||||
def __init__(self, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
||||
self.layer = Pool(kernel_size=kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode, count_include_pad=count_include_pad, op="mean")
|
||||
|
||||
def execute(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
def avg_pool2d(x, kernel_size, stride=None, padding=0, ceil_mode=False, count_include_pad=True):
|
||||
return AvgPool2d(kernel_size, stride, padding, ceil_mode, count_include_pad)(x)
|
||||
|
||||
class MaxPool2d(Module):
|
||||
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
|
||||
self.layer = Pool(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=return_indices, ceil_mode=ceil_mode, op="maximum")
|
||||
|
||||
def execute(self, x):
|
||||
return self.layer(x)
|
||||
|
||||
def max_pool2d(x, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
|
||||
return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)(x)
|
|
@ -5,7 +5,9 @@
|
|||
# ***************************************************************
|
||||
import unittest
|
||||
import jittor as jt
|
||||
from jittor.nn import Pool, pool
|
||||
from jittor.nn import Pool, pool, AvgPool2d, avg_pool2d
|
||||
from jittor.nn import MaxPool2d as j_MaxPool2d
|
||||
from jittor.nn import max_pool2d as j_max_pool2d
|
||||
import numpy as np
|
||||
from .test_core import expect_error
|
||||
from .test_grad import ngrad
|
||||
|
@ -101,7 +103,7 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
check(jt_model, torch_model, shape, False)
|
||||
for i in range(10):
|
||||
check(jt_model, torch_model, [1,1,300,300], True)
|
||||
|
||||
|
||||
def test_cpu_(self):
|
||||
# x = jt.random([32, 128, 157, 300])
|
||||
x = jt.random([4, 128, 157, 300])
|
||||
|
@ -138,5 +140,50 @@ class TestArgPoolOp(unittest.TestCase):
|
|||
shape = (2, 16, 33, 33)
|
||||
check(jt_model, torch_model, shape, False)
|
||||
|
||||
def test_AvgPool2d(self):
|
||||
from torch.nn import AvgPool2d as t_AvgPool2d
|
||||
jt_model = AvgPool2d(3, 1, 1, ceil_mode=True)
|
||||
torch_model = t_AvgPool2d(3, 1, 1, ceil_mode=True)
|
||||
shape = (2, 16, 33, 33)
|
||||
check(jt_model, torch_model, shape, False)
|
||||
|
||||
jt_model = AvgPool2d(3, 1, 1, ceil_mode=True, count_include_pad=False)
|
||||
torch_model = t_AvgPool2d(3, 1, 1, ceil_mode=True, count_include_pad=False)
|
||||
shape = (2, 16, 100, 100)
|
||||
check(jt_model, torch_model, shape, False)
|
||||
print('finish')
|
||||
|
||||
def test_avg_pool2d(self):
|
||||
from torch.nn.functional import avg_pool2d as t_avg_pool2d
|
||||
arr = np.random.random((2, 16, 33, 33))
|
||||
jt_model = avg_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True)
|
||||
torch_model = t_avg_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True)
|
||||
assert np.allclose(jt_model.numpy(), torch_model.numpy())
|
||||
|
||||
jt_model = avg_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True, count_include_pad=False)
|
||||
torch_model = t_avg_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True, count_include_pad=False)
|
||||
assert np.allclose(jt_model.numpy(), torch_model.numpy())
|
||||
print('finish')
|
||||
|
||||
def test_MaxPool2d(self):
|
||||
from torch.nn import MaxPool2d
|
||||
jt_model = j_MaxPool2d(3, 1, 1, ceil_mode=True)
|
||||
torch_model = MaxPool2d(3, 1, 1, ceil_mode=True)
|
||||
shape = (2, 16, 33, 33)
|
||||
check(jt_model, torch_model, shape, False)
|
||||
print('finish')
|
||||
|
||||
def test_max_pool2d(self):
|
||||
from torch.nn.functional import max_pool2d
|
||||
arr = np.random.random((2, 16, 33, 33))
|
||||
jt_model = j_max_pool2d(jt.array(arr), 3, 1, 1, ceil_mode=True)
|
||||
torch_model = max_pool2d(torch.Tensor(arr), 3, 1, 1, ceil_mode=True)
|
||||
assert np.allclose(jt_model.numpy(), torch_model.numpy())
|
||||
|
||||
jt_model = j_max_pool2d(jt.array(arr), 3, 1, 1)
|
||||
torch_model = max_pool2d(torch.Tensor(arr), 3, 1, 1)
|
||||
assert np.allclose(jt_model.numpy(), torch_model.numpy())
|
||||
print('finish')
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -18,6 +18,12 @@ import unittest
|
|||
from .test_reorder_tuner import simple_parser
|
||||
from .test_log import find_log_with_re
|
||||
|
||||
try:
|
||||
jt.dirty_fix_pytorch_runtime_error()
|
||||
import torch
|
||||
except:
|
||||
skip_this_test = True
|
||||
|
||||
class TestRandomOp(unittest.TestCase):
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
|
@ -56,6 +62,26 @@ class TestRandomOp(unittest.TestCase):
|
|||
assert (np.abs((data<(1)).mean() - 0.5) < 0.1)
|
||||
assert (np.abs((data<(1+3)).mean() - (1-r)) < 0.1)
|
||||
|
||||
np_res = np.random.normal(1, 0.1, (100, 100))
|
||||
jt_res = jt.normal(1., 0.1, (100, 100))
|
||||
assert (np.abs(np_res.mean() - jt_res.data.mean()) < 0.1)
|
||||
assert (np.abs(np_res.std() - jt_res.data.std()) < 0.1)
|
||||
|
||||
np_res = torch.normal(torch.arange(1., 10000.), 1)
|
||||
jt_res = jt.normal(jt.arange(1, 10000), 1)
|
||||
assert (np.abs(np_res.mean() - jt_res.data.mean()) < 0.1)
|
||||
assert (np.abs(np_res.std() - jt_res.data.std()) < 1)
|
||||
|
||||
np_res = np.random.randn(100, 100)
|
||||
jt_res = jt.randn(100, 100)
|
||||
assert (np.abs(np_res.mean() - jt_res.data.mean()) < 0.1)
|
||||
assert (np.abs(np_res.std() - jt_res.data.std()) < 0.1)
|
||||
|
||||
np_res = np.random.rand(100, 100)
|
||||
jt_res = jt.rand(100, 100)
|
||||
assert (np.abs(np_res.mean() - jt_res.data.mean()) < 0.1)
|
||||
assert (np.abs(np_res.std() - jt_res.data.std()) < 0.1)
|
||||
|
||||
@unittest.skipIf(not jt.has_cuda, "Cuda not found")
|
||||
@jt.flag_scope(use_cuda=1)
|
||||
def test_normal_cuda(self):
|
||||
|
|
Loading…
Reference in New Issue