doc: var_holder

This commit is contained in:
lzhengning 2021-02-20 17:53:40 +08:00
parent da9a4a0232
commit 4db2dcdaa1
1 changed files with 63 additions and 7 deletions

View File

@ -49,23 +49,33 @@ struct VarHolder {
// @pyjt(fetch_sync,numpy)
ArrayArgs fetch_sync();
/**
* assign the data from another Var.
*/
// @pyjt(assign)
// @attrs(return_self)
VarHolder* assign(VarHolder* v);
/* update parameter and global variable,
different from assign, it will
stop grad between origin var and assigned var, and
will update in the background
/**
* update parameter and global variable,
* different from assign, it will
* stop grad between origin var and assigned var, and
* will update in the background
*/
// @pyjt(update)
// @attrs(return_self)
VarHolder* update(VarHolder* v);
/* update parameter without set attribute */
/**
* update parameter without set attribute.
*/
// @pyjt(_update)
// @attrs(return_self)
VarHolder* _update(VarHolder* v);
/**
* swap the data with another Var.
*/
// @pyjt(swap)
// @attrs(return_self)
inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; };
@ -74,6 +84,9 @@ struct VarHolder {
static list<VarHolder*> hold_vars;
/**
* set the name of the Var.
*/
// @pyjt(name)
// @attrs(return_self)
inline VarHolder* name(const char* s) {
@ -81,17 +94,26 @@ struct VarHolder {
return this;
}
/**
* return the name of the Var.
*/
// @pyjt(name)
inline const char* name() {
return var->name.c_str();
}
/**
* return the number of elements in the Var.
*/
// @pyjt(numel)
inline int64 numel() {
if (var->num<0) sync();
return var->num;
}
/**
* disable the gradient calculation for the Var.
*/
// @pyjt(stop_grad)
// @attrs(return_self)
inline VarHolder* stop_grad() {
@ -99,6 +121,9 @@ struct VarHolder {
return this;
}
/**
* return True if the gradient is stopped.
*/
// @pyjt(is_stop_grad)
inline bool is_stop_grad() {
return var->is_stop_grad();
@ -111,6 +136,9 @@ struct VarHolder {
}
/**
* stop operator fusion.
*/
// @pyjt(stop_fuse)
// @attrs(return_self)
inline VarHolder* stop_fuse() {
@ -118,22 +146,36 @@ struct VarHolder {
return this;
}
/**
* return True if operator fusion is stopped.
*/
// @pyjt(is_stop_fuse)
inline bool is_stop_fuse() {
return var->flags.get(NodeFlags::_stop_fuse);
}
/**
* return the shape of the Var.
*/
// @pyjt(__get__shape)
inline NanoVector shape() {
if (var->num<0) sync();
return var->shape;
}
/**
* return True if the Var requires gradient calculation.
* @see is_stop_grad
*/
// @pyjt(__get__requires_grad)
inline bool get_requires_grad() {
return !var->is_stop_grad();
}
/**
* enable or disable gradient calculation.
* @see stop_grad
*/
// @pyjt(__set__requires_grad)
inline void set_requires_grad(bool flag) {
if (flag == get_requires_grad()) return;
@ -149,6 +191,9 @@ struct VarHolder {
return var->shape;
}
/**
* return the data type of the Var.
*/
// @pyjt(__get__dtype)
inline NanoString dtype() {
return var->dtype();
@ -164,7 +209,9 @@ struct VarHolder {
var->loop_options = move(options);
}
/** Get a numpy array which share the data with the var. */
/**
* get a numpy array which shares the data with the Var.
*/
// @pyjt(__get__data)
inline DataView data() {
sync(true);
@ -174,10 +221,16 @@ struct VarHolder {
return {this, var->mem_ptr, var->shape, var->dtype()};
}
/** Get one item data */
/**
* returns the Python number if the Var contains only one element.
* For other cases, see data().
*/
// @pyjt(item)
ItemData item();
/**
* return the number of dimensions.
*/
// @pyjt(__get__ndim)
inline int ndim() {
return var->shape.size();
@ -206,6 +259,9 @@ struct VarHolder {
return this;
}
/**
* print the information of the Var to debug.
*/
// @pyjt(debug_msg)
string debug_msg();
};