detach clone

This commit is contained in:
Dun Liang 2020-06-19 21:22:08 +08:00
parent 32abf3db0b
commit db275ed436
5 changed files with 21 additions and 7 deletions

View File

@ -214,11 +214,6 @@ def zeros(shape, dtype="float32"):
flags = core.flags()
def detach(x):
"""return detached var"""
return x.clone().stop_grad().clone()
Var.detach = detach
def std(x):
matsize=1
for i in x.shape:

View File

@ -31,4 +31,11 @@ void CloneOp::infer_shape() {
y->set_shape(x->shape);
y->share_with(x);
}
VarPtr detach(Var* x) {
auto y = make_clone(x);
y->input()->set_stop_grad();
return y;
}
} // jittor

View File

@ -18,4 +18,7 @@ struct CloneOp : Op {
VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override;
void infer_shape() override;
};
VarPtr detach(Var* x);
} // jittor

View File

@ -28,8 +28,9 @@ VarHolder::VarHolder(Var* v) : var(v) {
var->own_both_liveness();
}
VarHolder::VarHolder(VarPtr&& v) : VarHolder(v.ptr) {
v.free_liveness();
VarHolder::VarHolder(VarPtr&& v) {
add_hold_vars(this);
var = v.ptr;
v.ptr = nullptr;
}

View File

@ -13,6 +13,7 @@
namespace jittor {
struct VarHolder;
VarPtr detach(Var* x);
struct DataView {
VarHolder* vh;
@ -93,6 +94,13 @@ struct VarHolder {
return var->is_stop_grad();
}
/* detach the grad */
// @pyjt(detach)
inline VarHolder* detach() {
return new VarHolder(move(jittor::detach(var)));
}
// @pyjt(stop_fuse)
// @attrs(return_self)
inline VarHolder* stop_fuse() {