mirror of https://github.com/Jittor/Jittor
add enable_grad
This commit is contained in:
parent
655f3cc090
commit
9a4bc9d183
|
@ -113,6 +113,22 @@ Example::
|
||||||
self.jt_flags = jt_flags
|
self.jt_flags = jt_flags
|
||||||
jt_flags["no_grad"] = 1
|
jt_flags["no_grad"] = 1
|
||||||
|
|
||||||
|
class enable_grad(flag_scope):
|
||||||
|
''' enable_grad scope, all variable created inside this
|
||||||
|
scope will start grad.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
import jittor as jt
|
||||||
|
|
||||||
|
with jt.enable_grad():
|
||||||
|
...
|
||||||
|
|
||||||
|
'''
|
||||||
|
def __init__(self, **jt_flags):
|
||||||
|
self.jt_flags = jt_flags
|
||||||
|
jt_flags["no_grad"] = 0
|
||||||
|
|
||||||
single_log_capture = None
|
single_log_capture = None
|
||||||
|
|
||||||
class log_capture_scope(_call_no_record_scope):
|
class log_capture_scope(_call_no_record_scope):
|
||||||
|
|
Loading…
Reference in New Issue