feat(core): Task 新增 pre、post、regular、manual 四种 run_at 类型

This commit is contained in:
XcantloadX 2025-07-26 20:37:14 +08:00
parent 09252c5aa1
commit 0b7054e897
2 changed files with 44 additions and 9 deletions

View File

@ -15,6 +15,13 @@ from kotonebot.backend.context import task_registry, action_registry, Task, Acti
from kotonebot.errors import StopCurrentTask, UserFriendlyError
from kotonebot.interop.win.task_dialog import TaskDialog
@dataclass
class PostTaskContext:
has_error: bool
exception: Exception | None
log_stream = io.StringIO()
stream_handler = logging.StreamHandler(log_stream)
stream_handler.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] [%(name)s] [%(filename)s:%(lineno)d] - %(message)s'))
@ -174,20 +181,37 @@ class KotoneBot:
self._on_after_init_context()
vars.flow.clear_interrupt()
pre_tasks = [task for task in tasks if task.run_at == 'pre']
regular_tasks = [task for task in tasks if task.run_at == 'regular']
post_tasks = [task for task in tasks if task.run_at == 'post']
if by_priority:
tasks = sorted(tasks, key=lambda x: x.priority, reverse=True)
for task in tasks:
pre_tasks = sorted(pre_tasks, key=lambda x: x.priority, reverse=True)
regular_tasks = sorted(regular_tasks, key=lambda x: x.priority, reverse=True)
post_tasks = sorted(post_tasks, key=lambda x: x.priority, reverse=True)
all_tasks = pre_tasks + regular_tasks + post_tasks
for task in all_tasks:
self.events.task_status_changed.trigger(task, 'pending')
for task in tasks:
has_error = False
exception: Exception | None = None
for task in all_tasks:
logger.info(f'Task started: {task.name}')
self.events.task_status_changed.trigger(task, 'running')
if self.debug:
task.func()
if task.run_at == 'post':
task.func(PostTaskContext(has_error, exception))
else:
task.func()
else:
try:
task.func()
if task.run_at == 'post':
task.func(PostTaskContext(has_error, exception))
else:
task.func()
self.events.task_status_changed.trigger(task, 'finished')
except StopCurrentTask:
logger.info(f'Task skipped/stopped: {task.name}')
@ -195,7 +219,7 @@ class KotoneBot:
# 用户中止
except KeyboardInterrupt as e:
logger.exception('Keyboard interrupt detected.')
for task1 in tasks[tasks.index(task):]:
for task1 in all_tasks[all_tasks.index(task):]:
self.events.task_status_changed.trigger(task1, 'cancelled')
vars.flow.clear_interrupt()
break
@ -203,6 +227,8 @@ class KotoneBot:
except UserFriendlyError as e:
logger.error(f'Task failed: {task.name}')
logger.exception(f'Error: ')
has_error = True
exception = e
dialog = TaskDialog(
title='琴音小助手',
common_buttons=0,
@ -217,12 +243,14 @@ class KotoneBot:
except Exception as e:
logger.error(f'Task failed: {task.name}')
logger.exception(f'Error: ')
has_error = True
exception = e
report_path = None
if self.auto_save_error_report:
raise NotImplementedError
self.events.task_status_changed.trigger(task, 'error')
if not self.resume_on_error:
for task1 in tasks[tasks.index(task)+1:]:
for task1 in all_tasks[all_tasks.index(task)+1:]:
self.events.task_status_changed.trigger(task1, 'cancelled')
break
logger.info(f'Task ended: {task.name}')

View File

@ -1,5 +1,5 @@
import logging
from typing import Callable, ParamSpec, TypeVar, overload
from typing import Callable, ParamSpec, TypeVar, overload, Literal
from dataclasses import dataclass
@ -10,6 +10,9 @@ P = ParamSpec('P')
R = TypeVar('R')
logger = logging.getLogger(__name__)
TaskRunAtType = Literal['pre', 'post', 'manual', 'regular'] | str
@dataclass
class Task:
name: str
@ -20,6 +23,8 @@ class Task:
"""
任务优先级数字越大优先级越高
"""
run_at: TaskRunAtType = 'regular'
@dataclass
class Action:
@ -47,6 +52,7 @@ def task(
pass_through: bool = False,
priority: int = 0,
screenshot_mode: ScreenshotMode = 'auto',
run_at: TaskRunAtType = 'regular'
):
"""
`task` 装饰器用于标记一个函数为任务函数
@ -58,6 +64,7 @@ def task(
默认情况下 @task 装饰器会包裹任务函数跟踪其执行情况
如果不想跟踪则设置此参数为 False
:param priority: 任务优先级数字越大优先级越高
:param run_at: 任务运行时间
"""
# 设置 ID
# 获取 caller 信息
@ -66,7 +73,7 @@ def task(
description = description or func.__doc__ or ''
# TODO: task_id 冲突检测
task_id = task_id or func.__name__
task = Task(name, task_id, description, _placeholder, priority)
task = Task(name, task_id, description, _placeholder, priority, run_at)
task_registry[name] = task
logger.debug(f'Task "{name}" registered.')
if pass_through: