polish pool interface

This commit is contained in:
Dun Liang 2021-04-15 14:33:51 +08:00
parent dff7835847
commit fba55c7e31
4 changed files with 16 additions and 7 deletions

View File

@ -14,7 +14,7 @@ deb [trusted=yes] https://mirrors.tuna.tsinghua.edu.cn/ubuntu/ bionic-security m
RUN apt update && apt install wget \
python3.7 python3.7-dev \
g++ build-essential -y
g++ build-essential openssh-server -y
WORKDIR /usr/src/jittor

View File

@ -8,7 +8,7 @@
# This file is subject to the terms and conditions defined in
# file 'LICENSE.txt', which is part of this source code package.
# ***************************************************************
__version__ = '1.2.2.59'
__version__ = '1.2.2.60'
from . import lock
with lock.lock_scope():
ori_int = int

View File

@ -21,7 +21,7 @@ class Pool(Module):
assert return_indices == None
self.kernel_size = kernel_size if isinstance(kernel_size, tuple) else (kernel_size, kernel_size)
self.op = op
self.stride = stride if stride else kernel_size
stride = stride if stride else kernel_size
self.stride = stride if isinstance(stride, tuple) else (stride, stride)
self.padding = padding if isinstance(padding, tuple) else (padding, padding)
self.ceil_mode = ceil_mode
@ -207,10 +207,10 @@ def avg_pool2d(x, kernel_size, stride=None, padding=0, ceil_mode=False, count_in
class MaxPool2d(Module):
def __init__(self, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
self.layer = Pool(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=return_indices, ceil_mode=ceil_mode, op="maximum")
self._layer = Pool(kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, return_indices=return_indices, ceil_mode=ceil_mode, op="maximum")
def execute(self, x):
return self.layer(x)
return self._layer(x)
def max_pool2d(x, kernel_size, stride=None, padding=0, dilation=None, return_indices=None, ceil_mode=False):
return MaxPool2d(kernel_size, stride, padding, dilation, return_indices, ceil_mode)(x)

View File

@ -31,7 +31,7 @@ rand_hooked = False
def hook_pt_rand(*shape, device=None):
import torch
if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], torch.Size):
if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], (torch.Size, tuple, list)):
shape = tuple(shape[0])
np.random.seed(0)
res = torch.from_numpy(np.random.rand(*tuple(shape)).astype("float32"))
@ -41,9 +41,10 @@ def hook_pt_rand(*shape, device=None):
def hook_pt_randn(*shape, device=None):
import torch
if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], torch.Size):
if isinstance(shape, tuple) and len(shape)==1 and isinstance(shape[0], (torch.Size, tuple, list)):
shape = tuple(shape[0])
np.random.seed(0)
print(shape)
res = torch.from_numpy(np.random.randn(*tuple(shape)).astype("float32"))
if device is not None:
return res.to(device)
@ -269,6 +270,14 @@ class Hook:
names = []
for name, module in mod.named_modules():
ns = name.split('.')
skip = 0
for n in ns:
if n.startswith('_'):
skip = 1
if skip:
LOG.i("skip", name)
continue
name = mod_name + name
module.__ad_mod_name__ = name
names.append(name)