feat(core): 引入 HintBox 并优化 OCR
1. 引入 HintBox 定义 2. OCR 函数支持指定识别区域与 HintBox 3. 优化小图 OCR 识别
This commit is contained in:
parent
496b10cac3
commit
feb1dedb69
|
@ -16,6 +16,7 @@ from typing import (
|
|||
Generic,
|
||||
Type,
|
||||
)
|
||||
from typing_extensions import deprecated
|
||||
|
||||
import cv2
|
||||
from cv2.typing import MatLike
|
||||
|
@ -39,7 +40,7 @@ from kotonebot.backend.color import find_rgb
|
|||
from kotonebot.backend.ocr import Ocr, OcrResult, jp, en, StringMatchFunction
|
||||
from kotonebot.config.manager import load_config, save_config
|
||||
from kotonebot.config.base_config import UserConfig
|
||||
from kotonebot.backend.core import Image
|
||||
from kotonebot.backend.core import Image, HintBox
|
||||
|
||||
OcrLanguage = Literal['jp', 'en']
|
||||
ScreenshotMode = Literal['auto', 'manual', 'manual-inherit']
|
||||
|
@ -237,6 +238,7 @@ class ContextOcr:
|
|||
...
|
||||
|
||||
@overload
|
||||
@deprecated('使用 `ocr.raw().ocr()` 代替')
|
||||
def ocr(self, img: 'MatLike') -> list[OcrResult]:
|
||||
"""OCR 指定图像。"""
|
||||
...
|
||||
|
@ -246,32 +248,32 @@ class ContextOcr:
|
|||
if img is None:
|
||||
return self.__engine.ocr(ContextStackVars.ensure_current().screenshot)
|
||||
return self.__engine.ocr(img)
|
||||
|
||||
@overload
|
||||
def find(self, pattern: str | re.Pattern | StringMatchFunction) -> OcrResult | None:
|
||||
...
|
||||
|
||||
@overload
|
||||
def find(self, img: 'MatLike', pattern: str | re.Pattern | StringMatchFunction) -> OcrResult | None:
|
||||
...
|
||||
|
||||
def find(self, *args, **kwargs) -> OcrResult | None:
|
||||
"""检查指定图像是否包含指定文本。"""
|
||||
if len(args) == 1 and len(kwargs) == 0:
|
||||
ret = self.__engine.find(ContextStackVars.ensure_current().screenshot, args[0])
|
||||
self.context.device.last_find = ret
|
||||
return ret
|
||||
elif len(args) == 2 and len(kwargs) == 0:
|
||||
ret = self.__engine.find(args[0], args[1])
|
||||
self.context.device.last_find = ret
|
||||
return ret
|
||||
else:
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
def find(
|
||||
self,
|
||||
pattern: str | re.Pattern | StringMatchFunction,
|
||||
*,
|
||||
hint: HintBox | None = None,
|
||||
rect: Rect | None = None,
|
||||
) -> OcrResult | None:
|
||||
"""检查当前设备画面是否包含指定文本。"""
|
||||
ret = self.__engine.find(
|
||||
ContextStackVars.ensure_current().screenshot,
|
||||
pattern,
|
||||
hint=hint,
|
||||
rect=rect,
|
||||
)
|
||||
self.context.device.last_find = ret
|
||||
return ret
|
||||
|
||||
|
||||
|
||||
|
||||
def expect(
|
||||
self,
|
||||
pattern: str | re.Pattern | StringMatchFunction
|
||||
) -> OcrResult:
|
||||
|
||||
"""
|
||||
检查当前设备画面是否包含指定文本。
|
||||
|
||||
|
|
|
@ -1,10 +1,14 @@
|
|||
import logging
|
||||
|
||||
from typing import Callable, ParamSpec, TypeVar, overload
|
||||
from typing import Callable, ParamSpec, TypeVar, overload, TYPE_CHECKING
|
||||
|
||||
import cv2
|
||||
from cv2.typing import MatLike
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from kotonebot.backend.util import Rect
|
||||
|
||||
|
||||
class Ocr:
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -52,11 +56,58 @@ class Image:
|
|||
else:
|
||||
return f'<Image: "{self.name}" at {self.path}>'
|
||||
|
||||
|
||||
class HintBox(tuple[int, int, int, int]):
|
||||
def __new__(
|
||||
cls,
|
||||
x1: int,
|
||||
y1: int,
|
||||
x2: int,
|
||||
y2: int,
|
||||
*,
|
||||
source_resolution: tuple[int, int],
|
||||
):
|
||||
return super().__new__(cls, [x1, y1, x2, y2])
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
x1: int,
|
||||
y1: int,
|
||||
x2: int,
|
||||
y2: int,
|
||||
*,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
source_resolution: tuple[int, int],
|
||||
):
|
||||
self.x1 = x1
|
||||
self.y1 = y1
|
||||
self.x2 = x2
|
||||
self.y2 = y2
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.source_resolution = source_resolution
|
||||
|
||||
@property
|
||||
def width(self) -> int:
|
||||
return self.x2 - self.x1
|
||||
|
||||
@property
|
||||
def height(self) -> int:
|
||||
return self.y2 - self.y1
|
||||
|
||||
@property
|
||||
def rect(self) -> 'Rect':
|
||||
return self.x1, self.y1, self.width, self.height
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@overload
|
||||
def image(data: str) -> Image:
|
||||
|
||||
|
||||
"""从文件路径创建 Image 对象。"""
|
||||
...
|
||||
@overload
|
||||
|
@ -72,3 +123,10 @@ def image(data: str | MatLike) -> Image:
|
|||
|
||||
def ocr(text: str | Callable[[str], bool], language: str = 'jp') -> Ocr:
|
||||
return Ocr(text, language=language)
|
||||
|
||||
if __name__ == '__main__':
|
||||
hint_box = HintBox(100, 100, 200, 200, source_resolution=(1920, 1080))
|
||||
print(hint_box.rect)
|
||||
print(hint_box.width)
|
||||
print(hint_box.height)
|
||||
|
||||
|
|
|
@ -3,14 +3,17 @@ import unicodedata
|
|||
from functools import lru_cache
|
||||
from typing import Callable, NamedTuple
|
||||
|
||||
from .util import Rect, grayscaled, res_path
|
||||
from .debug import result as debug_result, debug
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from cv2.typing import MatLike
|
||||
from thefuzz import fuzz as _fuzz
|
||||
from rapidocr_onnxruntime import RapidOCR
|
||||
|
||||
from .util import Rect, grayscaled, res_path
|
||||
from .debug import result as debug_result, debug
|
||||
from .core import HintBox
|
||||
|
||||
|
||||
_engine_jp = RapidOCR(
|
||||
rec_model_path=res_path('res/models/japan_PP-OCRv3_rec_infer.onnx'),
|
||||
use_det=True,
|
||||
|
@ -89,6 +92,46 @@ def bounding_box(points: list[tuple[int, int]]) -> tuple[int, int, int, int]:
|
|||
topleft, bottomright = _bounding_box(points)
|
||||
return (topleft[0], topleft[1], bottomright[0] - topleft[0], bottomright[1] - topleft[1])
|
||||
|
||||
def pad_to(img: MatLike, target_size: tuple[int, int], rgb: tuple[int, int, int] = (255, 255, 255)) -> MatLike:
|
||||
"""将图像居中填充/缩放到指定大小。缺少部分使用指定颜色填充。"""
|
||||
h, w = img.shape[:2]
|
||||
tw, th = target_size
|
||||
|
||||
# 如果图像宽高都大于目标大小,则不进行填充
|
||||
if h >= th and w >= tw:
|
||||
return img
|
||||
|
||||
# 计算宽高比
|
||||
aspect = w / h
|
||||
target_aspect = tw / th
|
||||
|
||||
# 按比例缩放
|
||||
if aspect > target_aspect:
|
||||
# 图像较宽,以目标宽度为准
|
||||
new_w = tw
|
||||
new_h = int(tw / aspect)
|
||||
else:
|
||||
# 图像较高,以目标高度为准
|
||||
new_h = th
|
||||
new_w = int(th * aspect)
|
||||
|
||||
# 缩放图像
|
||||
if new_w != w or new_h != h:
|
||||
img = cv2.resize(img, (new_w, new_h))
|
||||
|
||||
# 创建目标画布并填充
|
||||
ret = np.full((th, tw, 3), rgb, dtype=np.uint8)
|
||||
|
||||
# 计算需要填充的宽高
|
||||
pad_h = th - new_h
|
||||
pad_w = tw - new_w
|
||||
|
||||
# 将缩放后的图像居中放置
|
||||
ret[
|
||||
pad_h // 2:pad_h // 2 + new_h,
|
||||
pad_w // 2:pad_w // 2 + new_w, :] = img
|
||||
return ret
|
||||
|
||||
def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike':
|
||||
import numpy as np
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
@ -153,17 +196,40 @@ class Ocr:
|
|||
def __init__(self, engine: RapidOCR):
|
||||
self.__engine = engine
|
||||
|
||||
|
||||
# TODO: 考虑缓存 OCR 结果,避免重复调用。
|
||||
def ocr(self, img: 'MatLike') -> list[OcrResult]:
|
||||
def ocr(
|
||||
self,
|
||||
img: 'MatLike',
|
||||
*,
|
||||
rect: Rect | None = None,
|
||||
pad: bool = True,
|
||||
) -> list[OcrResult]:
|
||||
"""
|
||||
OCR 一个 cv2 的图像。注意识别结果中的**全角字符会被转换为半角字符**。
|
||||
|
||||
|
||||
:param rect: 如果指定,则只识别指定矩形区域。
|
||||
:param pad:
|
||||
是否将过小的图像(尺寸 < 631x631)的图像填充到 631x631。
|
||||
默认为 True。
|
||||
|
||||
对于 PaddleOCR 模型,图片尺寸太小会降低准确率。
|
||||
将图片周围填充放大,有助于提高准确率,降低耗时。
|
||||
:return: 所有识别结果
|
||||
"""
|
||||
if rect is not None:
|
||||
x, y, w, h = rect
|
||||
img = img[y:y+h, x:x+w]
|
||||
original_img = img
|
||||
if pad:
|
||||
# TODO: 详细研究哪个尺寸最佳,以及背景颜色、图片位置是否对准确率与耗时有影响
|
||||
# https://blog.csdn.net/YY007H/article/details/124973777
|
||||
original_img = img.copy()
|
||||
img = pad_to(img, (631, 631))
|
||||
img_content = grayscaled(img)
|
||||
result, elapse = self.__engine(img_content)
|
||||
if result is None:
|
||||
|
||||
return []
|
||||
ret = [OcrResult(
|
||||
text=unicodedata.normalize('NFKC', r[1]).replace('ą', 'a'), # HACK: 识别结果中包含奇怪的符号,暂时替换掉
|
||||
|
@ -177,7 +243,7 @@ class Ocr:
|
|||
result_image = _draw_result(img, ret)
|
||||
debug_result(
|
||||
'ocr',
|
||||
[result_image, img],
|
||||
[result_image, original_img],
|
||||
f"result: \n" + \
|
||||
"<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \
|
||||
"\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.2f}</td></tr>" for r in ret]) + \
|
||||
|
@ -185,22 +251,50 @@ class Ocr:
|
|||
)
|
||||
return ret
|
||||
|
||||
def find(self, img: 'MatLike', text: str | re.Pattern | StringMatchFunction) -> OcrResult | None:
|
||||
def find(
|
||||
self,
|
||||
img: 'MatLike',
|
||||
text: str | re.Pattern | StringMatchFunction,
|
||||
*,
|
||||
hint: HintBox | None = None,
|
||||
rect: Rect | None = None,
|
||||
pad: bool = True,
|
||||
) -> OcrResult | None:
|
||||
"""
|
||||
寻找指定文本
|
||||
|
||||
寻找指定文本。
|
||||
|
||||
:param hint: 如果指定,则首先只识别 HintBox 范围内的文本,若未命中,再全局寻找。
|
||||
:param rect: 如果指定,则只识别指定矩形区域。此参数优先级低于 `hint`。
|
||||
:param pad: 见 `ocr` 的 `pad` 参数。
|
||||
:return: 找到的文本,如果未找到则返回 None
|
||||
"""
|
||||
for result in self.ocr(img):
|
||||
if hint is not None:
|
||||
if ret := self.find(img, text, rect=hint):
|
||||
return ret
|
||||
for result in self.ocr(img, rect=rect, pad=pad):
|
||||
if _is_match(result.text, text):
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def expect(self, img: 'MatLike', text: str | re.Pattern | StringMatchFunction) -> OcrResult:
|
||||
def expect(
|
||||
self,
|
||||
img: 'MatLike',
|
||||
text: str | re.Pattern | StringMatchFunction,
|
||||
*,
|
||||
hint: HintBox | None = None,
|
||||
rect: Rect | None = None,
|
||||
pad: bool = True,
|
||||
) -> OcrResult:
|
||||
"""
|
||||
寻找指定文本,如果未找到则抛出异常
|
||||
寻找指定文本,如果未找到则抛出异常。
|
||||
|
||||
:param hint: 如果指定,则首先只识别 HintBox 范围内的文本,若未命中,再全局寻找。
|
||||
:param rect: 如果指定,则只识别指定矩形区域。此参数优先级高于 `hint`。
|
||||
:param pad: 见 `ocr` 的 `pad` 参数。
|
||||
:return: 找到的文本
|
||||
"""
|
||||
ret = self.find(img, text)
|
||||
ret = self.find(img, text, hint=hint, rect=rect, pad=pad)
|
||||
if ret is None:
|
||||
raise TextNotFoundError(text, img)
|
||||
return ret
|
||||
|
@ -212,10 +306,38 @@ jp = Ocr(_engine_jp)
|
|||
en = Ocr(_engine_en)
|
||||
"""英语 OCR 引擎。"""
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import time
|
||||
from pprint import pprint as print
|
||||
import cv2
|
||||
img_path = 'test_images/acquire_pdorinku.png'
|
||||
print('small')
|
||||
img_path = r"C:\Users\user\Downloads\Screenshot_2025.01.28_21.00.40.172.png"
|
||||
img = cv2.imread(img_path)
|
||||
time_start = time.time()
|
||||
result1 = jp.ocr(img)
|
||||
print(result1)
|
||||
time_end = time.time()
|
||||
print(time_end - time_start)
|
||||
# print(result1)
|
||||
|
||||
for i in np.linspace(300, 1000, 20):
|
||||
i = int(i)
|
||||
print('small-pad: ' + f'{str(i)}x{str(i)}')
|
||||
img = pad_to(img, (int(i), int(i)))
|
||||
time_start = time.time()
|
||||
result1 = jp.ocr(img)
|
||||
time_end = time.time()
|
||||
print(time_end - time_start)
|
||||
|
||||
# print(result1)
|
||||
|
||||
|
||||
print('big')
|
||||
img_path = r"C:\Users\user\Pictures\BlueStacks\Screenshot_2025.01.28_21.00.40.172.png"
|
||||
img = cv2.imread(img_path)
|
||||
time_start = time.time()
|
||||
result1 = jp.ocr(img)
|
||||
time_end = time.time()
|
||||
print(time_end - time_start)
|
||||
# print(result1)
|
||||
|
|
Loading…
Reference in New Issue