serialization and random sampling

This commit is contained in:
lzhengning 2021-02-07 21:27:28 +08:00
parent 6ada098f60
commit 0278356f31
1 changed files with 206 additions and 12 deletions

View File

@ -404,35 +404,176 @@ def argmin(x, dim, keepdims:bool=False):
return x.arg_reduce("min", dim, keepdims)
Var.argmin = argmin
def randn(*size, dtype="float32", requires_grad=True):
def randn(*size, dtype="float32", requires_grad=True) -> Var:
''' samples random numbers from a standard normal distribution.
:param size: shape of the output.
:type size: int or a sequence of int
:param dtype: data type, defaults to "float32".
:type dtype: str, optional
:param requires_grad: whether to enable gradient back-propgation, defaults to True.
:type requires_grad: bool, optional
Example:
>>> jt.randn(3)
jt.Var([-1.019889 -0.30377278 -1.4948598 ], dtype=float32)
>>> jt.randn(2, 3)
jt.Var([[-0.15989183 -1.5010914 0.5476955 ]
[-0.612632 -1.1471151 -1.1879086 ]], dtype=float32)
'''
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
arr = jt.random(size, dtype, "normal")
if not requires_grad: return arr.stop_grad()
return arr
def rand(*size, dtype="float32", requires_grad=True):
def rand(*size, dtype="float32", requires_grad=True) -> Var:
''' samples random numbers from a uniform distribution on the interval [0, 1).
:param size: shape of the output.
:type size: int or a sequence of int
:param dtype: data type, defaults to "float32".
:type dtype: str, optional
:param requires_grad: whether to enable gradient back-propgation. defaults to True.
:type requires_grad: bool, optional
Example:
>>> jt.rand(3)
jt.Var([0.31005102 0.02765604 0.8150749 ], dtype=float32)
>>> jt.rand(2, 3)
jt.Var([[0.96414304 0.3519264 0.8268017 ]
[0.05658621 0.04449705 0.86190987]], dtype=float32)
'''
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
arr = jt.random(size, dtype)
if not requires_grad: return arr.stop_grad()
return arr
def rand_like(x, dtype=None):
def rand_like(x, dtype=None) -> Var:
''' samples random values from standard uniform distribution with the same shape as x.
:param x: reference variable.
:type x: jt.Var
:param dtype: if None, the dtype of the output is the same as x.
Otherwise, use the specified dtype. Defaults to None.
:type dtype: str, optional
Example:
>>> x = jt.zeros((2, 3))
>>> jt.rand_like(x)
jt.Var([[0.6164821 0.21476883 0.61959815]
[0.58626485 0.35345772 0.5638483 ]], dtype=float32)
'''
if dtype is None: dtype = x.dtype
return jt.random(x.shape, x.dtype)
def randn_like(x, dtype=None):
def randn_like(x, dtype=None) -> Var:
''' samples random values from standard normal distribution with the same shape as x.
:param x: reference variable.
:type x: jt.Var
:param dtype: if None, the dtype of the output is the same as x.
Otherwise, use the specified dtype. Defaults to None.
:type dtype: str, optional
Example:
>>> x = jt.zeros((2, 3))
>>> jt.randn_like(x)
jt.Var([[-1.1647032 0.34847224 -1.3061888 ]
[ 1.068085 -0.34366122 0.13172573]], dtype=float32)
'''
if dtype is None: dtype = x.dtype
return jt.random(x.shape, x.dtype, "normal")
def randint(low, high=None, shape=(1,), dtype="int32"):
def randint(low, high=None, shape=(1,), dtype="int32") -> Var:
''' samples random integers from a uniform distribution on the interval [low, high).
:param low: lowest intergers to be drawn from the distribution, defaults to 0.
:type low: int, optional
:param high: One above the highest integer to be drawn from the distribution.
:type high: int
:param shape: shape of the output size, defaults to (1,).
:type shape: tuple, optional
:param dtype: data type of the output, defaults to "int32".
:type dtype: str, optional
Example:
>>> jt.randint(3, shape=(3, 3))
jt.Var([[2 0 2]
[2 1 2]
[2 0 1]], dtype=int32)
>>> jt.randint(1, 3, shape=(3, 3))
jt.Var([[2 2 2]
[1 1 2]
[1 1 1]], dtype=int32)
'''
if high is None: low, high = 0, low
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
return v.astype(dtype)
def randint_like(x, low, high=None):
def randint_like(x, low, high=None) -> Var:
''' samples random values from standard normal distribution with the same shape as x.
:param x: reference variable.
:type x: jt.Var
:param low: lowest intergers to be drawn from the distribution, defaults to 0.
:type low: int, optional
:param high: One above the highest integer to be drawn from the distribution.
:type high: int
Example:
>>> x = jt.zeros((2, 3))
>>> jt.randint_like(x)
jt.Var([[-1.1647032 0.34847224 -1.3061888 ]
[ 1.068085 -0.34366122 0.13172573]], dtype=float32)
'''
return randint(low, high, x.shape, x.dtype)
def normal(mean, std, size=None, dtype="float32"):
def normal(mean, std, size=None, dtype="float32") -> Var:
''' samples random values from a normal distribution.
:param mean: means of the normal distributions.
:type mean: int or jt.Var
:param std: standard deviations of the normal distributions.
:type std: int or jt.Var
:param size: shape of the output size. if not specified, the
shape of the output is determined by mean or std. Exception will be
raised if mean and std are all integers or have different shape in
this case. Defaults to None
:type size: tuple, optional
:param dtype: data type of the output, defaults to "float32".
:type dtype: str, optional
Example:
>>> jt.normal(5, 3, size=(2,3))
jt.Var([[ 8.070848 7.654219 10.252696 ]
[ 6.383718 7.8817277 3.0786133]], dtype=float32)
>>> mean = jt.randint(low=0, high=10, shape=(10,))
>>> jt.normal(mean, 0.1)
jt.Var([1.9524184 1.0749301 7.9864206 5.9407325 8.1596155 4.824019 7.955083
8.972998 6.0674286 8.88026 ], dtype=float32)
'''
if size is None:
if isinstance(mean, Var) and isinstance(std, Var):
assert mean.shape == std.shape
@ -507,7 +648,9 @@ def display_memory_info():
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
core.display_memory_info(fileline)
def load(path):
def load(path: str):
''' loads an object from a file.
'''
if path.endswith(".pth"):
try:
dirty_fix_pytorch_runtime_error()
@ -519,7 +662,14 @@ def load(path):
model_dict = safeunpickle(path)
return model_dict
def save(params_dict, path):
def save(params_dict, path: str):
''' saves the parameter dictionary to a file.
:param params_dict: parameters to be saved
:type params_dict: list or dictionary
:param path: file path
:type path: str
'''
def dfs(x):
if isinstance(x, list):
for i in range(len(x)):
@ -695,6 +845,10 @@ class Module:
func(m)
def load_parameters(self, params):
''' loads parameters to the Module.
:param params: dictionary of parameter names and parameters.
'''
n_failed = 0
for key in params.keys():
v = self
@ -735,14 +889,54 @@ class Module:
if n_failed:
LOG.w(f"load total {len(params)} params, {n_failed} failed")
def save(self, path):
def save(self, path: str):
''' saves parameters to a file.
:param path: path to save.
:type path: str
Example:
>>> class Net(nn.Module):
>>> ...
>>> net = Net()
>>> net.save('net.pkl')
>>> net.load('net.pkl')
'''
params = self.parameters()
params_dict = {}
for p in params:
params_dict[p.name()] = p.data
safepickle(params_dict, path)
def load(self, path):
def load(self, path: str):
''' loads parameters from a file.
:param path: path to load.
:type path: str
Example:
>>> class Net(nn.Module):
>>> ...
>>> net = Net()
>>> net.save('net.pkl')
>>> net.load('net.pkl')
.. note::
当载入的参数与模型定义不一致时, jittor 会输出错误信息, 但是不会抛出异常.
若载入参数出现模型定义中没有的参数名, 则会输出如下信息, 并忽略此参数:
>>> [w 0205 21:49:39.962762 96 __init__.py:723] load parameter w failed ...
若载入参数的 shape 与模型定义不一致, 则会输出如下信息, 并忽略此参数:
>>> [e 0205 21:49:39.962822 96 __init__.py:739] load parameter w failed: expect the shape of w to be [1000,100,], but got [3,100,100,]
如载入过程中出现错误, jittor 会输出概要信息, 您需要仔细核对错误信息
>>> [w 0205 21:49:39.962906 96 __init__.py:741] load total 100 params, 3 failed
'''
self.load_parameters(load(path))
def eval(self):
@ -771,7 +965,7 @@ class Module:
if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]:
p.start_grad()
def is_training(self):
def is_training(self) -> bool:
if not hasattr(self, "is_train"):
self.is_train = True
return self.is_train