Merge pull request #179 from lzhengning/doc

[Doc]: Serialization and random sampling
This commit is contained in:
Jittor 2021-02-20 17:57:23 +08:00 committed by GitHub
commit 76eeaeb961
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 273 additions and 19 deletions

View File

@ -404,35 +404,180 @@ def argmin(x, dim, keepdims:bool=False):
return x.arg_reduce("min", dim, keepdims)
Var.argmin = argmin
def randn(*size, dtype="float32", requires_grad=True):
def randn(*size, dtype="float32", requires_grad=True) -> Var:
''' samples random numbers from a standard normal distribution.
:param size: shape of the output.
:type size: int or a sequence of int
:param dtype: data type, defaults to "float32".
:type dtype: str, optional
:param requires_grad: whether to enable gradient back-propgation, defaults to True.
:type requires_grad: bool, optional
Example:
>>> jt.randn(3)
jt.Var([-1.019889 -0.30377278 -1.4948598 ], dtype=float32)
>>> jt.randn(2, 3)
jt.Var([[-0.15989183 -1.5010914 0.5476955 ]
[-0.612632 -1.1471151 -1.1879086 ]], dtype=float32)
'''
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
arr = jt.random(size, dtype, "normal")
if not requires_grad: return arr.stop_grad()
return arr
def rand(*size, dtype="float32", requires_grad=True):
def rand(*size, dtype="float32", requires_grad=True) -> Var:
''' samples random numbers from a uniform distribution on the interval [0, 1).
:param size: shape of the output.
:type size: int or a sequence of int
:param dtype: data type, defaults to "float32".
:type dtype: str, optional
:param requires_grad: whether to enable gradient back-propgation. defaults to True.
:type requires_grad: bool, optional
Example:
>>> jt.rand(3)
jt.Var([0.31005102 0.02765604 0.8150749 ], dtype=float32)
>>> jt.rand(2, 3)
jt.Var([[0.96414304 0.3519264 0.8268017 ]
[0.05658621 0.04449705 0.86190987]], dtype=float32)
'''
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
arr = jt.random(size, dtype)
if not requires_grad: return arr.stop_grad()
return arr
def rand_like(x, dtype=None):
def rand_like(x, dtype=None) -> Var:
''' samples random values from standard uniform distribution with the same shape as x.
:param x: reference variable.
:type x: jt.Var
:param dtype: if None, the dtype of the output is the same as x.
Otherwise, use the specified dtype. Defaults to None.
:type dtype: str, optional
Example:
>>> x = jt.zeros((2, 3))
>>> jt.rand_like(x)
jt.Var([[0.6164821 0.21476883 0.61959815]
[0.58626485 0.35345772 0.5638483 ]], dtype=float32)
'''
if dtype is None: dtype = x.dtype
return jt.random(x.shape, x.dtype)
def randn_like(x, dtype=None):
def randn_like(x, dtype=None) -> Var:
''' samples random values from standard normal distribution with the same shape as x.
:param x: reference variable.
:type x: jt.Var
:param dtype: if None, the dtype of the output is the same as x.
Otherwise, use the specified dtype. Defaults to None.
:type dtype: str, optional
Example:
>>> x = jt.zeros((2, 3))
>>> jt.randn_like(x)
jt.Var([[-1.1647032 0.34847224 -1.3061888 ]
[ 1.068085 -0.34366122 0.13172573]], dtype=float32)
'''
if dtype is None: dtype = x.dtype
return jt.random(x.shape, x.dtype, "normal")
def randint(low, high=None, shape=(1,), dtype="int32"):
def randint(low, high=None, shape=(1,), dtype="int32") -> Var:
''' samples random integers from a uniform distribution on the interval [low, high).
:param low: lowest intergers to be drawn from the distribution, defaults to 0.
:type low: int, optional
:param high: One above the highest integer to be drawn from the distribution.
:type high: int
:param shape: shape of the output size, defaults to (1,).
:type shape: tuple, optional
:param dtype: data type of the output, defaults to "int32".
:type dtype: str, optional
Example:
>>> jt.randint(3, shape=(3, 3))
jt.Var([[2 0 2]
[2 1 2]
[2 0 1]], dtype=int32)
>>> jt.randint(1, 3, shape=(3, 3))
jt.Var([[2 2 2]
[1 1 2]
[1 1 1]], dtype=int32)
'''
if high is None: low, high = 0, low
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
v = jt.floor(v)
return v.astype(dtype)
def randint_like(x, low, high=None):
def randint_like(x, low, high=None) -> Var:
''' samples random values from standard normal distribution with the same shape as x.
:param x: reference variable.
:type x: jt.Var
:param low: lowest intergers to be drawn from the distribution, defaults to 0.
:type low: int, optional
:param high: One above the highest integer to be drawn from the distribution.
:type high: int
Example:
>>> x = jt.zeros((2, 3))
>>> jt.randint_like(x, 10)
jt.Var([[9. 3. 4.]
[4. 8. 5.]], dtype=float32)
>>> jt.randint_like(x, 10, 20)
jt.Var([[17. 11. 18.]
[14. 17. 15.]], dtype=float32)
'''
return randint(low, high, x.shape, x.dtype)
def normal(mean, std, size=None, dtype="float32"):
def normal(mean, std, size=None, dtype="float32") -> Var:
''' samples random values from a normal distribution.
:param mean: means of the normal distributions.
:type mean: int or jt.Var
:param std: standard deviations of the normal distributions.
:type std: int or jt.Var
:param size: shape of the output size. if not specified, the
shape of the output is determined by mean or std. Exception will be
raised if mean and std are all integers or have different shape in
this case. Defaults to None
:type size: tuple, optional
:param dtype: data type of the output, defaults to "float32".
:type dtype: str, optional
Example:
>>> jt.normal(5, 3, size=(2,3))
jt.Var([[ 8.070848 7.654219 10.252696 ]
[ 6.383718 7.8817277 3.0786133]], dtype=float32)
>>> mean = jt.randint(low=0, high=10, shape=(10,))
>>> jt.normal(mean, 0.1)
jt.Var([1.9524184 1.0749301 7.9864206 5.9407325 8.1596155 4.824019 7.955083
8.972998 6.0674286 8.88026 ], dtype=float32)
'''
if size is None:
if isinstance(mean, Var) and isinstance(std, Var):
assert mean.shape == std.shape
@ -507,7 +652,9 @@ def display_memory_info():
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
core.display_memory_info(fileline)
def load(path):
def load(path: str):
''' loads an object from a file.
'''
if path.endswith(".pth"):
try:
dirty_fix_pytorch_runtime_error()
@ -519,7 +666,14 @@ def load(path):
model_dict = safeunpickle(path)
return model_dict
def save(params_dict, path):
def save(params_dict, path: str):
''' saves the parameter dictionary to a file.
:param params_dict: parameters to be saved
:type params_dict: list or dictionary
:param path: file path
:type path: str
'''
def dfs(x):
if isinstance(x, list):
for i in range(len(x)):
@ -695,6 +849,10 @@ class Module:
func(m)
def load_parameters(self, params):
''' loads parameters to the Module.
:param params: dictionary of parameter names and parameters.
'''
n_failed = 0
for key in params.keys():
v = self
@ -740,14 +898,54 @@ class Module:
if n_failed:
LOG.w(f"load total {len(params)} params, {n_failed} failed")
def save(self, path):
def save(self, path: str):
''' saves parameters to a file.
:param path: path to save.
:type path: str
Example:
>>> class Net(nn.Module):
>>> ...
>>> net = Net()
>>> net.save('net.pkl')
>>> net.load('net.pkl')
'''
params = self.parameters()
params_dict = {}
for p in params:
params_dict[p.name()] = p.data
safepickle(params_dict, path)
def load(self, path):
def load(self, path: str):
''' loads parameters from a file.
:param path: path to load.
:type path: str
Example:
>>> class Net(nn.Module):
>>> ...
>>> net = Net()
>>> net.save('net.pkl')
>>> net.load('net.pkl')
.. note::
当载入的参数与模型定义不一致时, jittor 会输出错误信息, 但是不会抛出异常.
若载入参数出现模型定义中没有的参数名, 则会输出如下信息, 并忽略此参数:
>>> [w 0205 21:49:39.962762 96 __init__.py:723] load parameter w failed ...
若载入参数的 shape 与模型定义不一致, 则会输出如下信息, 并忽略此参数:
>>> [e 0205 21:49:39.962822 96 __init__.py:739] load parameter w failed: expect the shape of w to be [1000,100,], but got [3,100,100,]
如载入过程中出现错误, jittor 会输出概要信息, 您需要仔细核对错误信息
>>> [w 0205 21:49:39.962906 96 __init__.py:741] load total 100 params, 3 failed
'''
self.load_parameters(load(path))
def eval(self):
@ -776,7 +974,7 @@ class Module:
if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]:
p.start_grad()
def is_training(self):
def is_training(self) -> bool:
if not hasattr(self, "is_train"):
self.is_train = True
return self.is_train

View File

@ -49,23 +49,33 @@ struct VarHolder {
// @pyjt(fetch_sync,numpy)
ArrayArgs fetch_sync();
/**
* assign the data from another Var.
*/
// @pyjt(assign)
// @attrs(return_self)
VarHolder* assign(VarHolder* v);
/* update parameter and global variable,
different from assign, it will
stop grad between origin var and assigned var, and
will update in the background
/**
* update parameter and global variable,
* different from assign, it will
* stop grad between origin var and assigned var, and
* will update in the background
*/
// @pyjt(update)
// @attrs(return_self)
VarHolder* update(VarHolder* v);
/* update parameter without set attribute */
/**
* update parameter without set attribute.
*/
// @pyjt(_update)
// @attrs(return_self)
VarHolder* _update(VarHolder* v);
/**
* swap the data with another Var.
*/
// @pyjt(swap)
// @attrs(return_self)
inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; };
@ -74,6 +84,9 @@ struct VarHolder {
static list<VarHolder*> hold_vars;
/**
* set the name of the Var.
*/
// @pyjt(name)
// @attrs(return_self)
inline VarHolder* name(const char* s) {
@ -81,17 +94,26 @@ struct VarHolder {
return this;
}
/**
* return the name of the Var.
*/
// @pyjt(name)
inline const char* name() {
return var->name.c_str();
}
/**
* return the number of elements in the Var.
*/
// @pyjt(numel)
inline int64 numel() {
if (var->num<0) sync();
return var->num;
}
/**
* disable the gradient calculation for the Var.
*/
// @pyjt(stop_grad)
// @attrs(return_self)
inline VarHolder* stop_grad() {
@ -99,6 +121,9 @@ struct VarHolder {
return this;
}
/**
* return True if the gradient is stopped.
*/
// @pyjt(is_stop_grad)
inline bool is_stop_grad() {
return var->is_stop_grad();
@ -111,6 +136,9 @@ struct VarHolder {
}
/**
* stop operator fusion.
*/
// @pyjt(stop_fuse)
// @attrs(return_self)
inline VarHolder* stop_fuse() {
@ -118,22 +146,36 @@ struct VarHolder {
return this;
}
/**
* return True if operator fusion is stopped.
*/
// @pyjt(is_stop_fuse)
inline bool is_stop_fuse() {
return var->flags.get(NodeFlags::_stop_fuse);
}
/**
* return the shape of the Var.
*/
// @pyjt(__get__shape)
inline NanoVector shape() {
if (var->num<0) sync();
return var->shape;
}
/**
* return True if the Var requires gradient calculation.
* @see is_stop_grad
*/
// @pyjt(__get__requires_grad)
inline bool get_requires_grad() {
return !var->is_stop_grad();
}
/**
* enable or disable gradient calculation.
* @see stop_grad
*/
// @pyjt(__set__requires_grad)
inline void set_requires_grad(bool flag) {
if (flag == get_requires_grad()) return;
@ -149,6 +191,9 @@ struct VarHolder {
return var->shape;
}
/**
* return the data type of the Var.
*/
// @pyjt(__get__dtype)
inline NanoString dtype() {
return var->dtype();
@ -164,7 +209,9 @@ struct VarHolder {
var->loop_options = move(options);
}
/** Get a numpy array which share the data with the var. */
/**
* get a numpy array which shares the data with the Var.
*/
// @pyjt(__get__data)
inline DataView data() {
sync(true);
@ -174,10 +221,16 @@ struct VarHolder {
return {this, var->mem_ptr, var->shape, var->dtype()};
}
/** Get one item data */
/**
* returns the Python number if the Var contains only one element.
* For other cases, see data().
*/
// @pyjt(item)
ItemData item();
/**
* return the number of dimensions.
*/
// @pyjt(__get__ndim)
inline int ndim() {
return var->shape.size();
@ -206,6 +259,9 @@ struct VarHolder {
return this;
}
/**
* print the information of the Var to debug.
*/
// @pyjt(debug_msg)
string debug_msg();
};