add enable_grad

This commit is contained in:
li-xl 2021-03-16 16:15:09 +08:00
parent 655f3cc090
commit 9a4bc9d183
1 changed files with 16 additions and 0 deletions

View File

@ -113,6 +113,22 @@ Example::
self.jt_flags = jt_flags
jt_flags["no_grad"] = 1
class enable_grad(flag_scope):
''' enable_grad scope, all variable created inside this
scope will start grad.
Example::
import jittor as jt
with jt.enable_grad():
...
'''
def __init__(self, **jt_flags):
self.jt_flags = jt_flags
jt_flags["no_grad"] = 0
single_log_capture = None
class log_capture_scope(_call_no_record_scope):