feat(core): 引入 HintBox 并优化 OCR

1. 引入 HintBox 定义
2. OCR 函数支持指定识别区域与 HintBox
3. 优化小图 OCR 识别
This commit is contained in:
XcantloadX 2025-02-04 22:43:31 +08:00
parent 496b10cac3
commit feb1dedb69
3 changed files with 220 additions and 38 deletions

View File

@ -16,6 +16,7 @@ from typing import (
Generic,
Type,
)
from typing_extensions import deprecated
import cv2
from cv2.typing import MatLike
@ -39,7 +40,7 @@ from kotonebot.backend.color import find_rgb
from kotonebot.backend.ocr import Ocr, OcrResult, jp, en, StringMatchFunction
from kotonebot.config.manager import load_config, save_config
from kotonebot.config.base_config import UserConfig
from kotonebot.backend.core import Image
from kotonebot.backend.core import Image, HintBox
OcrLanguage = Literal['jp', 'en']
ScreenshotMode = Literal['auto', 'manual', 'manual-inherit']
@ -237,6 +238,7 @@ class ContextOcr:
...
@overload
@deprecated('使用 `ocr.raw().ocr()` 代替')
def ocr(self, img: 'MatLike') -> list[OcrResult]:
"""OCR 指定图像。"""
...
@ -246,32 +248,32 @@ class ContextOcr:
if img is None:
return self.__engine.ocr(ContextStackVars.ensure_current().screenshot)
return self.__engine.ocr(img)
@overload
def find(self, pattern: str | re.Pattern | StringMatchFunction) -> OcrResult | None:
...
@overload
def find(self, img: 'MatLike', pattern: str | re.Pattern | StringMatchFunction) -> OcrResult | None:
...
def find(self, *args, **kwargs) -> OcrResult | None:
"""检查指定图像是否包含指定文本。"""
if len(args) == 1 and len(kwargs) == 0:
ret = self.__engine.find(ContextStackVars.ensure_current().screenshot, args[0])
self.context.device.last_find = ret
return ret
elif len(args) == 2 and len(kwargs) == 0:
ret = self.__engine.find(args[0], args[1])
self.context.device.last_find = ret
return ret
else:
raise ValueError("Invalid arguments")
def find(
self,
pattern: str | re.Pattern | StringMatchFunction,
*,
hint: HintBox | None = None,
rect: Rect | None = None,
) -> OcrResult | None:
"""检查当前设备画面是否包含指定文本。"""
ret = self.__engine.find(
ContextStackVars.ensure_current().screenshot,
pattern,
hint=hint,
rect=rect,
)
self.context.device.last_find = ret
return ret
def expect(
self,
pattern: str | re.Pattern | StringMatchFunction
) -> OcrResult:
"""
检查当前设备画面是否包含指定文本

View File

@ -1,10 +1,14 @@
import logging
from typing import Callable, ParamSpec, TypeVar, overload
from typing import Callable, ParamSpec, TypeVar, overload, TYPE_CHECKING
import cv2
from cv2.typing import MatLike
if TYPE_CHECKING:
from kotonebot.backend.util import Rect
class Ocr:
def __init__(
self,
@ -52,11 +56,58 @@ class Image:
else:
return f'<Image: "{self.name}" at {self.path}>'
class HintBox(tuple[int, int, int, int]):
def __new__(
cls,
x1: int,
y1: int,
x2: int,
y2: int,
*,
source_resolution: tuple[int, int],
):
return super().__new__(cls, [x1, y1, x2, y2])
def __init__(
self,
x1: int,
y1: int,
x2: int,
y2: int,
*,
name: str | None = None,
description: str | None = None,
source_resolution: tuple[int, int],
):
self.x1 = x1
self.y1 = y1
self.x2 = x2
self.y2 = y2
self.name = name
self.description = description
self.source_resolution = source_resolution
@property
def width(self) -> int:
return self.x2 - self.x1
@property
def height(self) -> int:
return self.y2 - self.y1
@property
def rect(self) -> 'Rect':
return self.x1, self.y1, self.width, self.height
logger = logging.getLogger(__name__)
@overload
def image(data: str) -> Image:
"""从文件路径创建 Image 对象。"""
...
@overload
@ -72,3 +123,10 @@ def image(data: str | MatLike) -> Image:
def ocr(text: str | Callable[[str], bool], language: str = 'jp') -> Ocr:
return Ocr(text, language=language)
if __name__ == '__main__':
hint_box = HintBox(100, 100, 200, 200, source_resolution=(1920, 1080))
print(hint_box.rect)
print(hint_box.width)
print(hint_box.height)

View File

@ -3,14 +3,17 @@ import unicodedata
from functools import lru_cache
from typing import Callable, NamedTuple
from .util import Rect, grayscaled, res_path
from .debug import result as debug_result, debug
import cv2
import numpy as np
from cv2.typing import MatLike
from thefuzz import fuzz as _fuzz
from rapidocr_onnxruntime import RapidOCR
from .util import Rect, grayscaled, res_path
from .debug import result as debug_result, debug
from .core import HintBox
_engine_jp = RapidOCR(
rec_model_path=res_path('res/models/japan_PP-OCRv3_rec_infer.onnx'),
use_det=True,
@ -89,6 +92,46 @@ def bounding_box(points: list[tuple[int, int]]) -> tuple[int, int, int, int]:
topleft, bottomright = _bounding_box(points)
return (topleft[0], topleft[1], bottomright[0] - topleft[0], bottomright[1] - topleft[1])
def pad_to(img: MatLike, target_size: tuple[int, int], rgb: tuple[int, int, int] = (255, 255, 255)) -> MatLike:
"""将图像居中填充/缩放到指定大小。缺少部分使用指定颜色填充。"""
h, w = img.shape[:2]
tw, th = target_size
# 如果图像宽高都大于目标大小,则不进行填充
if h >= th and w >= tw:
return img
# 计算宽高比
aspect = w / h
target_aspect = tw / th
# 按比例缩放
if aspect > target_aspect:
# 图像较宽,以目标宽度为准
new_w = tw
new_h = int(tw / aspect)
else:
# 图像较高,以目标高度为准
new_h = th
new_w = int(th * aspect)
# 缩放图像
if new_w != w or new_h != h:
img = cv2.resize(img, (new_w, new_h))
# 创建目标画布并填充
ret = np.full((th, tw, 3), rgb, dtype=np.uint8)
# 计算需要填充的宽高
pad_h = th - new_h
pad_w = tw - new_w
# 将缩放后的图像居中放置
ret[
pad_h // 2:pad_h // 2 + new_h,
pad_w // 2:pad_w // 2 + new_w, :] = img
return ret
def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike':
import numpy as np
from PIL import Image, ImageDraw, ImageFont
@ -153,17 +196,40 @@ class Ocr:
def __init__(self, engine: RapidOCR):
self.__engine = engine
# TODO: 考虑缓存 OCR 结果,避免重复调用。
def ocr(self, img: 'MatLike') -> list[OcrResult]:
def ocr(
self,
img: 'MatLike',
*,
rect: Rect | None = None,
pad: bool = True,
) -> list[OcrResult]:
"""
OCR 一个 cv2 的图像注意识别结果中的**全角字符会被转换为半角字符**
:param rect: 如果指定则只识别指定矩形区域
:param pad:
是否将过小的图像尺寸 < 631x631的图像填充到 631x631
默认为 True
对于 PaddleOCR 模型图片尺寸太小会降低准确率
将图片周围填充放大有助于提高准确率降低耗时
:return: 所有识别结果
"""
if rect is not None:
x, y, w, h = rect
img = img[y:y+h, x:x+w]
original_img = img
if pad:
# TODO: 详细研究哪个尺寸最佳,以及背景颜色、图片位置是否对准确率与耗时有影响
# https://blog.csdn.net/YY007H/article/details/124973777
original_img = img.copy()
img = pad_to(img, (631, 631))
img_content = grayscaled(img)
result, elapse = self.__engine(img_content)
if result is None:
return []
ret = [OcrResult(
text=unicodedata.normalize('NFKC', r[1]).replace('ą', 'a'), # HACK: 识别结果中包含奇怪的符号,暂时替换掉
@ -177,7 +243,7 @@ class Ocr:
result_image = _draw_result(img, ret)
debug_result(
'ocr',
[result_image, img],
[result_image, original_img],
f"result: \n" + \
"<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \
"\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.2f}</td></tr>" for r in ret]) + \
@ -185,22 +251,50 @@ class Ocr:
)
return ret
def find(self, img: 'MatLike', text: str | re.Pattern | StringMatchFunction) -> OcrResult | None:
def find(
self,
img: 'MatLike',
text: str | re.Pattern | StringMatchFunction,
*,
hint: HintBox | None = None,
rect: Rect | None = None,
pad: bool = True,
) -> OcrResult | None:
"""
寻找指定文本
寻找指定文本
:param hint: 如果指定则首先只识别 HintBox 范围内的文本若未命中再全局寻找
:param rect: 如果指定则只识别指定矩形区域此参数优先级低于 `hint`
:param pad: `ocr` `pad` 参数
:return: 找到的文本如果未找到则返回 None
"""
for result in self.ocr(img):
if hint is not None:
if ret := self.find(img, text, rect=hint):
return ret
for result in self.ocr(img, rect=rect, pad=pad):
if _is_match(result.text, text):
return result
return None
def expect(self, img: 'MatLike', text: str | re.Pattern | StringMatchFunction) -> OcrResult:
def expect(
self,
img: 'MatLike',
text: str | re.Pattern | StringMatchFunction,
*,
hint: HintBox | None = None,
rect: Rect | None = None,
pad: bool = True,
) -> OcrResult:
"""
寻找指定文本如果未找到则抛出异常
寻找指定文本如果未找到则抛出异常
:param hint: 如果指定则首先只识别 HintBox 范围内的文本若未命中再全局寻找
:param rect: 如果指定则只识别指定矩形区域此参数优先级高于 `hint`
:param pad: `ocr` `pad` 参数
:return: 找到的文本
"""
ret = self.find(img, text)
ret = self.find(img, text, hint=hint, rect=rect, pad=pad)
if ret is None:
raise TextNotFoundError(text, img)
return ret
@ -212,10 +306,38 @@ jp = Ocr(_engine_jp)
en = Ocr(_engine_en)
"""英语 OCR 引擎。"""
if __name__ == '__main__':
import time
from pprint import pprint as print
import cv2
img_path = 'test_images/acquire_pdorinku.png'
print('small')
img_path = r"C:\Users\user\Downloads\Screenshot_2025.01.28_21.00.40.172.png"
img = cv2.imread(img_path)
time_start = time.time()
result1 = jp.ocr(img)
print(result1)
time_end = time.time()
print(time_end - time_start)
# print(result1)
for i in np.linspace(300, 1000, 20):
i = int(i)
print('small-pad: ' + f'{str(i)}x{str(i)}')
img = pad_to(img, (int(i), int(i)))
time_start = time.time()
result1 = jp.ocr(img)
time_end = time.time()
print(time_end - time_start)
# print(result1)
print('big')
img_path = r"C:\Users\user\Pictures\BlueStacks\Screenshot_2025.01.28_21.00.40.172.png"
img = cv2.imread(img_path)
time_start = time.time()
result1 = jp.ocr(img)
time_end = time.time()
print(time_end - time_start)
# print(result1)