kotones-auto-assistant/kotonebot/backend/loop.py

277 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import time
from functools import lru_cache, partial
from typing import Callable, Any, overload, Literal, Generic, TypeVar, cast, get_args, get_origin
from cv2.typing import MatLike
from kotonebot.util import Interval
from kotonebot import device, image, ocr
from kotonebot.backend.core import Image
from kotonebot.backend.ocr import TextComparator
from kotonebot.client.protocol import ClickableObjectProtocol
class LoopAction:
def __init__(self, loop: 'Loop', func: Callable[[], ClickableObjectProtocol | None]):
self.loop = loop
self.func = func
self.result: ClickableObjectProtocol | None = None
@property
def found(self):
"""
是否找到结果。若父 Loop 未在运行中,则返回 False。
"""
if not self.loop.running:
return False
return bool(self.result)
def __bool__(self):
return self.found
def reset(self):
"""
重置 LoopAction以复用此对象。
"""
self.result = None
def do(self):
"""
执行 LoopAction。
:return: 执行结果。
"""
if not self.loop.running:
return
if self.loop.found_anything:
# 本轮循环已执行任意操作,因此不需要再继续检测
return
self.result = self.func()
if self.result:
self.loop.found_anything = True
def click(self, *, at: tuple[int, int] | None = None):
"""
点击寻找结果。若结果为空,会跳过执行。
:return:
"""
if self.result:
if at is not None:
device.click(*at)
else:
device.click(self.result)
def call(self, func: Callable[[ClickableObjectProtocol], Any]):
pass
class Loop:
def __init__(
self,
*,
timeout: float = 300,
interval: float = 0.3,
auto_screenshot: bool = True
):
self.running = True
self.found_anything = False
self.auto_screenshot = auto_screenshot
"""
是否在每次循环开始时Loop.tick() 被调用时)截图。
"""
self.__last_loop: float = -1
self.__interval = Interval(interval)
self.screenshot: MatLike | None = None
"""上次截图时的图像数据。"""
def __iter__(self):
self.__interval.reset()
return self
def __next__(self):
if not self.running:
raise StopIteration
self.found_anything = False
self.__last_loop = time.time()
return self.tick()
def tick(self):
self.__interval.wait()
if self.auto_screenshot:
self.screenshot = device.screenshot()
self.__last_loop = time.time()
self.found_anything = False
return self
def exit(self):
"""
结束循环。
"""
self.running = False
@overload
def when(self, condition: Image) -> LoopAction:
...
@overload
def when(self, condition: TextComparator) -> LoopAction:
...
def when(self, condition: Any):
"""
判断某个条件是否成立。
:param condition:
:return:
"""
if isinstance(condition, Image):
func = partial(image.find, condition)
elif isinstance(condition, TextComparator):
func = partial(ocr.find, condition)
else:
raise ValueError('Invalid condition type.')
la = LoopAction(self, func)
la.reset()
la.do()
return la
def until(self, condition: Any):
"""
当满足指定条件时,结束循环。
等价于 ``loop.when(...).call(lambda _: loop.exit())``
"""
return self.when(condition).call(lambda _: self.exit())
def click_if(self, condition: Any, *, at: tuple[int, int] | None = None):
"""
检测指定对象是否出现,若出现,点击该对象或指定位置。
``click_if()`` 等价于 ``loop.when(...).click(...)``。
:param condition: 检测目标。
:param at: 点击位置。若为 None表示点击找到的目标。
"""
return self.when(condition).click(at=at)
StateType = TypeVar('StateType')
class StatedLoop(Loop, Generic[StateType]):
def __init__(
self,
states: list[Any] | None = None,
initial_state: StateType | None = None,
*,
timeout: float = 300,
interval: float = 0.3,
auto_screenshot: bool = True
):
self.__tmp_states = states
self.__tmp_initial_state = initial_state
self.state: StateType
super().__init__(timeout=timeout, interval=interval, auto_screenshot=auto_screenshot)
def __iter__(self):
# __retrive_state_values() 只能在非 __init__ 中调用
self.__retrive_state_values()
return super().__iter__()
def __retrive_state_values(self):
# HACK: __orig_class__ 是 undocumented 属性
if not hasattr(self, '__orig_class__'):
# 如果 Foo 不是以参数化泛型的方式实例化的,可能没有 __orig_class__
if self.state is None:
raise ValueError('Either specify `states` or use StatedLoop[Literal[...]] syntax.')
else:
generic_type_args = get_args(self.__orig_class__) # type: ignore
if len(generic_type_args) != 1:
raise ValueError('StatedLoop must have exactly one generic type argument.')
state_values = get_args(generic_type_args[0])
if not state_values:
raise ValueError('StatedLoop must have at least one state value.')
self.states = cast(tuple[StateType, ...], state_values)
self.state = self.__tmp_initial_state or self.states[0]
return state_values
def StatedLoop2(states: StateType) -> StatedLoop[StateType]:
state_values = get_args(states)
return cast(StatedLoop[StateType], Loop())
if __name__ == '__main__':
from kotonebot.kaa.tasks import R
from kotonebot.backend.ocr import contains
from kotonebot.backend.context import manual_context, init_context
# T = TypeVar('T')
# class Foo(Generic[T]):
# def get_literal_params(self) -> list | None:
# """
# 尝试获取泛型参数 T (如果它是 Literal 类型) 的参数列表。
# """
# # self.__orig_class__ 会是 Foo 的具体参数化类型,
# # 例如 Foo[Literal['p0', 'p1', 'p2', 'p3', 'ap']]
# if not hasattr(self, '__orig_class__'):
# # 如果 Foo 不是以参数化泛型的方式实例化的,可能没有 __orig_class__
# return None
#
# # generic_type_args 是传递给 Foo 的类型参数元组
# # 例如 (Literal['p0', 'p1', 'p2', 'p3', 'ap'],)
# generic_type_args = get_args(self.__orig_class__)
#
# if not generic_type_args:
# # Foo 没有类型参数
# return None
#
# # T_type 是 Foo 的第一个类型参数
# # 例如 Literal['p0', 'p1', 'p2', 'p3', 'ap']
# t_type = generic_type_args[0]
#
# # 检查 T_type 是否是 Literal 类型
# if get_origin(t_type) is Literal:
# # literal_args 是 Literal 类型的参数元组
# # 例如 ('p0', 'p1', 'p2', 'p3', 'ap')
# literal_args = get_args(t_type)
# return list(literal_args)
# else:
# # T 不是 Literal 类型
# return None
# f = Foo[Literal['p0', 'p1', 'p2', 'p3', 'ap']]()
# values = f.get_literal_params()
# 1
from typing_extensions import reveal_type
slp = StatedLoop[Literal['p0', 'p1', 'p2', 'p3', 'ap']]()
for l in slp:
reveal_type(l.states)
# init_context()
# manual_context().begin()
# for l in Loop():
# l.when(R.Produce.ButtonUse).click()
# l.when(R.Produce.ButtonRefillAP).click()
# l.when(contains("123")).click()
# l.click_if(contains("!23"), at=(1, 2))
# State = Literal['p0', 'p1', 'p2', 'p3', 'ap']
# for sl in StatedLoop[State]():
# match sl.state:
# case 'p0':
# sl.click_if(R.Produce.ButtonProduce)
# sl.click_if(contains('master'))
# sl.when(R.Produce.ButtonPIdolOverview).goto('p1')
# # AP 不足
# sl.when(R.Produce.TextAPInsufficient).goto('ap')
# case 'ap':
# pass
# # p1: 选择偶像
# case 'p1':
# sl.call(lambda _: select_idol(idol_skin_id), once=True)
# sl.when(R.Produce.TextAnotherIdolAvailableDialog).call(dialog.no)
# sl.click_if(R.Common.ButtonNextNoIcon)
# sl.until(R.Produce.TextStepIndicator2).goto('p2')
# case 'p2':
# sl.when(contains("123")).click()
# case 'p3':
# sl.click_if(contains("!23"), at=(1, 2))
# case _:
# assert_never(sl.state)