feat(core): 优化 OCR 匹配函数的调试日志显示

This commit is contained in:
XcantloadX 2025-02-11 15:47:32 +08:00
parent 6a47d7d878
commit 1016ef6532
2 changed files with 40 additions and 30 deletions

View File

@ -290,7 +290,10 @@ class ContextOcr:
def expect(
self,
pattern: str | re.Pattern | StringMatchFunction
pattern: str | re.Pattern | StringMatchFunction,
*,
rect: Rect | None = None,
hint: HintBox | None = None,
) -> OcrResult:
"""
@ -298,7 +301,7 @@ class ContextOcr:
`find()` 的区别在于`expect()` 未找到时会抛出异常
"""
ret = self.__engine.expect(ContextStackVars.ensure_current().screenshot, pattern)
ret = self.__engine.expect(ContextStackVars.ensure_current().screenshot, pattern, rect=rect, hint=hint)
self.context.device.last_find = ret.original_rect if ret else None
return ret
@ -307,7 +310,9 @@ class ContextOcr:
pattern: str | re.Pattern | StringMatchFunction,
timeout: float = DEFAULT_TIMEOUT,
*,
interval: float = DEFAULT_INTERVAL
interval: float = DEFAULT_INTERVAL,
rect: Rect | None = None,
hint: HintBox | None = None,
) -> OcrResult:
"""
等待指定文本出现
@ -316,7 +321,7 @@ class ContextOcr:
start_time = time.time()
while True:
result = self.find(pattern)
result = self.find(pattern, rect=rect, hint=hint)
if result is not None:
self.context.device.last_find = result.original_rect if result else None
@ -330,7 +335,9 @@ class ContextOcr:
pattern: str | re.Pattern | StringMatchFunction,
timeout: float = DEFAULT_TIMEOUT,
*,
interval: float = DEFAULT_INTERVAL
interval: float = DEFAULT_INTERVAL,
rect: Rect | None = None,
hint: HintBox | None = None,
) -> OcrResult | None:
"""
等待指定文本出现
@ -339,7 +346,7 @@ class ContextOcr:
start_time = time.time()
while True:
result = self.find(pattern)
result = self.find(pattern, rect=rect, hint=hint)
if result is not None:
self.context.device.last_find = result.original_rect if result else None
return result

View File

@ -118,30 +118,35 @@ class TextNotFoundError(Exception):
else:
super().__init__(f"Expected text not found: {pattern.__name__}")
class TextComparator:
def __init__(self, name: str, text: str, func: Callable[[str], bool]):
self.name = name
self.text = text
self.func = func
def __call__(self, text: str) -> bool:
return self.func(text)
def __repr__(self) -> str:
return f'{self.name}("{self.text}")'
@lru_cache(maxsize=1000)
def fuzz(text: str) -> Callable[[str], bool]:
def fuzz(text: str) -> TextComparator:
"""返回 fuzzy 算法的字符串匹配函数。"""
f = lambda s: _fuzz.ratio(s, text) > 90
f.__repr__ = lambda: f"fuzzy({text})"
f.__name__ = f"fuzzy({text})"
return f
func = lambda s: _fuzz.ratio(s, text) > 90
return TextComparator("fuzzy", text, func)
@lru_cache(maxsize=1000)
def regex(regex: str) -> Callable[[str], bool]:
def regex(regex: str) -> TextComparator:
"""返回正则表达式字符串匹配函数。"""
f = lambda s: re.match(regex, s) is not None
f.__repr__ = lambda: f"regex('{regex}')"
f.__name__ = f"regex('{regex}')"
return f
func = lambda s: re.match(regex, s) is not None
return TextComparator("regex", regex, func)
@lru_cache(maxsize=1000)
def contains(text: str) -> Callable[[str], bool]:
def contains(text: str) -> TextComparator:
"""返回包含指定文本的函数。"""
f = lambda s: text in s
f.__repr__ = lambda: f"contains('{text}')"
f.__name__ = f"contains('{text}')"
return f
func = lambda s: text in s
return TextComparator("contains", text, func)
@lru_cache(maxsize=1000)
def equals(
@ -149,7 +154,7 @@ def equals(
*,
remove_space: bool = False,
ignore_case: bool = True,
) -> Callable[[str], bool]:
) -> TextComparator:
"""
返回等于指定文本的函数
@ -168,11 +173,9 @@ def equals(
s = s.replace(' ', '').replace(' ', '')
return text == s
compare.__repr__ = lambda: f"equals('{text}')"
compare.__name__ = f"equals('{text}')"
return compare
return TextComparator("equals", text, compare)
def _is_match(text: str, pattern: re.Pattern | str | StringMatchFunction) -> bool:
def _is_match(text: str, pattern: re.Pattern | str | StringMatchFunction | TextComparator) -> bool:
if isinstance(pattern, re.Pattern):
return pattern.match(text) is not None
elif callable(pattern):
@ -392,9 +395,9 @@ class Ocr:
"""
if hint is not None:
if ret := self.find(img, text, rect=hint):
logger.debug(f"find: {text} with hint={hint} SUCCESS")
logger.debug(f"find: {text} SUCCESS [hint={hint}]")
return ret
logger.debug(f"find: {text} with hint={hint} FAILED. Retrying on whole image...")
logger.debug(f"find: {text} FAILED [hint={hint}]")
start_time = time.time()
results = self.ocr(img, rect=rect, pad=pad)
@ -405,8 +408,8 @@ class Ocr:
target = result
break
logger.debug(
f"find: {text} with rect={rect} elapsed={end_time - start_time:.3f}s " + \
f"{'SUCCESS' if target else 'FAILED'}"
f"find: {text} {'SUCCESS' if target else 'FAILED'} " + \
f"[elapsed={end_time - start_time:.3f}s] [rect={rect}]"
)
return target