feat(core): Task 新增 pre、post、regular、manual 四种 run_at 类型
This commit is contained in:
parent
09252c5aa1
commit
0b7054e897
|
@ -15,6 +15,13 @@ from kotonebot.backend.context import task_registry, action_registry, Task, Acti
|
|||
from kotonebot.errors import StopCurrentTask, UserFriendlyError
|
||||
from kotonebot.interop.win.task_dialog import TaskDialog
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostTaskContext:
|
||||
has_error: bool
|
||||
exception: Exception | None
|
||||
|
||||
|
||||
log_stream = io.StringIO()
|
||||
stream_handler = logging.StreamHandler(log_stream)
|
||||
stream_handler.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] [%(name)s] [%(filename)s:%(lineno)d] - %(message)s'))
|
||||
|
@ -174,20 +181,37 @@ class KotoneBot:
|
|||
self._on_after_init_context()
|
||||
vars.flow.clear_interrupt()
|
||||
|
||||
pre_tasks = [task for task in tasks if task.run_at == 'pre']
|
||||
regular_tasks = [task for task in tasks if task.run_at == 'regular']
|
||||
post_tasks = [task for task in tasks if task.run_at == 'post']
|
||||
|
||||
if by_priority:
|
||||
tasks = sorted(tasks, key=lambda x: x.priority, reverse=True)
|
||||
for task in tasks:
|
||||
pre_tasks = sorted(pre_tasks, key=lambda x: x.priority, reverse=True)
|
||||
regular_tasks = sorted(regular_tasks, key=lambda x: x.priority, reverse=True)
|
||||
post_tasks = sorted(post_tasks, key=lambda x: x.priority, reverse=True)
|
||||
|
||||
all_tasks = pre_tasks + regular_tasks + post_tasks
|
||||
for task in all_tasks:
|
||||
self.events.task_status_changed.trigger(task, 'pending')
|
||||
|
||||
for task in tasks:
|
||||
has_error = False
|
||||
exception: Exception | None = None
|
||||
|
||||
for task in all_tasks:
|
||||
logger.info(f'Task started: {task.name}')
|
||||
self.events.task_status_changed.trigger(task, 'running')
|
||||
|
||||
if self.debug:
|
||||
task.func()
|
||||
if task.run_at == 'post':
|
||||
task.func(PostTaskContext(has_error, exception))
|
||||
else:
|
||||
task.func()
|
||||
else:
|
||||
try:
|
||||
task.func()
|
||||
if task.run_at == 'post':
|
||||
task.func(PostTaskContext(has_error, exception))
|
||||
else:
|
||||
task.func()
|
||||
self.events.task_status_changed.trigger(task, 'finished')
|
||||
except StopCurrentTask:
|
||||
logger.info(f'Task skipped/stopped: {task.name}')
|
||||
|
@ -195,7 +219,7 @@ class KotoneBot:
|
|||
# 用户中止
|
||||
except KeyboardInterrupt as e:
|
||||
logger.exception('Keyboard interrupt detected.')
|
||||
for task1 in tasks[tasks.index(task):]:
|
||||
for task1 in all_tasks[all_tasks.index(task):]:
|
||||
self.events.task_status_changed.trigger(task1, 'cancelled')
|
||||
vars.flow.clear_interrupt()
|
||||
break
|
||||
|
@ -203,6 +227,8 @@ class KotoneBot:
|
|||
except UserFriendlyError as e:
|
||||
logger.error(f'Task failed: {task.name}')
|
||||
logger.exception(f'Error: ')
|
||||
has_error = True
|
||||
exception = e
|
||||
dialog = TaskDialog(
|
||||
title='琴音小助手',
|
||||
common_buttons=0,
|
||||
|
@ -217,12 +243,14 @@ class KotoneBot:
|
|||
except Exception as e:
|
||||
logger.error(f'Task failed: {task.name}')
|
||||
logger.exception(f'Error: ')
|
||||
has_error = True
|
||||
exception = e
|
||||
report_path = None
|
||||
if self.auto_save_error_report:
|
||||
raise NotImplementedError
|
||||
self.events.task_status_changed.trigger(task, 'error')
|
||||
if not self.resume_on_error:
|
||||
for task1 in tasks[tasks.index(task)+1:]:
|
||||
for task1 in all_tasks[all_tasks.index(task)+1:]:
|
||||
self.events.task_status_changed.trigger(task1, 'cancelled')
|
||||
break
|
||||
logger.info(f'Task ended: {task.name}')
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import logging
|
||||
from typing import Callable, ParamSpec, TypeVar, overload
|
||||
from typing import Callable, ParamSpec, TypeVar, overload, Literal
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
|
@ -10,6 +10,9 @@ P = ParamSpec('P')
|
|||
R = TypeVar('R')
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TaskRunAtType = Literal['pre', 'post', 'manual', 'regular'] | str
|
||||
|
||||
|
||||
@dataclass
|
||||
class Task:
|
||||
name: str
|
||||
|
@ -20,6 +23,8 @@ class Task:
|
|||
"""
|
||||
任务优先级,数字越大优先级越高。
|
||||
"""
|
||||
run_at: TaskRunAtType = 'regular'
|
||||
|
||||
|
||||
@dataclass
|
||||
class Action:
|
||||
|
@ -47,6 +52,7 @@ def task(
|
|||
pass_through: bool = False,
|
||||
priority: int = 0,
|
||||
screenshot_mode: ScreenshotMode = 'auto',
|
||||
run_at: TaskRunAtType = 'regular'
|
||||
):
|
||||
"""
|
||||
`task` 装饰器,用于标记一个函数为任务函数。
|
||||
|
@ -58,6 +64,7 @@ def task(
|
|||
默认情况下, @task 装饰器会包裹任务函数,跟踪其执行情况。
|
||||
如果不想跟踪,则设置此参数为 False。
|
||||
:param priority: 任务优先级,数字越大优先级越高。
|
||||
:param run_at: 任务运行时间。
|
||||
"""
|
||||
# 设置 ID
|
||||
# 获取 caller 信息
|
||||
|
@ -66,7 +73,7 @@ def task(
|
|||
description = description or func.__doc__ or ''
|
||||
# TODO: task_id 冲突检测
|
||||
task_id = task_id or func.__name__
|
||||
task = Task(name, task_id, description, _placeholder, priority)
|
||||
task = Task(name, task_id, description, _placeholder, priority, run_at)
|
||||
task_registry[name] = task
|
||||
logger.debug(f'Task "{name}" registered.')
|
||||
if pass_through:
|
||||
|
|
Loading…
Reference in New Issue