kotones-auto-assistant/kotonebot/backend/context/task_action.py

176 lines
5.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import logging
from typing import Callable, ParamSpec, TypeVar, overload, Literal
from dataclasses import dataclass
from .context import ContextStackVars, ScreenshotMode
from ...errors import TaskNotFoundError
P = ParamSpec('P')
R = TypeVar('R')
logger = logging.getLogger(__name__)
TaskRunAtType = Literal['pre', 'post', 'manual', 'regular'] | str
@dataclass
class Task:
name: str
id: str
description: str
func: Callable
priority: int
"""
任务优先级,数字越大优先级越高。
"""
run_at: TaskRunAtType = 'regular'
@dataclass
class Action:
name: str
description: str
func: Callable
priority: int
"""
动作优先级,数字越大优先级越高。
"""
task_registry: dict[str, Task] = {}
action_registry: dict[str, Action] = {}
current_callstack: list[Task|Action] = []
def _placeholder():
raise NotImplementedError('Placeholder function')
def task(
name: str,
task_id: str|None = None,
description: str|None = None,
*,
pass_through: bool = False,
priority: int = 0,
screenshot_mode: ScreenshotMode = 'auto',
run_at: TaskRunAtType = 'regular'
):
"""
`task` 装饰器,用于标记一个函数为任务函数。
:param name: 任务名称
:param task_id: 任务 ID。如果为 None则使用函数名称作为 ID。
:param description: 任务描述。如果为 None则使用函数的 docstring 作为描述。
:param pass_through:
默认情况下, @task 装饰器会包裹任务函数,跟踪其执行情况。
如果不想跟踪,则设置此参数为 False。
:param priority: 任务优先级,数字越大优先级越高。
:param run_at: 任务运行时间。
"""
# 设置 ID
# 获取 caller 信息
def _task_decorator(func: Callable[P, R]) -> Callable[P, R]:
nonlocal description, task_id
description = description or func.__doc__ or ''
# TODO: task_id 冲突检测
task_id = task_id or func.__name__
task = Task(name, task_id, description, _placeholder, priority, run_at)
task_registry[name] = task
logger.debug(f'Task "{name}" registered.')
if pass_through:
return func
else:
def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
current_callstack.append(task)
vars = ContextStackVars.push(screenshot_mode=screenshot_mode)
ret = func(*args, **kwargs)
ContextStackVars.pop()
current_callstack.pop()
return ret
task.func = _wrapper
return _wrapper
return _task_decorator
@overload
def action(func: Callable[P, R]) -> Callable[P, R]: ...
@overload
def action(
name: str,
*,
description: str|None = None,
pass_through: bool = False,
priority: int = 0,
screenshot_mode: ScreenshotMode | None = None,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
"""
`action` 装饰器,用于标记一个函数为动作函数。
:param name: 动作名称。如果为 None则使用函数的名称作为名称。
:param description: 动作描述。如果为 None则使用函数的 docstring 作为描述。
:param pass_through:
默认情况下, @action 装饰器会包裹动作函数,跟踪其执行情况。
如果不想跟踪,则设置此参数为 False。
:param priority: 动作优先级,数字越大优先级越高。
:param screenshot_mode: 截图模式。
"""
...
# TODO: 需要找个地方统一管理这些属性名
ATTR_ORIGINAL_FUNC = '_kb_inner'
ATTR_ACTION_MARK = '__kb_action_mark'
def action(*args, **kwargs):
def _register(func: Callable, name: str, description: str|None = None, priority: int = 0) -> Action:
description = description or func.__doc__ or ''
action = Action(name, description, func, priority)
action_registry[name] = action
logger.debug(f'Action "{name}" registered.')
return action
if len(args) == 1 and isinstance(args[0], Callable):
func = args[0]
action = _register(_placeholder, func.__name__, func.__doc__)
def _wrapper(*args: P.args, **kwargs: P.kwargs):
current_callstack.append(action)
vars = ContextStackVars.push()
ret = func(*args, **kwargs)
ContextStackVars.pop()
current_callstack.pop()
return ret
setattr(_wrapper, ATTR_ORIGINAL_FUNC, func)
setattr(_wrapper, ATTR_ACTION_MARK, True)
action.func = _wrapper
return _wrapper
else:
name = args[0]
description = kwargs.get('description', None)
pass_through = kwargs.get('pass_through', False)
priority = kwargs.get('priority', 0)
screenshot_mode = kwargs.get('screenshot_mode', None)
def _action_decorator(func: Callable):
nonlocal pass_through
action = _register(_placeholder, name, description)
pass_through = kwargs.get('pass_through', False)
if pass_through:
return func
else:
def _wrapper(*args: P.args, **kwargs: P.kwargs):
current_callstack.append(action)
vars = ContextStackVars.push(screenshot_mode=screenshot_mode)
ret = func(*args, **kwargs)
ContextStackVars.pop()
current_callstack.pop()
return ret
setattr(_wrapper, ATTR_ORIGINAL_FUNC, func)
setattr(_wrapper, ATTR_ACTION_MARK, True)
action.func = _wrapper
return _wrapper
return _action_decorator
def tasks_from_id(task_ids: list[str]) -> list[Task]:
result = []
for tid in task_ids:
target = next(task for task in task_registry.values() if task.id == tid)
if target is None:
raise TaskNotFoundError(f'Task "{tid}" not found.')
result.append(target)
return result