add register_pre_forward_hook

This commit is contained in:
周文洋 2020-12-09 21:53:28 +08:00
parent 899bc4d9e8
commit 61fe7f1eae
1 changed files with 15 additions and 0 deletions

View File

@ -560,6 +560,21 @@ class Module:
return ret
self.__class__.__call__ = new_call
def register_pre_forward_hook(self, func):
cls = self.__class__
self.__fhook2__ = func
if hasattr(cls, "__hooked2__"):
return
cls.__hooked2__ = True
origin_call = cls.__call__
def new_call(self, *args, **kw):
if hasattr(self, "__fhook2__"):
if len(kw):
self.__fhook2__(self, args, kw)
else:
self.__fhook2__(self, args)
return origin_call(self, *args, **kw)
self.__class__.__call__ = new_call
def children(self):
cd = []