feat(core): 引入 HintBox 并优化 OCR

1. 引入 HintBox 定义
2. OCR 函数支持指定识别区域与 HintBox
3. 优化小图 OCR 识别
This commit is contained in:
XcantloadX 2025-02-04 22:43:31 +08:00
parent 496b10cac3
commit feb1dedb69
3 changed files with 220 additions and 38 deletions

View File

@ -16,6 +16,7 @@ from typing import (
Generic, Generic,
Type, Type,
) )
from typing_extensions import deprecated
import cv2 import cv2
from cv2.typing import MatLike from cv2.typing import MatLike
@ -39,7 +40,7 @@ from kotonebot.backend.color import find_rgb
from kotonebot.backend.ocr import Ocr, OcrResult, jp, en, StringMatchFunction from kotonebot.backend.ocr import Ocr, OcrResult, jp, en, StringMatchFunction
from kotonebot.config.manager import load_config, save_config from kotonebot.config.manager import load_config, save_config
from kotonebot.config.base_config import UserConfig from kotonebot.config.base_config import UserConfig
from kotonebot.backend.core import Image from kotonebot.backend.core import Image, HintBox
OcrLanguage = Literal['jp', 'en'] OcrLanguage = Literal['jp', 'en']
ScreenshotMode = Literal['auto', 'manual', 'manual-inherit'] ScreenshotMode = Literal['auto', 'manual', 'manual-inherit']
@ -237,6 +238,7 @@ class ContextOcr:
... ...
@overload @overload
@deprecated('使用 `ocr.raw().ocr()` 代替')
def ocr(self, img: 'MatLike') -> list[OcrResult]: def ocr(self, img: 'MatLike') -> list[OcrResult]:
"""OCR 指定图像。""" """OCR 指定图像。"""
... ...
@ -246,32 +248,32 @@ class ContextOcr:
if img is None: if img is None:
return self.__engine.ocr(ContextStackVars.ensure_current().screenshot) return self.__engine.ocr(ContextStackVars.ensure_current().screenshot)
return self.__engine.ocr(img) return self.__engine.ocr(img)
@overload
def find(self, pattern: str | re.Pattern | StringMatchFunction) -> OcrResult | None:
...
@overload def find(
def find(self, img: 'MatLike', pattern: str | re.Pattern | StringMatchFunction) -> OcrResult | None: self,
... pattern: str | re.Pattern | StringMatchFunction,
*,
def find(self, *args, **kwargs) -> OcrResult | None: hint: HintBox | None = None,
"""检查指定图像是否包含指定文本。""" rect: Rect | None = None,
if len(args) == 1 and len(kwargs) == 0: ) -> OcrResult | None:
ret = self.__engine.find(ContextStackVars.ensure_current().screenshot, args[0]) """检查当前设备画面是否包含指定文本。"""
self.context.device.last_find = ret ret = self.__engine.find(
return ret ContextStackVars.ensure_current().screenshot,
elif len(args) == 2 and len(kwargs) == 0: pattern,
ret = self.__engine.find(args[0], args[1]) hint=hint,
self.context.device.last_find = ret rect=rect,
return ret )
else: self.context.device.last_find = ret
raise ValueError("Invalid arguments") return ret
def expect( def expect(
self, self,
pattern: str | re.Pattern | StringMatchFunction pattern: str | re.Pattern | StringMatchFunction
) -> OcrResult: ) -> OcrResult:
""" """
检查当前设备画面是否包含指定文本 检查当前设备画面是否包含指定文本

View File

@ -1,10 +1,14 @@
import logging import logging
from typing import Callable, ParamSpec, TypeVar, overload from typing import Callable, ParamSpec, TypeVar, overload, TYPE_CHECKING
import cv2 import cv2
from cv2.typing import MatLike from cv2.typing import MatLike
if TYPE_CHECKING:
from kotonebot.backend.util import Rect
class Ocr: class Ocr:
def __init__( def __init__(
self, self,
@ -52,11 +56,58 @@ class Image:
else: else:
return f'<Image: "{self.name}" at {self.path}>' return f'<Image: "{self.name}" at {self.path}>'
class HintBox(tuple[int, int, int, int]):
def __new__(
cls,
x1: int,
y1: int,
x2: int,
y2: int,
*,
source_resolution: tuple[int, int],
):
return super().__new__(cls, [x1, y1, x2, y2])
def __init__(
self,
x1: int,
y1: int,
x2: int,
y2: int,
*,
name: str | None = None,
description: str | None = None,
source_resolution: tuple[int, int],
):
self.x1 = x1
self.y1 = y1
self.x2 = x2
self.y2 = y2
self.name = name
self.description = description
self.source_resolution = source_resolution
@property
def width(self) -> int:
return self.x2 - self.x1
@property
def height(self) -> int:
return self.y2 - self.y1
@property
def rect(self) -> 'Rect':
return self.x1, self.y1, self.width, self.height
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@overload @overload
def image(data: str) -> Image: def image(data: str) -> Image:
"""从文件路径创建 Image 对象。""" """从文件路径创建 Image 对象。"""
... ...
@overload @overload
@ -72,3 +123,10 @@ def image(data: str | MatLike) -> Image:
def ocr(text: str | Callable[[str], bool], language: str = 'jp') -> Ocr: def ocr(text: str | Callable[[str], bool], language: str = 'jp') -> Ocr:
return Ocr(text, language=language) return Ocr(text, language=language)
if __name__ == '__main__':
hint_box = HintBox(100, 100, 200, 200, source_resolution=(1920, 1080))
print(hint_box.rect)
print(hint_box.width)
print(hint_box.height)

View File

@ -3,14 +3,17 @@ import unicodedata
from functools import lru_cache from functools import lru_cache
from typing import Callable, NamedTuple from typing import Callable, NamedTuple
from .util import Rect, grayscaled, res_path
from .debug import result as debug_result, debug
import cv2 import cv2
import numpy as np
from cv2.typing import MatLike from cv2.typing import MatLike
from thefuzz import fuzz as _fuzz from thefuzz import fuzz as _fuzz
from rapidocr_onnxruntime import RapidOCR from rapidocr_onnxruntime import RapidOCR
from .util import Rect, grayscaled, res_path
from .debug import result as debug_result, debug
from .core import HintBox
_engine_jp = RapidOCR( _engine_jp = RapidOCR(
rec_model_path=res_path('res/models/japan_PP-OCRv3_rec_infer.onnx'), rec_model_path=res_path('res/models/japan_PP-OCRv3_rec_infer.onnx'),
use_det=True, use_det=True,
@ -89,6 +92,46 @@ def bounding_box(points: list[tuple[int, int]]) -> tuple[int, int, int, int]:
topleft, bottomright = _bounding_box(points) topleft, bottomright = _bounding_box(points)
return (topleft[0], topleft[1], bottomright[0] - topleft[0], bottomright[1] - topleft[1]) return (topleft[0], topleft[1], bottomright[0] - topleft[0], bottomright[1] - topleft[1])
def pad_to(img: MatLike, target_size: tuple[int, int], rgb: tuple[int, int, int] = (255, 255, 255)) -> MatLike:
"""将图像居中填充/缩放到指定大小。缺少部分使用指定颜色填充。"""
h, w = img.shape[:2]
tw, th = target_size
# 如果图像宽高都大于目标大小,则不进行填充
if h >= th and w >= tw:
return img
# 计算宽高比
aspect = w / h
target_aspect = tw / th
# 按比例缩放
if aspect > target_aspect:
# 图像较宽,以目标宽度为准
new_w = tw
new_h = int(tw / aspect)
else:
# 图像较高,以目标高度为准
new_h = th
new_w = int(th * aspect)
# 缩放图像
if new_w != w or new_h != h:
img = cv2.resize(img, (new_w, new_h))
# 创建目标画布并填充
ret = np.full((th, tw, 3), rgb, dtype=np.uint8)
# 计算需要填充的宽高
pad_h = th - new_h
pad_w = tw - new_w
# 将缩放后的图像居中放置
ret[
pad_h // 2:pad_h // 2 + new_h,
pad_w // 2:pad_w // 2 + new_w, :] = img
return ret
def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike': def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike':
import numpy as np import numpy as np
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
@ -153,17 +196,40 @@ class Ocr:
def __init__(self, engine: RapidOCR): def __init__(self, engine: RapidOCR):
self.__engine = engine self.__engine = engine
# TODO: 考虑缓存 OCR 结果,避免重复调用。 # TODO: 考虑缓存 OCR 结果,避免重复调用。
def ocr(self, img: 'MatLike') -> list[OcrResult]: def ocr(
self,
img: 'MatLike',
*,
rect: Rect | None = None,
pad: bool = True,
) -> list[OcrResult]:
""" """
OCR 一个 cv2 的图像注意识别结果中的**全角字符会被转换为半角字符** OCR 一个 cv2 的图像注意识别结果中的**全角字符会被转换为半角字符**
:param rect: 如果指定则只识别指定矩形区域
:param pad:
是否将过小的图像尺寸 < 631x631的图像填充到 631x631
默认为 True
对于 PaddleOCR 模型图片尺寸太小会降低准确率
将图片周围填充放大有助于提高准确率降低耗时
:return: 所有识别结果 :return: 所有识别结果
""" """
if rect is not None:
x, y, w, h = rect
img = img[y:y+h, x:x+w]
original_img = img
if pad:
# TODO: 详细研究哪个尺寸最佳,以及背景颜色、图片位置是否对准确率与耗时有影响
# https://blog.csdn.net/YY007H/article/details/124973777
original_img = img.copy()
img = pad_to(img, (631, 631))
img_content = grayscaled(img) img_content = grayscaled(img)
result, elapse = self.__engine(img_content) result, elapse = self.__engine(img_content)
if result is None: if result is None:
return [] return []
ret = [OcrResult( ret = [OcrResult(
text=unicodedata.normalize('NFKC', r[1]).replace('ą', 'a'), # HACK: 识别结果中包含奇怪的符号,暂时替换掉 text=unicodedata.normalize('NFKC', r[1]).replace('ą', 'a'), # HACK: 识别结果中包含奇怪的符号,暂时替换掉
@ -177,7 +243,7 @@ class Ocr:
result_image = _draw_result(img, ret) result_image = _draw_result(img, ret)
debug_result( debug_result(
'ocr', 'ocr',
[result_image, img], [result_image, original_img],
f"result: \n" + \ f"result: \n" + \
"<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \ "<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \
"\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.2f}</td></tr>" for r in ret]) + \ "\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.2f}</td></tr>" for r in ret]) + \
@ -185,22 +251,50 @@ class Ocr:
) )
return ret return ret
def find(self, img: 'MatLike', text: str | re.Pattern | StringMatchFunction) -> OcrResult | None: def find(
self,
img: 'MatLike',
text: str | re.Pattern | StringMatchFunction,
*,
hint: HintBox | None = None,
rect: Rect | None = None,
pad: bool = True,
) -> OcrResult | None:
""" """
寻找指定文本 寻找指定文本
:param hint: 如果指定则首先只识别 HintBox 范围内的文本若未命中再全局寻找
:param rect: 如果指定则只识别指定矩形区域此参数优先级低于 `hint`
:param pad: `ocr` `pad` 参数
:return: 找到的文本如果未找到则返回 None :return: 找到的文本如果未找到则返回 None
""" """
for result in self.ocr(img): if hint is not None:
if ret := self.find(img, text, rect=hint):
return ret
for result in self.ocr(img, rect=rect, pad=pad):
if _is_match(result.text, text): if _is_match(result.text, text):
return result return result
return None return None
def expect(self, img: 'MatLike', text: str | re.Pattern | StringMatchFunction) -> OcrResult: def expect(
self,
img: 'MatLike',
text: str | re.Pattern | StringMatchFunction,
*,
hint: HintBox | None = None,
rect: Rect | None = None,
pad: bool = True,
) -> OcrResult:
""" """
寻找指定文本如果未找到则抛出异常 寻找指定文本如果未找到则抛出异常
:param hint: 如果指定则首先只识别 HintBox 范围内的文本若未命中再全局寻找
:param rect: 如果指定则只识别指定矩形区域此参数优先级高于 `hint`
:param pad: `ocr` `pad` 参数
:return: 找到的文本
""" """
ret = self.find(img, text) ret = self.find(img, text, hint=hint, rect=rect, pad=pad)
if ret is None: if ret is None:
raise TextNotFoundError(text, img) raise TextNotFoundError(text, img)
return ret return ret
@ -212,10 +306,38 @@ jp = Ocr(_engine_jp)
en = Ocr(_engine_en) en = Ocr(_engine_en)
"""英语 OCR 引擎。""" """英语 OCR 引擎。"""
if __name__ == '__main__': if __name__ == '__main__':
import time
from pprint import pprint as print from pprint import pprint as print
import cv2 import cv2
img_path = 'test_images/acquire_pdorinku.png' print('small')
img_path = r"C:\Users\user\Downloads\Screenshot_2025.01.28_21.00.40.172.png"
img = cv2.imread(img_path) img = cv2.imread(img_path)
time_start = time.time()
result1 = jp.ocr(img) result1 = jp.ocr(img)
print(result1) time_end = time.time()
print(time_end - time_start)
# print(result1)
for i in np.linspace(300, 1000, 20):
i = int(i)
print('small-pad: ' + f'{str(i)}x{str(i)}')
img = pad_to(img, (int(i), int(i)))
time_start = time.time()
result1 = jp.ocr(img)
time_end = time.time()
print(time_end - time_start)
# print(result1)
print('big')
img_path = r"C:\Users\user\Pictures\BlueStacks\Screenshot_2025.01.28_21.00.40.172.png"
img = cv2.imread(img_path)
time_start = time.time()
result1 = jp.ocr(img)
time_end = time.time()
print(time_end - time_start)
# print(result1)