mirror of https://github.com/inclusionAI/AReaL
PullRequest: 336 add wrapper
Merge branch lite-util-wrapper of git@code.alipay.com:inclusionAI/AReaL.git into lite https://code.alipay.com/inclusionAI/AReaL/pull_requests/336 Reviewed-by: 博惟 <bowei.fw@antgroup.com> * add wrapper
This commit is contained in:
parent
8771778995
commit
7a438c0650
|
@ -0,0 +1,48 @@
|
|||
from arealite.utils.wrapper import (
|
||||
wrap,
|
||||
wrap_get_method,
|
||||
wrap_get_method_name,
|
||||
wrap_remove_meta,
|
||||
wrapable,
|
||||
)
|
||||
|
||||
|
||||
class Calculator:
|
||||
def __init__(self):
|
||||
self.remaining = 2
|
||||
|
||||
@wrapable()
|
||||
def add(self, a, b):
|
||||
return a + b + self.remaining
|
||||
|
||||
@wrapable(name="multiply")
|
||||
def mul(self, x, y):
|
||||
return x * y * self.remaining
|
||||
|
||||
|
||||
class EnhancedCalculator:
|
||||
def __init__(self, obj: Calculator):
|
||||
super().__init__()
|
||||
self.prefix = "Result:"
|
||||
self.calculator = obj
|
||||
wrap(self, obj, self._wrap_call)
|
||||
|
||||
def _wrap_call(self, *args, **kwargs):
|
||||
method_name = wrap_get_method_name(kwargs)
|
||||
method = wrap_get_method(kwargs)
|
||||
kwargs = wrap_remove_meta(kwargs)
|
||||
print("wrap method: ", method_name)
|
||||
print("wrap method: ", method)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
|
||||
def test_wrapper():
|
||||
calc = Calculator()
|
||||
enhancer = EnhancedCalculator(calc)
|
||||
|
||||
assert enhancer.add(2, 3) == 7
|
||||
assert enhancer.multiply(4, 5) == 40
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_wrapper()
|
|
@ -0,0 +1,47 @@
|
|||
import inspect
|
||||
from functools import partial
|
||||
from typing import Callable, Dict, Optional
|
||||
|
||||
|
||||
def wrapable(name: Optional[str] = None):
|
||||
def decorator(func):
|
||||
func.__wrap_meta__ = {
|
||||
"name": name or func.__name__,
|
||||
}
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def wrap(target: object, source: object, transform: Optional[Callable] = None):
|
||||
for name, member in inspect.getmembers(source):
|
||||
if callable(member) and hasattr(member, "__wrap_meta__"):
|
||||
meta = member.__wrap_meta__
|
||||
method_name = meta["name"]
|
||||
|
||||
if transform:
|
||||
setattr(
|
||||
target,
|
||||
method_name,
|
||||
partial(
|
||||
transform,
|
||||
wrap_method_name=method_name,
|
||||
wrap_original_method=member,
|
||||
),
|
||||
)
|
||||
else:
|
||||
setattr(target, method_name, member)
|
||||
|
||||
|
||||
def wrap_get_method_name(kwargs) -> str:
|
||||
return kwargs["wrap_method_name"]
|
||||
|
||||
|
||||
def wrap_get_method(kwargs) -> Callable:
|
||||
return kwargs["wrap_original_method"]
|
||||
|
||||
|
||||
def wrap_remove_meta(kwargs) -> Dict:
|
||||
del kwargs["wrap_method_name"]
|
||||
del kwargs["wrap_original_method"]
|
||||
return kwargs
|
Loading…
Reference in New Issue