refactor(task): 移除配置文件路径的硬编码
This commit is contained in:
parent
a967732f31
commit
6fa9250950
|
@ -1,22 +1,17 @@
|
|||
import io
|
||||
import os
|
||||
import logging
|
||||
import pkgutil
|
||||
import importlib
|
||||
import threading
|
||||
from typing_extensions import Self
|
||||
from dataclasses import dataclass, field
|
||||
import threading
|
||||
import traceback
|
||||
import os
|
||||
import zipfile
|
||||
import cv2
|
||||
from datetime import datetime
|
||||
import io
|
||||
from typing import Any, Literal, Callable, Generic, TypeVar, ParamSpec
|
||||
|
||||
from kotonebot.backend.context import Task, Action
|
||||
from kotonebot.backend.context import init_context, vars
|
||||
from kotonebot.backend.context import task_registry, action_registry, current_callstack, Task, Action
|
||||
from kotonebot.client.host.protocol import Instance
|
||||
from kotonebot.ui import user
|
||||
from kotonebot.client.host.protocol import Instance
|
||||
from kotonebot.backend.context import init_context, vars
|
||||
from kotonebot.backend.context import task_registry, action_registry, Task, Action
|
||||
|
||||
log_stream = io.StringIO()
|
||||
stream_handler = logging.StreamHandler(log_stream)
|
||||
|
@ -39,41 +34,6 @@ class RunStatus:
|
|||
def interrupt(self):
|
||||
vars.interrupted.set()
|
||||
|
||||
def _save_error_report(
|
||||
exception: Exception,
|
||||
*,
|
||||
path: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
保存错误报告
|
||||
|
||||
:param path: 保存的路径。若为 `None`,则保存到 `./reports/{YY-MM-DD HH-MM-SS}.zip`。
|
||||
:return: 保存的路径
|
||||
"""
|
||||
from kotonebot import device
|
||||
try:
|
||||
if path is None:
|
||||
path = f'./reports/{datetime.now().strftime("%Y-%m-%d %H-%M-%S")}.zip'
|
||||
exception_msg = '\n'.join(traceback.format_exception(exception))
|
||||
task_callstack = '\n'.join([f'{i+1}. name={task.name} priority={task.priority}' for i, task in enumerate(current_callstack)])
|
||||
screenshot = device.screenshot()
|
||||
logs = log_stream.getvalue()
|
||||
with open('config.json', 'r', encoding='utf-8') as f:
|
||||
config_content = f.read()
|
||||
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
with zipfile.ZipFile(path, 'w') as zipf:
|
||||
zipf.writestr('exception.txt', exception_msg)
|
||||
zipf.writestr('task_callstack.txt', task_callstack)
|
||||
zipf.writestr('screenshot.png', cv2.imencode('.png', screenshot)[1].tobytes())
|
||||
zipf.writestr('config.json', config_content)
|
||||
zipf.writestr('logs.txt', logs)
|
||||
return path
|
||||
except Exception as e:
|
||||
logger.exception(f'Failed to save error report:')
|
||||
return ''
|
||||
|
||||
# Modified from https://stackoverflow.com/questions/70982565/how-do-i-make-an-event-listener-with-decorators-in-python
|
||||
Params = ParamSpec('Params')
|
||||
Return = TypeVar('Return')
|
||||
|
@ -125,11 +85,12 @@ class KotoneBot:
|
|||
def __init__(
|
||||
self,
|
||||
module: str,
|
||||
config_path: str,
|
||||
config_type: type = dict[str, Any],
|
||||
*,
|
||||
debug: bool = False,
|
||||
resume_on_error: bool = False,
|
||||
auto_save_error_report: bool = True,
|
||||
auto_save_error_report: bool = False,
|
||||
):
|
||||
"""
|
||||
初始化 KotoneBot。
|
||||
|
@ -141,6 +102,7 @@ class KotoneBot:
|
|||
:param auto_save_error_report: 是否自动保存错误报告。
|
||||
"""
|
||||
self.module = module
|
||||
self.config_path = config_path
|
||||
self.config_type = config_type
|
||||
# HACK: 硬编码
|
||||
self.current_config: int | str = 0
|
||||
|
@ -150,6 +112,9 @@ class KotoneBot:
|
|||
self.events = KotoneBotEvents()
|
||||
self.backend_instance: Instance | None = None
|
||||
|
||||
if self.auto_save_error_report:
|
||||
raise NotImplementedError('auto_save_error_report not implemented yet.')
|
||||
|
||||
def initialize(self):
|
||||
"""
|
||||
初始化并载入所有任务和动作。
|
||||
|
@ -176,7 +141,7 @@ class KotoneBot:
|
|||
from kotonebot.client.host import create_custom
|
||||
from kotonebot.config.manager import load_config
|
||||
# HACK: 硬编码
|
||||
config = load_config('config.json', type=self.config_type)
|
||||
config = load_config(self.config_path, type=self.config_type)
|
||||
config = config.user_configs[0]
|
||||
logger.info('Checking backend...')
|
||||
if config.backend.type == 'custom' and config.backend.check_emulator:
|
||||
|
@ -207,7 +172,7 @@ class KotoneBot:
|
|||
按优先级顺序运行所有任务。
|
||||
"""
|
||||
self.check_backend()
|
||||
init_context(config_type=self.config_type)
|
||||
init_context(config_path=self.config_path, config_type=self.config_type)
|
||||
vars.interrupted.clear()
|
||||
|
||||
if by_priority:
|
||||
|
@ -238,7 +203,7 @@ class KotoneBot:
|
|||
logger.exception(f'Error: ')
|
||||
report_path = None
|
||||
if self.auto_save_error_report:
|
||||
report_path = _save_error_report(e)
|
||||
raise NotImplementedError
|
||||
self.events.task_status_changed.trigger(task, 'error')
|
||||
if not self.resume_on_error:
|
||||
for task1 in tasks[tasks.index(task)+1:]:
|
||||
|
|
|
@ -601,9 +601,9 @@ class ContextDebug:
|
|||
|
||||
V = TypeVar('V')
|
||||
class ContextConfig(Generic[T]):
|
||||
def __init__(self, context: 'Context', config_type: Type[T] = dict[str, Any]):
|
||||
def __init__(self, context: 'Context', config_path: str = 'config.json', config_type: Type[T] = dict[str, Any]):
|
||||
self.context = context
|
||||
self.config_path: str = 'config.json'
|
||||
self.config_path: str = config_path
|
||||
self.current_key: int | str = 0
|
||||
self.config_type: Type = config_type
|
||||
self.root = load_config(self.config_path, type=config_type)
|
||||
|
@ -730,13 +730,13 @@ class ContextDevice(Device):
|
|||
|
||||
|
||||
class Context(Generic[T]):
|
||||
def __init__(self, config_type: Type[T], screenshot_impl: Optional[DeviceImpl] = None):
|
||||
def __init__(self, config_path: str, config_type: Type[T], screenshot_impl: Optional[DeviceImpl] = None):
|
||||
self.__ocr = ContextOcr(self)
|
||||
self.__image = ContextImage(self)
|
||||
self.__color = ContextColor(self)
|
||||
self.__vars = ContextGlobalVars()
|
||||
self.__debug = ContextDebug(self)
|
||||
self.__config = ContextConfig[T](self, config_type)
|
||||
self.__config = ContextConfig[T](self, config_path, config_type)
|
||||
|
||||
ip = self.config.current.backend.adb_ip
|
||||
port = self.config.current.backend.adb_port
|
||||
|
@ -851,6 +851,7 @@ next_wait_time: float = 0
|
|||
|
||||
def init_context(
|
||||
*,
|
||||
config_path: str = 'config.json',
|
||||
config_type: Type[T] = dict[str, Any],
|
||||
force: bool = False,
|
||||
screenshot_impl: Optional[DeviceImpl] = None
|
||||
|
@ -858,6 +859,7 @@ def init_context(
|
|||
"""
|
||||
初始化 Context 模块。
|
||||
|
||||
:param config_path: 配置文件路径。
|
||||
:param config_type: 配置数据类类型。
|
||||
配置数据类必须继承自 pydantic 的 `BaseModel`。
|
||||
默认为 `dict[str, Any]`,即普通的 JSON 数据,不包含任何类型信息。
|
||||
|
@ -869,7 +871,7 @@ def init_context(
|
|||
global _c, device, ocr, image, color, vars, debug, config
|
||||
if _c is not None and not force:
|
||||
return
|
||||
_c = Context(config_type=config_type, screenshot_impl=screenshot_impl)
|
||||
_c = Context(config_path=config_path, config_type=config_type, screenshot_impl=screenshot_impl)
|
||||
device._FORWARD_getter = lambda: _c.device # type: ignore
|
||||
ocr._FORWARD_getter = lambda: _c.ocr # type: ignore
|
||||
image._FORWARD_getter = lambda: _c.image # type: ignore
|
||||
|
|
|
@ -12,7 +12,7 @@ version = importlib.metadata.version('ksaa')
|
|||
# 主命令
|
||||
psr = argparse.ArgumentParser(description='Command-line interface for Kotone\'s Auto Assistant')
|
||||
psr.add_argument('-v', '--version', action='version', version='kaa v' + version)
|
||||
# psr.add_argument('-c', '--config', required=False, help='Path to the configuration file. Default: ./config.json')
|
||||
psr.add_argument('-c', '--config', default='./config.json', help='Path to the configuration file. Default: ./config.json')
|
||||
|
||||
# 子命令
|
||||
subparsers = psr.add_subparsers(dest='subcommands')
|
||||
|
@ -37,7 +37,7 @@ _kaa: Kaa | None = None
|
|||
def kaa() -> Kaa:
|
||||
global _kaa
|
||||
if _kaa is None:
|
||||
_kaa = Kaa()
|
||||
_kaa = Kaa(psr.parse_args().config)
|
||||
_kaa.initialize()
|
||||
return _kaa
|
||||
|
||||
|
@ -85,6 +85,8 @@ def main():
|
|||
sys.exit(task_invoke())
|
||||
elif args.task_command == 'list':
|
||||
sys.exit(task_list())
|
||||
else:
|
||||
raise ValueError(f'Unknown task command: {args.task_command}')
|
||||
elif args.subcommands == 'remote-server':
|
||||
sys.exit(remote_server())
|
||||
elif args.subcommands is None:
|
||||
|
|
|
@ -1178,7 +1178,7 @@ class KotoneBotUI:
|
|||
return app
|
||||
|
||||
def main(kaa: Kaa | None = None) -> None:
|
||||
kaa = kaa or Kaa()
|
||||
kaa = kaa or Kaa('./config.json')
|
||||
ui = KotoneBotUI(kaa)
|
||||
app = ui.create_ui()
|
||||
app.launch(inbrowser=True, show_error=True)
|
||||
|
|
|
@ -1,8 +1,13 @@
|
|||
import io
|
||||
import os
|
||||
import logging
|
||||
import importlib.metadata
|
||||
import traceback
|
||||
import zipfile
|
||||
from datetime import datetime
|
||||
|
||||
import cv2
|
||||
|
||||
from kotonebot import KotoneBot
|
||||
from ..common import BaseConfig, upgrade_config
|
||||
|
||||
|
@ -17,6 +22,10 @@ console_handler.setLevel(logging.CRITICAL)
|
|||
file_handler = logging.FileHandler(log_filename, encoding='utf-8')
|
||||
file_handler.setFormatter(log_formatter)
|
||||
|
||||
log_stream = io.StringIO()
|
||||
stream_handler = logging.StreamHandler(log_stream)
|
||||
stream_handler.setFormatter(logging.Formatter('[%(asctime)s] [%(levelname)s] [%(name)s] [%(filename)s:%(lineno)d] - %(message)s'))
|
||||
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.INFO)
|
||||
root_logger.addHandler(console_handler)
|
||||
|
@ -33,11 +42,49 @@ class Kaa(KotoneBot):
|
|||
"""
|
||||
琴音小助手 kaa 主类。由其他 GUI/TUI 调用。
|
||||
"""
|
||||
def __init__(self):
|
||||
super().__init__(module='kotonebot.kaa.tasks', config_type=BaseConfig)
|
||||
def __init__(self, config_path: str):
|
||||
super().__init__(module='kotonebot.kaa.tasks', config_path=config_path, config_type=BaseConfig)
|
||||
self.upgrade_msg = upgrade_msg
|
||||
self.version = importlib.metadata.version('ksaa')
|
||||
logger.info('Version: %s', self.version)
|
||||
|
||||
def set_log_level(self, level: int):
|
||||
console_handler.setLevel(level)
|
||||
console_handler.setLevel(level)
|
||||
|
||||
def dump_error_report(
|
||||
self,
|
||||
exception: Exception,
|
||||
*,
|
||||
path: str | None = None
|
||||
) -> str:
|
||||
"""
|
||||
保存错误报告
|
||||
|
||||
:param path: 保存的路径。若为 `None`,则保存到 `./reports/{YY-MM-DD HH-MM-SS}.zip`。
|
||||
:return: 保存的路径
|
||||
"""
|
||||
from kotonebot import device
|
||||
from kotonebot.backend.context import current_callstack
|
||||
try:
|
||||
if path is None:
|
||||
path = f'./reports/{datetime.now().strftime("%Y-%m-%d %H-%M-%S")}.zip'
|
||||
exception_msg = '\n'.join(traceback.format_exception(exception))
|
||||
task_callstack = '\n'.join(
|
||||
[f'{i + 1}. name={task.name} priority={task.priority}' for i, task in enumerate(current_callstack)])
|
||||
screenshot = device.screenshot()
|
||||
logs = log_stream.getvalue()
|
||||
with open(self.config_path, 'r', encoding='utf-8') as f:
|
||||
config_content = f.read()
|
||||
|
||||
if not os.path.exists(os.path.dirname(path)):
|
||||
os.makedirs(os.path.dirname(path))
|
||||
with zipfile.ZipFile(path, 'w') as zipf:
|
||||
zipf.writestr('exception.txt', exception_msg)
|
||||
zipf.writestr('task_callstack.txt', task_callstack)
|
||||
zipf.writestr('screenshot.png', cv2.imencode('.png', screenshot)[1].tobytes())
|
||||
zipf.writestr('config.json', config_content)
|
||||
zipf.writestr('logs.txt', logs)
|
||||
return path
|
||||
except Exception as e:
|
||||
logger.exception(f'Failed to save error report:')
|
||||
return ''
|
Loading…
Reference in New Issue