mirror of https://github.com/Jittor/Jittor
add remove forward hook
This commit is contained in:
parent
433cae596e
commit
83775f0427
|
@ -8,7 +8,7 @@
|
|||
# This file is subject to the terms and conditions defined in
|
||||
# file 'LICENSE.txt', which is part of this source code package.
|
||||
# ***************************************************************
|
||||
__version__ = '1.2.2.53'
|
||||
__version__ = '1.2.2.54'
|
||||
from . import lock
|
||||
with lock.lock_scope():
|
||||
ori_int = int
|
||||
|
@ -831,6 +831,13 @@ class Module:
|
|||
self.__fhook__(self, args, ret)
|
||||
return ret
|
||||
self.__class__.__call__ = new_call
|
||||
|
||||
def remove_forward_hook(self):
|
||||
cls = self.__class__
|
||||
if hasattr(cls,"__hooked__"):
|
||||
delattr(cls,"__hooked__")
|
||||
if hasattr(self,"__fhook__"):
|
||||
delattr(self,"__fhook__")
|
||||
|
||||
def register_pre_forward_hook(self, func):
|
||||
cls = self.__class__
|
||||
|
@ -848,6 +855,13 @@ class Module:
|
|||
return origin_call(self, *args, **kw)
|
||||
self.__class__.__call__ = new_call
|
||||
|
||||
def remove_pre_forward_hook(self):
|
||||
cls = self.__class__
|
||||
if hasattr(cls,"__hooked2__"):
|
||||
delattr(cls,"__hooked2__")
|
||||
if hasattr(self,"__fhook2__"):
|
||||
delattr(self,"__fhook2__")
|
||||
|
||||
def children(self):
|
||||
cd = []
|
||||
def callback(parents, k, v, n):
|
||||
|
|
Loading…
Reference in New Issue