diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index e49cdb8c..57277ad6 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -113,6 +113,22 @@ Example:: self.jt_flags = jt_flags jt_flags["no_grad"] = 1 +class enable_grad(flag_scope): + ''' enable_grad scope, all variable created inside this +scope will start grad. + +Example:: + + import jittor as jt + + with jt.enable_grad(): + ... + + ''' + def __init__(self, **jt_flags): + self.jt_flags = jt_flags + jt_flags["no_grad"] = 0 + single_log_capture = None class log_capture_scope(_call_no_record_scope):