PullRequest: 336 add wrapper

Merge branch lite-util-wrapper of git@code.alipay.com:inclusionAI/AReaL.git into lite
https://code.alipay.com/inclusionAI/AReaL/pull_requests/336

Reviewed-by: 博惟 <bowei.fw@antgroup.com>


* add wrapper
This commit is contained in:
郭唯 2025-07-09 15:44:27 +08:00
parent 8771778995
commit 7a438c0650
2 changed files with 95 additions and 0 deletions

View File

@ -0,0 +1,48 @@
from arealite.utils.wrapper import (
wrap,
wrap_get_method,
wrap_get_method_name,
wrap_remove_meta,
wrapable,
)
class Calculator:
def __init__(self):
self.remaining = 2
@wrapable()
def add(self, a, b):
return a + b + self.remaining
@wrapable(name="multiply")
def mul(self, x, y):
return x * y * self.remaining
class EnhancedCalculator:
def __init__(self, obj: Calculator):
super().__init__()
self.prefix = "Result:"
self.calculator = obj
wrap(self, obj, self._wrap_call)
def _wrap_call(self, *args, **kwargs):
method_name = wrap_get_method_name(kwargs)
method = wrap_get_method(kwargs)
kwargs = wrap_remove_meta(kwargs)
print("wrap method: ", method_name)
print("wrap method: ", method)
return method(*args, **kwargs)
def test_wrapper():
calc = Calculator()
enhancer = EnhancedCalculator(calc)
assert enhancer.add(2, 3) == 7
assert enhancer.multiply(4, 5) == 40
if __name__ == "__main__":
test_wrapper()

47
arealite/utils/wrapper.py Normal file
View File

@ -0,0 +1,47 @@
import inspect
from functools import partial
from typing import Callable, Dict, Optional
def wrapable(name: Optional[str] = None):
def decorator(func):
func.__wrap_meta__ = {
"name": name or func.__name__,
}
return func
return decorator
def wrap(target: object, source: object, transform: Optional[Callable] = None):
for name, member in inspect.getmembers(source):
if callable(member) and hasattr(member, "__wrap_meta__"):
meta = member.__wrap_meta__
method_name = meta["name"]
if transform:
setattr(
target,
method_name,
partial(
transform,
wrap_method_name=method_name,
wrap_original_method=member,
),
)
else:
setattr(target, method_name, member)
def wrap_get_method_name(kwargs) -> str:
return kwargs["wrap_method_name"]
def wrap_get_method(kwargs) -> Callable:
return kwargs["wrap_original_method"]
def wrap_remove_meta(kwargs) -> Dict:
del kwargs["wrap_method_name"]
del kwargs["wrap_original_method"]
return kwargs