From 9a4bc9d183feb63530517f1a4b7dbd53c13be539 Mon Sep 17 00:00:00 2001 From: li-xl <1905692338@qq.com> Date: Tue, 16 Mar 2021 16:15:09 +0800 Subject: [PATCH] add enable_grad --- python/jittor/__init__.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/jittor/__init__.py b/python/jittor/__init__.py index e49cdb8c..57277ad6 100644 --- a/python/jittor/__init__.py +++ b/python/jittor/__init__.py @@ -113,6 +113,22 @@ Example:: self.jt_flags = jt_flags jt_flags["no_grad"] = 1 +class enable_grad(flag_scope): + ''' enable_grad scope, all variable created inside this +scope will start grad. + +Example:: + + import jittor as jt + + with jt.enable_grad(): + ... + + ''' + def __init__(self, **jt_flags): + self.jt_flags = jt_flags + jt_flags["no_grad"] = 0 + single_log_capture = None class log_capture_scope(_call_no_record_scope):