Merge pull request #163 from Jittor/zwy5

add register_pre_forward_hook
This commit is contained in:
zhouwy19 2020-12-09 21:56:02 +08:00 committed by GitHub
commit 3c95c6d100
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 0 deletions

View File

@ -560,6 +560,21 @@ class Module:
return ret
self.__class__.__call__ = new_call
def register_pre_forward_hook(self, func):
cls = self.__class__
self.__fhook2__ = func
if hasattr(cls, "__hooked2__"):
return
cls.__hooked2__ = True
origin_call = cls.__call__
def new_call(self, *args, **kw):
if hasattr(self, "__fhook2__"):
if len(kw):
self.__fhook2__(self, args, kw)
else:
self.__fhook2__(self, args)
return origin_call(self, *args, **kw)
self.__class__.__call__ = new_call
def children(self):
cd = []