mirror of https://github.com/Jittor/Jittor
detach clone
This commit is contained in:
parent
32abf3db0b
commit
db275ed436
|
@ -214,11 +214,6 @@ def zeros(shape, dtype="float32"):
|
|||
|
||||
flags = core.flags()
|
||||
|
||||
def detach(x):
|
||||
"""return detached var"""
|
||||
return x.clone().stop_grad().clone()
|
||||
Var.detach = detach
|
||||
|
||||
def std(x):
|
||||
matsize=1
|
||||
for i in x.shape:
|
||||
|
|
|
@ -31,4 +31,11 @@ void CloneOp::infer_shape() {
|
|||
y->set_shape(x->shape);
|
||||
y->share_with(x);
|
||||
}
|
||||
|
||||
VarPtr detach(Var* x) {
|
||||
auto y = make_clone(x);
|
||||
y->input()->set_stop_grad();
|
||||
return y;
|
||||
}
|
||||
|
||||
} // jittor
|
|
@ -18,4 +18,7 @@ struct CloneOp : Op {
|
|||
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
|
||||
void infer_shape() override;
|
||||
};
|
||||
|
||||
VarPtr detach(Var* x);
|
||||
|
||||
} // jittor
|
|
@ -28,8 +28,9 @@ VarHolder::VarHolder(Var* v) : var(v) {
|
|||
var->own_both_liveness();
|
||||
}
|
||||
|
||||
VarHolder::VarHolder(VarPtr&& v) : VarHolder(v.ptr) {
|
||||
v.free_liveness();
|
||||
VarHolder::VarHolder(VarPtr&& v) {
|
||||
add_hold_vars(this);
|
||||
var = v.ptr;
|
||||
v.ptr = nullptr;
|
||||
}
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@
|
|||
namespace jittor {
|
||||
|
||||
struct VarHolder;
|
||||
VarPtr detach(Var* x);
|
||||
|
||||
struct DataView {
|
||||
VarHolder* vh;
|
||||
|
@ -93,6 +94,13 @@ struct VarHolder {
|
|||
return var->is_stop_grad();
|
||||
}
|
||||
|
||||
/* detach the grad */
|
||||
// @pyjt(detach)
|
||||
inline VarHolder* detach() {
|
||||
return new VarHolder(move(jittor::detach(var)));
|
||||
}
|
||||
|
||||
|
||||
// @pyjt(stop_fuse)
|
||||
// @attrs(return_self)
|
||||
inline VarHolder* stop_fuse() {
|
||||
|
|
Loading…
Reference in New Issue