mirror of https://github.com/Jittor/Jittor
Merge branch 'master' of https://github.com/Jittor/jittor
This commit is contained in:
commit
3df3bfc35c
|
@ -404,35 +404,180 @@ def argmin(x, dim, keepdims:bool=False):
|
|||
return x.arg_reduce("min", dim, keepdims)
|
||||
Var.argmin = argmin
|
||||
|
||||
def randn(*size, dtype="float32", requires_grad=True):
|
||||
def randn(*size, dtype="float32", requires_grad=True) -> Var:
|
||||
''' samples random numbers from a standard normal distribution.
|
||||
|
||||
:param size: shape of the output.
|
||||
:type size: int or a sequence of int
|
||||
|
||||
:param dtype: data type, defaults to "float32".
|
||||
:type dtype: str, optional
|
||||
|
||||
:param requires_grad: whether to enable gradient back-propgation, defaults to True.
|
||||
:type requires_grad: bool, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> jt.randn(3)
|
||||
jt.Var([-1.019889 -0.30377278 -1.4948598 ], dtype=float32)
|
||||
>>> jt.randn(2, 3)
|
||||
jt.Var([[-0.15989183 -1.5010914 0.5476955 ]
|
||||
[-0.612632 -1.1471151 -1.1879086 ]], dtype=float32)
|
||||
'''
|
||||
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
|
||||
arr = jt.random(size, dtype, "normal")
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def rand(*size, dtype="float32", requires_grad=True):
|
||||
def rand(*size, dtype="float32", requires_grad=True) -> Var:
|
||||
''' samples random numbers from a uniform distribution on the interval [0, 1).
|
||||
|
||||
:param size: shape of the output.
|
||||
:type size: int or a sequence of int
|
||||
|
||||
:param dtype: data type, defaults to "float32".
|
||||
:type dtype: str, optional
|
||||
|
||||
:param requires_grad: whether to enable gradient back-propgation. defaults to True.
|
||||
:type requires_grad: bool, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> jt.rand(3)
|
||||
jt.Var([0.31005102 0.02765604 0.8150749 ], dtype=float32)
|
||||
>>> jt.rand(2, 3)
|
||||
jt.Var([[0.96414304 0.3519264 0.8268017 ]
|
||||
[0.05658621 0.04449705 0.86190987]], dtype=float32)
|
||||
'''
|
||||
if isinstance(size, tuple) and isinstance(size[0], (tuple, list, NanoVector)): size = size[0]
|
||||
arr = jt.random(size, dtype)
|
||||
if not requires_grad: return arr.stop_grad()
|
||||
return arr
|
||||
|
||||
def rand_like(x, dtype=None):
|
||||
def rand_like(x, dtype=None) -> Var:
|
||||
''' samples random values from standard uniform distribution with the same shape as x.
|
||||
|
||||
:param x: reference variable.
|
||||
:type x: jt.Var
|
||||
|
||||
:param dtype: if None, the dtype of the output is the same as x.
|
||||
Otherwise, use the specified dtype. Defaults to None.
|
||||
:type dtype: str, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = jt.zeros((2, 3))
|
||||
>>> jt.rand_like(x)
|
||||
jt.Var([[0.6164821 0.21476883 0.61959815]
|
||||
[0.58626485 0.35345772 0.5638483 ]], dtype=float32)
|
||||
'''
|
||||
if dtype is None: dtype = x.dtype
|
||||
return jt.random(x.shape, x.dtype)
|
||||
|
||||
def randn_like(x, dtype=None):
|
||||
def randn_like(x, dtype=None) -> Var:
|
||||
''' samples random values from standard normal distribution with the same shape as x.
|
||||
|
||||
:param x: reference variable.
|
||||
:type x: jt.Var
|
||||
|
||||
:param dtype: if None, the dtype of the output is the same as x.
|
||||
Otherwise, use the specified dtype. Defaults to None.
|
||||
:type dtype: str, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = jt.zeros((2, 3))
|
||||
>>> jt.randn_like(x)
|
||||
jt.Var([[-1.1647032 0.34847224 -1.3061888 ]
|
||||
[ 1.068085 -0.34366122 0.13172573]], dtype=float32)
|
||||
'''
|
||||
if dtype is None: dtype = x.dtype
|
||||
return jt.random(x.shape, x.dtype, "normal")
|
||||
|
||||
def randint(low, high=None, shape=(1,), dtype="int32"):
|
||||
def randint(low, high=None, shape=(1,), dtype="int32") -> Var:
|
||||
''' samples random integers from a uniform distribution on the interval [low, high).
|
||||
|
||||
:param low: lowest intergers to be drawn from the distribution, defaults to 0.
|
||||
:type low: int, optional
|
||||
|
||||
:param high: One above the highest integer to be drawn from the distribution.
|
||||
:type high: int
|
||||
|
||||
:param shape: shape of the output size, defaults to (1,).
|
||||
:type shape: tuple, optional
|
||||
|
||||
:param dtype: data type of the output, defaults to "int32".
|
||||
:type dtype: str, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> jt.randint(3, shape=(3, 3))
|
||||
jt.Var([[2 0 2]
|
||||
[2 1 2]
|
||||
[2 0 1]], dtype=int32)
|
||||
>>> jt.randint(1, 3, shape=(3, 3))
|
||||
jt.Var([[2 2 2]
|
||||
[1 1 2]
|
||||
[1 1 1]], dtype=int32)
|
||||
'''
|
||||
if high is None: low, high = 0, low
|
||||
v = (jt.random(shape) * (high - low) + low).clamp(low, high-0.5)
|
||||
v = jt.floor(v)
|
||||
return v.astype(dtype)
|
||||
|
||||
def randint_like(x, low, high=None):
|
||||
def randint_like(x, low, high=None) -> Var:
|
||||
''' samples random values from standard normal distribution with the same shape as x.
|
||||
|
||||
:param x: reference variable.
|
||||
:type x: jt.Var
|
||||
|
||||
:param low: lowest intergers to be drawn from the distribution, defaults to 0.
|
||||
:type low: int, optional
|
||||
|
||||
:param high: One above the highest integer to be drawn from the distribution.
|
||||
:type high: int
|
||||
|
||||
Example:
|
||||
|
||||
>>> x = jt.zeros((2, 3))
|
||||
>>> jt.randint_like(x, 10)
|
||||
jt.Var([[9. 3. 4.]
|
||||
[4. 8. 5.]], dtype=float32)
|
||||
>>> jt.randint_like(x, 10, 20)
|
||||
jt.Var([[17. 11. 18.]
|
||||
[14. 17. 15.]], dtype=float32)
|
||||
'''
|
||||
|
||||
return randint(low, high, x.shape, x.dtype)
|
||||
|
||||
def normal(mean, std, size=None, dtype="float32"):
|
||||
def normal(mean, std, size=None, dtype="float32") -> Var:
|
||||
''' samples random values from a normal distribution.
|
||||
|
||||
:param mean: means of the normal distributions.
|
||||
:type mean: int or jt.Var
|
||||
|
||||
:param std: standard deviations of the normal distributions.
|
||||
:type std: int or jt.Var
|
||||
|
||||
:param size: shape of the output size. if not specified, the
|
||||
shape of the output is determined by mean or std. Exception will be
|
||||
raised if mean and std are all integers or have different shape in
|
||||
this case. Defaults to None
|
||||
:type size: tuple, optional
|
||||
|
||||
:param dtype: data type of the output, defaults to "float32".
|
||||
:type dtype: str, optional
|
||||
|
||||
Example:
|
||||
|
||||
>>> jt.normal(5, 3, size=(2,3))
|
||||
jt.Var([[ 8.070848 7.654219 10.252696 ]
|
||||
[ 6.383718 7.8817277 3.0786133]], dtype=float32)
|
||||
>>> mean = jt.randint(low=0, high=10, shape=(10,))
|
||||
>>> jt.normal(mean, 0.1)
|
||||
jt.Var([1.9524184 1.0749301 7.9864206 5.9407325 8.1596155 4.824019 7.955083
|
||||
8.972998 6.0674286 8.88026 ], dtype=float32)
|
||||
'''
|
||||
if size is None:
|
||||
if isinstance(mean, Var) and isinstance(std, Var):
|
||||
assert mean.shape == std.shape
|
||||
|
@ -507,7 +652,9 @@ def display_memory_info():
|
|||
fileline = f"{os.path.basename(fileline.filename)}:{fileline.lineno}"
|
||||
core.display_memory_info(fileline)
|
||||
|
||||
def load(path):
|
||||
def load(path: str):
|
||||
''' loads an object from a file.
|
||||
'''
|
||||
if path.endswith(".pth"):
|
||||
try:
|
||||
dirty_fix_pytorch_runtime_error()
|
||||
|
@ -519,7 +666,14 @@ def load(path):
|
|||
model_dict = safeunpickle(path)
|
||||
return model_dict
|
||||
|
||||
def save(params_dict, path):
|
||||
def save(params_dict, path: str):
|
||||
''' saves the parameter dictionary to a file.
|
||||
|
||||
:param params_dict: parameters to be saved
|
||||
:type params_dict: list or dictionary
|
||||
:param path: file path
|
||||
:type path: str
|
||||
'''
|
||||
def dfs(x):
|
||||
if isinstance(x, list):
|
||||
for i in range(len(x)):
|
||||
|
@ -695,6 +849,10 @@ class Module:
|
|||
func(m)
|
||||
|
||||
def load_parameters(self, params):
|
||||
''' loads parameters to the Module.
|
||||
|
||||
:param params: dictionary of parameter names and parameters.
|
||||
'''
|
||||
n_failed = 0
|
||||
for key in params.keys():
|
||||
v = self
|
||||
|
@ -740,14 +898,54 @@ class Module:
|
|||
if n_failed:
|
||||
LOG.w(f"load total {len(params)} params, {n_failed} failed")
|
||||
|
||||
def save(self, path):
|
||||
def save(self, path: str):
|
||||
''' saves parameters to a file.
|
||||
|
||||
:param path: path to save.
|
||||
:type path: str
|
||||
|
||||
Example:
|
||||
|
||||
>>> class Net(nn.Module):
|
||||
>>> ...
|
||||
>>> net = Net()
|
||||
>>> net.save('net.pkl')
|
||||
>>> net.load('net.pkl')
|
||||
'''
|
||||
params = self.parameters()
|
||||
params_dict = {}
|
||||
for p in params:
|
||||
params_dict[p.name()] = p.data
|
||||
safepickle(params_dict, path)
|
||||
|
||||
def load(self, path):
|
||||
def load(self, path: str):
|
||||
''' loads parameters from a file.
|
||||
|
||||
:param path: path to load.
|
||||
:type path: str
|
||||
|
||||
Example:
|
||||
|
||||
>>> class Net(nn.Module):
|
||||
>>> ...
|
||||
>>> net = Net()
|
||||
>>> net.save('net.pkl')
|
||||
>>> net.load('net.pkl')
|
||||
|
||||
.. note::
|
||||
当载入的参数与模型定义不一致时, jittor 会输出错误信息, 但是不会抛出异常.
|
||||
若载入参数出现模型定义中没有的参数名, 则会输出如下信息, 并忽略此参数:
|
||||
|
||||
>>> [w 0205 21:49:39.962762 96 __init__.py:723] load parameter w failed ...
|
||||
|
||||
若载入参数的 shape 与模型定义不一致, 则会输出如下信息, 并忽略此参数:
|
||||
|
||||
>>> [e 0205 21:49:39.962822 96 __init__.py:739] load parameter w failed: expect the shape of w to be [1000,100,], but got [3,100,100,]
|
||||
|
||||
如载入过程中出现错误, jittor 会输出概要信息, 您需要仔细核对错误信息
|
||||
|
||||
>>> [w 0205 21:49:39.962906 96 __init__.py:741] load total 100 params, 3 failed
|
||||
'''
|
||||
self.load_parameters(load(path))
|
||||
|
||||
def eval(self):
|
||||
|
@ -776,7 +974,7 @@ class Module:
|
|||
if id(p) in self.backup_grad_state and self.backup_grad_state[id(p)]:
|
||||
p.start_grad()
|
||||
|
||||
def is_training(self):
|
||||
def is_training(self) -> bool:
|
||||
if not hasattr(self, "is_train"):
|
||||
self.is_train = True
|
||||
return self.is_train
|
||||
|
|
|
@ -49,23 +49,33 @@ struct VarHolder {
|
|||
// @pyjt(fetch_sync,numpy)
|
||||
ArrayArgs fetch_sync();
|
||||
|
||||
/**
|
||||
* assign the data from another Var.
|
||||
*/
|
||||
// @pyjt(assign)
|
||||
// @attrs(return_self)
|
||||
VarHolder* assign(VarHolder* v);
|
||||
|
||||
/* update parameter and global variable,
|
||||
different from assign, it will
|
||||
stop grad between origin var and assigned var, and
|
||||
will update in the background
|
||||
/**
|
||||
* update parameter and global variable,
|
||||
* different from assign, it will
|
||||
* stop grad between origin var and assigned var, and
|
||||
* will update in the background
|
||||
*/
|
||||
// @pyjt(update)
|
||||
// @attrs(return_self)
|
||||
VarHolder* update(VarHolder* v);
|
||||
/* update parameter without set attribute */
|
||||
|
||||
/**
|
||||
* update parameter without set attribute.
|
||||
*/
|
||||
// @pyjt(_update)
|
||||
// @attrs(return_self)
|
||||
VarHolder* _update(VarHolder* v);
|
||||
|
||||
/**
|
||||
* swap the data with another Var.
|
||||
*/
|
||||
// @pyjt(swap)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; };
|
||||
|
@ -74,6 +84,9 @@ struct VarHolder {
|
|||
|
||||
static list<VarHolder*> hold_vars;
|
||||
|
||||
/**
|
||||
* set the name of the Var.
|
||||
*/
|
||||
// @pyjt(name)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* name(const char* s) {
|
||||
|
@ -81,17 +94,26 @@ struct VarHolder {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* return the name of the Var.
|
||||
*/
|
||||
// @pyjt(name)
|
||||
inline const char* name() {
|
||||
return var->name.c_str();
|
||||
}
|
||||
|
||||
/**
|
||||
* return the number of elements in the Var.
|
||||
*/
|
||||
// @pyjt(numel)
|
||||
inline int64 numel() {
|
||||
if (var->num<0) sync();
|
||||
return var->num;
|
||||
}
|
||||
|
||||
/**
|
||||
* disable the gradient calculation for the Var.
|
||||
*/
|
||||
// @pyjt(stop_grad)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* stop_grad() {
|
||||
|
@ -99,6 +121,9 @@ struct VarHolder {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* return True if the gradient is stopped.
|
||||
*/
|
||||
// @pyjt(is_stop_grad)
|
||||
inline bool is_stop_grad() {
|
||||
return var->is_stop_grad();
|
||||
|
@ -111,6 +136,9 @@ struct VarHolder {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* stop operator fusion.
|
||||
*/
|
||||
// @pyjt(stop_fuse)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* stop_fuse() {
|
||||
|
@ -118,22 +146,36 @@ struct VarHolder {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* return True if operator fusion is stopped.
|
||||
*/
|
||||
// @pyjt(is_stop_fuse)
|
||||
inline bool is_stop_fuse() {
|
||||
return var->flags.get(NodeFlags::_stop_fuse);
|
||||
}
|
||||
|
||||
/**
|
||||
* return the shape of the Var.
|
||||
*/
|
||||
// @pyjt(__get__shape)
|
||||
inline NanoVector shape() {
|
||||
if (var->num<0) sync();
|
||||
return var->shape;
|
||||
}
|
||||
|
||||
/**
|
||||
* return True if the Var requires gradient calculation.
|
||||
* @see is_stop_grad
|
||||
*/
|
||||
// @pyjt(__get__requires_grad)
|
||||
inline bool get_requires_grad() {
|
||||
return !var->is_stop_grad();
|
||||
}
|
||||
|
||||
/**
|
||||
* enable or disable gradient calculation.
|
||||
* @see stop_grad
|
||||
*/
|
||||
// @pyjt(__set__requires_grad)
|
||||
inline void set_requires_grad(bool flag) {
|
||||
if (flag == get_requires_grad()) return;
|
||||
|
@ -149,6 +191,9 @@ struct VarHolder {
|
|||
return var->shape;
|
||||
}
|
||||
|
||||
/**
|
||||
* return the data type of the Var.
|
||||
*/
|
||||
// @pyjt(__get__dtype)
|
||||
inline NanoString dtype() {
|
||||
return var->dtype();
|
||||
|
@ -164,7 +209,9 @@ struct VarHolder {
|
|||
var->loop_options = move(options);
|
||||
}
|
||||
|
||||
/** Get a numpy array which share the data with the var. */
|
||||
/**
|
||||
* get a numpy array which shares the data with the Var.
|
||||
*/
|
||||
// @pyjt(__get__data)
|
||||
inline DataView data() {
|
||||
sync(true);
|
||||
|
@ -174,10 +221,16 @@ struct VarHolder {
|
|||
return {this, var->mem_ptr, var->shape, var->dtype()};
|
||||
}
|
||||
|
||||
/** Get one item data */
|
||||
/**
|
||||
* returns the Python number if the Var contains only one element.
|
||||
* For other cases, see data().
|
||||
*/
|
||||
// @pyjt(item)
|
||||
ItemData item();
|
||||
|
||||
/**
|
||||
* return the number of dimensions.
|
||||
*/
|
||||
// @pyjt(__get__ndim)
|
||||
inline int ndim() {
|
||||
return var->shape.size();
|
||||
|
@ -206,6 +259,9 @@ struct VarHolder {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* print the information of the Var to debug.
|
||||
*/
|
||||
// @pyjt(debug_msg)
|
||||
string debug_msg();
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue