mirror of https://github.com/Jittor/Jittor
commit
3c95c6d100
|
@ -560,6 +560,21 @@ class Module:
|
|||
return ret
|
||||
self.__class__.__call__ = new_call
|
||||
|
||||
def register_pre_forward_hook(self, func):
|
||||
cls = self.__class__
|
||||
self.__fhook2__ = func
|
||||
if hasattr(cls, "__hooked2__"):
|
||||
return
|
||||
cls.__hooked2__ = True
|
||||
origin_call = cls.__call__
|
||||
def new_call(self, *args, **kw):
|
||||
if hasattr(self, "__fhook2__"):
|
||||
if len(kw):
|
||||
self.__fhook2__(self, args, kw)
|
||||
else:
|
||||
self.__fhook2__(self, args)
|
||||
return origin_call(self, *args, **kw)
|
||||
self.__class__.__call__ = new_call
|
||||
|
||||
def children(self):
|
||||
cd = []
|
||||
|
|
Loading…
Reference in New Issue