mirror of https://github.com/Jittor/Jittor
polish pool interface
This commit is contained in:
parent
dff7835847
commit
fba55c7e31
|
@ -14,7 +14,7 @@ deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security m
|
|||
|
||||
RUN apt update && apt install wget \
|
||||
python3.7 python3.7-dev \
|
||||
g++ build-essential -y
|
||||
g++ build-essential openssh-server -y
|
||||
|
||||
WORKDIR /usr/src/jittor
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.59'
|
||||
__version__ = '1.2.2.60'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
|
|
@ -21,7 +21,7 @@ class Pool(Module):
|
|||
assert return_indices == None
|
||||
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
|
||||
self.op = op
|
||||
self.stride = stride if stride else kernel_size
|
||||
stride = stride if stride else kernel_size
|
||||
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
|
||||
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
|
||||
self.ceil_mode = ceil_mode
|
||||
|
@ -207,10 +207,10 @@ def avg_pool2d(x, kernel_size, stride=None, padding=0, ceil_mode=False, count_in
|
|||
|
||||
class MaxPool2d(Module):
|
||||
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
|
||||
self.layer = Pool(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=return_indices, ceil_mode=ceil_mode, op="maximum")
|
||||
self._layer = Pool(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=return_indices, ceil_mode=ceil_mode, op="maximum")
|
||||
|
||||
def execute(self, x):
|
||||
return self.layer(x)
|
||||
return self._layer(x)
|
||||
|
||||
def max_pool2d(x, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
|
||||
return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)(x)
|
|
@ -31,7 +31,7 @@ rand_hooked = False
|
|||
|
||||
def hook_pt_rand(*shape, device=None):
|
||||
import torch
|
||||
if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], torch.Size):
|
||||
if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], (torch.Size, tuple, list)):
|
||||
shape = tuple(shape[0])
|
||||
np.random.seed(0)
|
||||
res = torch.from_numpy(np.random.rand(*tuple(shape)).astype("float32"))
|
||||
|
@ -41,9 +41,10 @@ def hook_pt_rand(*shape, device=None):
|
|||
|
||||
def hook_pt_randn(*shape, device=None):
|
||||
import torch
|
||||
if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], torch.Size):
|
||||
if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], (torch.Size, tuple, list)):
|
||||
shape = tuple(shape[0])
|
||||
np.random.seed(0)
|
||||
print(shape)
|
||||
res = torch.from_numpy(np.random.randn(*tuple(shape)).astype("float32"))
|
||||
if device is not None:
|
||||
return res.to(device)
|
||||
|
@ -269,6 +270,14 @@ class Hook:
|
|||
|
||||
names = []
|
||||
for name, module in mod.named_modules():
|
||||
ns = name.split('.')
|
||||
skip = 0
|
||||
for n in ns:
|
||||
if n.startswith('_'):
|
||||
skip = 1
|
||||
if skip:
|
||||
LOG.i("skip", name)
|
||||
continue
|
||||
name = mod_name + name
|
||||
module.__ad_mod_name__ = name
|
||||
names.append(name)
|
||||
|
|
Loading…
Reference in New Issue