mirror of https://github.com/Jittor/Jittor
serialization and random sampling
This commit is contained in:
parent
6ada098f60
commit
0278356f31
|
@ -404,35 +404,176 @@ def argmin(x, dim, keepdims:bool=False):
|
|||
return x.arg_reduce("min", dim, keepdims)
|
||||
Var.argmin = argmin
|
||||
|
||||
def randn(*size, dtype="float32", requires_grad=True):
|
||||
def randn(*size, dtype="float32", requires_grad=True) -> Var:
|
||||
''' samples random numbers from a standard normal distribution.
|
||||
|
||||
:param size: shape of the output.
|
||||
:type size: int or a sequence of int
|
||||
|
||||
:param dtype: data type, defaults to "float32".
|
||||
:type dtype: str, optional
|
||||
|
||||
:param requires_grad: whether to enable gradient back-propgation, defaults to True.
|
||||
:type requires_grad: bool, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> jt.randn(3)
|
||||
jt.Var([-1.019889 -0.30377278 -1.4948598 ], dtype=float32)
|
||||
>>> jt.randn(2, 3)
|
||||
jt.Var([[-0.15989183 -1.5010914 0.5476955 ]
|
||||
[-0.612632 -1.1471151 -1.1879086 ]], dtype=float32)
|
||||
'''
|
||||
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
|
||||
arr = jt.random(size, dtype, "normal")
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def rand(*size, dtype="float32", requires_grad=True):
|
||||
def rand(*size, dtype="float32", requires_grad=True) -> Var:
|
||||
''' samples random numbers from a uniform distribution on the interval [0, 1).
|
||||
|
||||
:param size: shape of the output.
|
||||
:type size: int or a sequence of int
|
||||
|
||||
:param dtype: data type, defaults to "float32".
|
||||
:type dtype: str, optional
|
||||
|
||||
:param requires_grad: whether to enable gradient back-propgation. defaults to True.
|
||||
:type requires_grad: bool, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> jt.rand(3)
|
||||
jt.Var([0.31005102 0.02765604 0.8150749 ], dtype=float32)
|
||||
>>> jt.rand(2, 3)
|
||||
jt.Var([[0.96414304 0.3519264 0.8268017 ]
|
||||
[0.05658621 0.04449705 0.86190987]], dtype=float32)
|
||||
'''
|
||||
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
|
||||
arr = jt.random(size, dtype)
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def rand_like(x, dtype=None):
|
||||
def rand_like(x, dtype=None) -> Var:
|
||||
''' samples random values from standard uniform distribution with the same shape as x.
|
||||
|
||||
:param x: reference variable.
|
||||
:type x: jt.Var
|
||||
|
||||
:param dtype: if None, the dtype of the output is the same as x.
|
||||
Otherwise, use the specified dtype. Defaults to None.
|
||||
:type dtype: str, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = jt.zeros((2, 3))
|
||||
>>> jt.rand_like(x)
|
||||
jt.Var([[0.6164821 0.21476883 0.61959815]
|
||||
[0.58626485 0.35345772 0.5638483 ]], dtype=float32)
|
||||
'''
|
||||
if dtype is None: dtype = x.dtype
|
||||
return jt.random(x.shape, x.dtype)
|
||||
|
||||
def randn_like(x, dtype=None):
|
||||
def randn_like(x, dtype=None) -> Var:
|
||||
''' samples random values from standard normal distribution with the same shape as x.
|
||||
|
||||
:param x: reference variable.
|
||||
:type x: jt.Var
|
||||
|
||||
:param dtype: if None, the dtype of the output is the same as x.
|
||||
Otherwise, use the specified dtype. Defaults to None.
|
||||
:type dtype: str, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = jt.zeros((2, 3))
|
||||
>>> jt.randn_like(x)
|
||||
jt.Var([[-1.1647032 0.34847224 -1.3061888 ]
|
||||
[ 1.068085 -0.34366122 0.13172573]], dtype=float32)
|
||||
'''
|
||||
if dtype is None: dtype = x.dtype
|
||||
return jt.random(x.shape, x.dtype, "normal")
|
||||
|
||||
def randint(low, high=None, shape=(1,), dtype="int32"):
|
||||
def randint(low, high=None, shape=(1,), dtype="int32") -> Var:
|
||||
''' samples random integers from a uniform distribution on the interval [low, high).
|
||||
|
||||
:param low: lowest intergers to be drawn from the distribution, defaults to 0.
|
||||
:type low: int, optional
|
||||
|
||||
:param high: One above the highest integer to be drawn from the distribution.
|
||||
:type high: int
|
||||
|
||||
:param shape: shape of the output size, defaults to (1,).
|
||||
:type shape: tuple, optional
|
||||
|
||||
:param dtype: data type of the output, defaults to "int32".
|
||||
:type dtype: str, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> jt.randint(3, shape=(3, 3))
|
||||
jt.Var([[2 0 2]
|
||||
[2 1 2]
|
||||
[2 0 1]], dtype=int32)
|
||||
>>> jt.randint(1, 3, shape=(3, 3))
|
||||
jt.Var([[2 2 2]
|
||||
[1 1 2]
|
||||
[1 1 1]], dtype=int32)
|
||||
'''
|
||||
if high is None: low, high = 0, low
|
||||
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
|
||||
return v.astype(dtype)
|
||||
|
||||
def randint_like(x, low, high=None):
|
||||
def randint_like(x, low, high=None) -> Var:
|
||||
''' samples random values from standard normal distribution with the same shape as x.
|
||||
|
||||
:param x: reference variable.
|
||||
:type x: jt.Var
|
||||
|
||||
:param low: lowest intergers to be drawn from the distribution, defaults to 0.
|
||||
:type low: int, optional
|
||||
|
||||
:param high: One above the highest integer to be drawn from the distribution.
|
||||
:type high: int
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = jt.zeros((2, 3))
|
||||
>>> jt.randint_like(x)
|
||||
jt.Var([[-1.1647032 0.34847224 -1.3061888 ]
|
||||
[ 1.068085 -0.34366122 0.13172573]], dtype=float32)
|
||||
'''
|
||||
|
||||
return randint(low, high, x.shape, x.dtype)
|
||||
|
||||
def normal(mean, std, size=None, dtype="float32"):
|
||||
def normal(mean, std, size=None, dtype="float32") -> Var:
|
||||
''' samples random values from a normal distribution.
|
||||
|
||||
:param mean: means of the normal distributions.
|
||||
:type mean: int or jt.Var
|
||||
|
||||
:param std: standard deviations of the normal distributions.
|
||||
:type std: int or jt.Var
|
||||
|
||||
:param size: shape of the output size. if not specified, the
|
||||
shape of the output is determined by mean or std. Exception will be
|
||||
raised if mean and std are all integers or have different shape in
|
||||
this case. Defaults to None
|
||||
:type size: tuple, optional
|
||||
|
||||
:param dtype: data type of the output, defaults to "float32".
|
||||
:type dtype: str, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> jt.normal(5, 3, size=(2,3))
|
||||
jt.Var([[ 8.070848 7.654219 10.252696 ]
|
||||
[ 6.383718 7.8817277 3.0786133]], dtype=float32)
|
||||
>>> mean = jt.randint(low=0, high=10, shape=(10,))
|
||||
>>> jt.normal(mean, 0.1)
|
||||
jt.Var([1.9524184 1.0749301 7.9864206 5.9407325 8.1596155 4.824019 7.955083
|
||||
8.972998 6.0674286 8.88026 ], dtype=float32)
|
||||
'''
|
||||
if size is None:
|
||||
if isinstance(mean, Var) and isinstance(std, Var):
|
||||
assert mean.shape == std.shape
|
||||
|
@ -507,7 +648,9 @@ def display_memory_info():
|
|||
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
|
||||
core.display_memory_info(fileline)
|
||||
|
||||
def load(path):
|
||||
def load(path: str):
|
||||
''' loads an object from a file.
|
||||
'''
|
||||
if path.endswith(".pth"):
|
||||
try:
|
||||
dirty_fix_pytorch_runtime_error()
|
||||
|
@ -519,7 +662,14 @@ def load(path):
|
|||
model_dict = safeunpickle(path)
|
||||
return model_dict
|
||||
|
||||
def save(params_dict, path):
|
||||
def save(params_dict, path: str):
|
||||
''' saves the parameter dictionary to a file.
|
||||
|
||||
:param params_dict: parameters to be saved
|
||||
:type params_dict: list or dictionary
|
||||
:param path: file path
|
||||
:type path: str
|
||||
'''
|
||||
def dfs(x):
|
||||
if isinstance(x, list):
|
||||
for i in range(len(x)):
|
||||
|
@ -695,6 +845,10 @@ class Module:
|
|||
func(m)
|
||||
|
||||
def load_parameters(self, params):
|
||||
''' loads parameters to the Module.
|
||||
|
||||
:param params: dictionary of parameter names and parameters.
|
||||
'''
|
||||
n_failed = 0
|
||||
for key in params.keys():
|
||||
v = self
|
||||
|
@ -735,14 +889,54 @@ class Module:
|
|||
if n_failed:
|
||||
LOG.w(f"load total {len(params)} params, {n_failed} failed")
|
||||
|
||||
def save(self, path):
|
||||
def save(self, path: str):
|
||||
''' saves parameters to a file.
|
||||
|
||||
:param path: path to save.
|
||||
:type path: str
|
||||
|
||||
Example:
|
||||
|
||||
>>> class Net(nn.Module):
|
||||
>>> ...
|
||||
>>> net = Net()
|
||||
>>> net.save('net.pkl')
|
||||
>>> net.load('net.pkl')
|
||||
'''
|
||||
params = self.parameters()
|
||||
params_dict = {}
|
||||
for p in params:
|
||||
params_dict[p.name()] = p.data
|
||||
safepickle(params_dict, path)
|
||||
|
||||
def load(self, path):
|
||||
def load(self, path: str):
|
||||
''' loads parameters from a file.
|
||||
|
||||
:param path: path to load.
|
||||
:type path: str
|
||||
|
||||
Example:
|
||||
|
||||
>>> class Net(nn.Module):
|
||||
>>> ...
|
||||
>>> net = Net()
|
||||
>>> net.save('net.pkl')
|
||||
>>> net.load('net.pkl')
|
||||
|
||||
.. note::
|
||||
当载入的参数与模型定义不一致时, jittor 会输出错误信息, 但是不会抛出异常.
|
||||
若载入参数出现模型定义中没有的参数名, 则会输出如下信息, 并忽略此参数:
|
||||
|
||||
>>> [w 0205 21:49:39.962762 96 __init__.py:723] load parameter w failed ...
|
||||
|
||||
若载入参数的 shape 与模型定义不一致, 则会输出如下信息, 并忽略此参数:
|
||||
|
||||
>>> [e 0205 21:49:39.962822 96 __init__.py:739] load parameter w failed: expect the shape of w to be [1000,100,], but got [3,100,100,]
|
||||
|
||||
如载入过程中出现错误, jittor 会输出概要信息, 您需要仔细核对错误信息
|
||||
|
||||
>>> [w 0205 21:49:39.962906 96 __init__.py:741] load total 100 params, 3 failed
|
||||
'''
|
||||
self.load_parameters(load(path))
|
||||
|
||||
def eval(self):
|
||||
|
@ -771,7 +965,7 @@ class Module:
|
|||
if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]:
|
||||
p.start_grad()
|
||||
|
||||
def is_training(self):
|
||||
def is_training(self) -> bool:
|
||||
if not hasattr(self, "is_train"):
|
||||
self.is_train = True
|
||||
return self.is_train
|
||||
|
|
Loading…
Reference in New Issue