mirror of https://github.com/Jittor/Jittor
doc: var_holder
This commit is contained in:
parent
da9a4a0232
commit
4db2dcdaa1
|
@ -49,23 +49,33 @@ struct VarHolder {
|
|||
// @pyjt(fetch_sync,numpy)
|
||||
ArrayArgs fetch_sync();
|
||||
|
||||
/**
|
||||
* assign the data from another Var.
|
||||
*/
|
||||
// @pyjt(assign)
|
||||
// @attrs(return_self)
|
||||
VarHolder* assign(VarHolder* v);
|
||||
|
||||
/* update parameter and global variable,
|
||||
different from assign, it will
|
||||
stop grad between origin var and assigned var, and
|
||||
will update in the background
|
||||
/**
|
||||
* update parameter and global variable,
|
||||
* different from assign, it will
|
||||
* stop grad between origin var and assigned var, and
|
||||
* will update in the background
|
||||
*/
|
||||
// @pyjt(update)
|
||||
// @attrs(return_self)
|
||||
VarHolder* update(VarHolder* v);
|
||||
/* update parameter without set attribute */
|
||||
|
||||
/**
|
||||
* update parameter without set attribute.
|
||||
*/
|
||||
// @pyjt(_update)
|
||||
// @attrs(return_self)
|
||||
VarHolder* _update(VarHolder* v);
|
||||
|
||||
/**
|
||||
* swap the data with another Var.
|
||||
*/
|
||||
// @pyjt(swap)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* swap(VarHolder* v) { std::swap(var, v->var); return this; };
|
||||
|
@ -74,6 +84,9 @@ struct VarHolder {
|
|||
|
||||
static list<VarHolder*> hold_vars;
|
||||
|
||||
/**
|
||||
* set the name of the Var.
|
||||
*/
|
||||
// @pyjt(name)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* name(const char* s) {
|
||||
|
@ -81,17 +94,26 @@ struct VarHolder {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* return the name of the Var.
|
||||
*/
|
||||
// @pyjt(name)
|
||||
inline const char* name() {
|
||||
return var->name.c_str();
|
||||
}
|
||||
|
||||
/**
|
||||
* return the number of elements in the Var.
|
||||
*/
|
||||
// @pyjt(numel)
|
||||
inline int64 numel() {
|
||||
if (var->num<0) sync();
|
||||
return var->num;
|
||||
}
|
||||
|
||||
/**
|
||||
* disable the gradient calculation for the Var.
|
||||
*/
|
||||
// @pyjt(stop_grad)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* stop_grad() {
|
||||
|
@ -99,6 +121,9 @@ struct VarHolder {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* return True if the gradient is stopped.
|
||||
*/
|
||||
// @pyjt(is_stop_grad)
|
||||
inline bool is_stop_grad() {
|
||||
return var->is_stop_grad();
|
||||
|
@ -111,6 +136,9 @@ struct VarHolder {
|
|||
}
|
||||
|
||||
|
||||
/**
|
||||
* stop operator fusion.
|
||||
*/
|
||||
// @pyjt(stop_fuse)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* stop_fuse() {
|
||||
|
@ -118,22 +146,36 @@ struct VarHolder {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* return True if operator fusion is stopped.
|
||||
*/
|
||||
// @pyjt(is_stop_fuse)
|
||||
inline bool is_stop_fuse() {
|
||||
return var->flags.get(NodeFlags::_stop_fuse);
|
||||
}
|
||||
|
||||
/**
|
||||
* return the shape of the Var.
|
||||
*/
|
||||
// @pyjt(__get__shape)
|
||||
inline NanoVector shape() {
|
||||
if (var->num<0) sync();
|
||||
return var->shape;
|
||||
}
|
||||
|
||||
/**
|
||||
* return True if the Var requires gradient calculation.
|
||||
* @see is_stop_grad
|
||||
*/
|
||||
// @pyjt(__get__requires_grad)
|
||||
inline bool get_requires_grad() {
|
||||
return !var->is_stop_grad();
|
||||
}
|
||||
|
||||
/**
|
||||
* enable or disable gradient calculation.
|
||||
* @see stop_grad
|
||||
*/
|
||||
// @pyjt(__set__requires_grad)
|
||||
inline void set_requires_grad(bool flag) {
|
||||
if (flag == get_requires_grad()) return;
|
||||
|
@ -149,6 +191,9 @@ struct VarHolder {
|
|||
return var->shape;
|
||||
}
|
||||
|
||||
/**
|
||||
* return the data type of the Var.
|
||||
*/
|
||||
// @pyjt(__get__dtype)
|
||||
inline NanoString dtype() {
|
||||
return var->dtype();
|
||||
|
@ -164,7 +209,9 @@ struct VarHolder {
|
|||
var->loop_options = move(options);
|
||||
}
|
||||
|
||||
/** Get a numpy array which share the data with the var. */
|
||||
/**
|
||||
* get a numpy array which shares the data with the Var.
|
||||
*/
|
||||
// @pyjt(__get__data)
|
||||
inline DataView data() {
|
||||
sync(true);
|
||||
|
@ -174,10 +221,16 @@ struct VarHolder {
|
|||
return {this, var->mem_ptr, var->shape, var->dtype()};
|
||||
}
|
||||
|
||||
/** Get one item data */
|
||||
/**
|
||||
* returns the Python number if the Var contains only one element.
|
||||
* For other cases, see data().
|
||||
*/
|
||||
// @pyjt(item)
|
||||
ItemData item();
|
||||
|
||||
/**
|
||||
* return the number of dimensions.
|
||||
*/
|
||||
// @pyjt(__get__ndim)
|
||||
inline int ndim() {
|
||||
return var->shape.size();
|
||||
|
@ -206,6 +259,9 @@ struct VarHolder {
|
|||
return this;
|
||||
}
|
||||
|
||||
/**
|
||||
* print the information of the Var to debug.
|
||||
*/
|
||||
// @pyjt(debug_msg)
|
||||
string debug_msg();
|
||||
};
|
||||
|
|
Loading…
Reference in New Issue