mirror of https://github.com/Jittor/Jittor
add enable_grad
This commit is contained in:
parent
655f3cc090
commit
9a4bc9d183
|
@ -113,6 +113,22 @@ Example::
|
|||
self.jt_flags = jt_flags
|
||||
jt_flags["no_grad"] = 1
|
||||
|
||||
class enable_grad(flag_scope):
|
||||
''' enable_grad scope, all variable created inside this
|
||||
scope will start grad.
|
||||
|
||||
Example::
|
||||
|
||||
import jittor as jt
|
||||
|
||||
with jt.enable_grad():
|
||||
...
|
||||
|
||||
'''
|
||||
def __init__(self, **jt_flags):
|
||||
self.jt_flags = jt_flags
|
||||
jt_flags["no_grad"] = 0
|
||||
|
||||
single_log_capture = None
|
||||
|
||||
class log_capture_scope(_call_no_record_scope):
|
||||
|
|
Loading…
Reference in New Issue