feat(core): 引入 HintBox 并优化 OCR
1. 引入 HintBox 定义 2. OCR 函数支持指定识别区域与 HintBox 3. 优化小图 OCR 识别
This commit is contained in:
parent
496b10cac3
commit
feb1dedb69
|
@ -16,6 +16,7 @@ from typing import (
|
||||||
Generic,
|
Generic,
|
||||||
Type,
|
Type,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import deprecated
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
from cv2.typing import MatLike
|
from cv2.typing import MatLike
|
||||||
|
@ -39,7 +40,7 @@ from kotonebot.backend.color import find_rgb
|
||||||
from kotonebot.backend.ocr import Ocr, OcrResult, jp, en, StringMatchFunction
|
from kotonebot.backend.ocr import Ocr, OcrResult, jp, en, StringMatchFunction
|
||||||
from kotonebot.config.manager import load_config, save_config
|
from kotonebot.config.manager import load_config, save_config
|
||||||
from kotonebot.config.base_config import UserConfig
|
from kotonebot.config.base_config import UserConfig
|
||||||
from kotonebot.backend.core import Image
|
from kotonebot.backend.core import Image, HintBox
|
||||||
|
|
||||||
OcrLanguage = Literal['jp', 'en']
|
OcrLanguage = Literal['jp', 'en']
|
||||||
ScreenshotMode = Literal['auto', 'manual', 'manual-inherit']
|
ScreenshotMode = Literal['auto', 'manual', 'manual-inherit']
|
||||||
|
@ -237,6 +238,7 @@ class ContextOcr:
|
||||||
...
|
...
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
|
@deprecated('使用 `ocr.raw().ocr()` 代替')
|
||||||
def ocr(self, img: 'MatLike') -> list[OcrResult]:
|
def ocr(self, img: 'MatLike') -> list[OcrResult]:
|
||||||
"""OCR 指定图像。"""
|
"""OCR 指定图像。"""
|
||||||
...
|
...
|
||||||
|
@ -246,32 +248,32 @@ class ContextOcr:
|
||||||
if img is None:
|
if img is None:
|
||||||
return self.__engine.ocr(ContextStackVars.ensure_current().screenshot)
|
return self.__engine.ocr(ContextStackVars.ensure_current().screenshot)
|
||||||
return self.__engine.ocr(img)
|
return self.__engine.ocr(img)
|
||||||
|
|
||||||
@overload
|
|
||||||
def find(self, pattern: str | re.Pattern | StringMatchFunction) -> OcrResult | None:
|
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
def find(
|
||||||
def find(self, img: 'MatLike', pattern: str | re.Pattern | StringMatchFunction) -> OcrResult | None:
|
self,
|
||||||
...
|
pattern: str | re.Pattern | StringMatchFunction,
|
||||||
|
*,
|
||||||
def find(self, *args, **kwargs) -> OcrResult | None:
|
hint: HintBox | None = None,
|
||||||
"""检查指定图像是否包含指定文本。"""
|
rect: Rect | None = None,
|
||||||
if len(args) == 1 and len(kwargs) == 0:
|
) -> OcrResult | None:
|
||||||
ret = self.__engine.find(ContextStackVars.ensure_current().screenshot, args[0])
|
"""检查当前设备画面是否包含指定文本。"""
|
||||||
self.context.device.last_find = ret
|
ret = self.__engine.find(
|
||||||
return ret
|
ContextStackVars.ensure_current().screenshot,
|
||||||
elif len(args) == 2 and len(kwargs) == 0:
|
pattern,
|
||||||
ret = self.__engine.find(args[0], args[1])
|
hint=hint,
|
||||||
self.context.device.last_find = ret
|
rect=rect,
|
||||||
return ret
|
)
|
||||||
else:
|
self.context.device.last_find = ret
|
||||||
raise ValueError("Invalid arguments")
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def expect(
|
def expect(
|
||||||
self,
|
self,
|
||||||
pattern: str | re.Pattern | StringMatchFunction
|
pattern: str | re.Pattern | StringMatchFunction
|
||||||
) -> OcrResult:
|
) -> OcrResult:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
检查当前设备画面是否包含指定文本。
|
检查当前设备画面是否包含指定文本。
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,14 @@
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from typing import Callable, ParamSpec, TypeVar, overload
|
from typing import Callable, ParamSpec, TypeVar, overload, TYPE_CHECKING
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
from cv2.typing import MatLike
|
from cv2.typing import MatLike
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from kotonebot.backend.util import Rect
|
||||||
|
|
||||||
|
|
||||||
class Ocr:
|
class Ocr:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -52,11 +56,58 @@ class Image:
|
||||||
else:
|
else:
|
||||||
return f'<Image: "{self.name}" at {self.path}>'
|
return f'<Image: "{self.name}" at {self.path}>'
|
||||||
|
|
||||||
|
|
||||||
|
class HintBox(tuple[int, int, int, int]):
|
||||||
|
def __new__(
|
||||||
|
cls,
|
||||||
|
x1: int,
|
||||||
|
y1: int,
|
||||||
|
x2: int,
|
||||||
|
y2: int,
|
||||||
|
*,
|
||||||
|
source_resolution: tuple[int, int],
|
||||||
|
):
|
||||||
|
return super().__new__(cls, [x1, y1, x2, y2])
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
x1: int,
|
||||||
|
y1: int,
|
||||||
|
x2: int,
|
||||||
|
y2: int,
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
description: str | None = None,
|
||||||
|
source_resolution: tuple[int, int],
|
||||||
|
):
|
||||||
|
self.x1 = x1
|
||||||
|
self.y1 = y1
|
||||||
|
self.x2 = x2
|
||||||
|
self.y2 = y2
|
||||||
|
self.name = name
|
||||||
|
self.description = description
|
||||||
|
self.source_resolution = source_resolution
|
||||||
|
|
||||||
|
@property
|
||||||
|
def width(self) -> int:
|
||||||
|
return self.x2 - self.x1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def height(self) -> int:
|
||||||
|
return self.y2 - self.y1
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rect(self) -> 'Rect':
|
||||||
|
return self.x1, self.y1, self.width, self.height
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def image(data: str) -> Image:
|
def image(data: str) -> Image:
|
||||||
|
|
||||||
|
|
||||||
"""从文件路径创建 Image 对象。"""
|
"""从文件路径创建 Image 对象。"""
|
||||||
...
|
...
|
||||||
@overload
|
@overload
|
||||||
|
@ -72,3 +123,10 @@ def image(data: str | MatLike) -> Image:
|
||||||
|
|
||||||
def ocr(text: str | Callable[[str], bool], language: str = 'jp') -> Ocr:
|
def ocr(text: str | Callable[[str], bool], language: str = 'jp') -> Ocr:
|
||||||
return Ocr(text, language=language)
|
return Ocr(text, language=language)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
hint_box = HintBox(100, 100, 200, 200, source_resolution=(1920, 1080))
|
||||||
|
print(hint_box.rect)
|
||||||
|
print(hint_box.width)
|
||||||
|
print(hint_box.height)
|
||||||
|
|
||||||
|
|
|
@ -3,14 +3,17 @@ import unicodedata
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import Callable, NamedTuple
|
from typing import Callable, NamedTuple
|
||||||
|
|
||||||
from .util import Rect, grayscaled, res_path
|
|
||||||
from .debug import result as debug_result, debug
|
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
|
import numpy as np
|
||||||
from cv2.typing import MatLike
|
from cv2.typing import MatLike
|
||||||
from thefuzz import fuzz as _fuzz
|
from thefuzz import fuzz as _fuzz
|
||||||
from rapidocr_onnxruntime import RapidOCR
|
from rapidocr_onnxruntime import RapidOCR
|
||||||
|
|
||||||
|
from .util import Rect, grayscaled, res_path
|
||||||
|
from .debug import result as debug_result, debug
|
||||||
|
from .core import HintBox
|
||||||
|
|
||||||
|
|
||||||
_engine_jp = RapidOCR(
|
_engine_jp = RapidOCR(
|
||||||
rec_model_path=res_path('res/models/japan_PP-OCRv3_rec_infer.onnx'),
|
rec_model_path=res_path('res/models/japan_PP-OCRv3_rec_infer.onnx'),
|
||||||
use_det=True,
|
use_det=True,
|
||||||
|
@ -89,6 +92,46 @@ def bounding_box(points: list[tuple[int, int]]) -> tuple[int, int, int, int]:
|
||||||
topleft, bottomright = _bounding_box(points)
|
topleft, bottomright = _bounding_box(points)
|
||||||
return (topleft[0], topleft[1], bottomright[0] - topleft[0], bottomright[1] - topleft[1])
|
return (topleft[0], topleft[1], bottomright[0] - topleft[0], bottomright[1] - topleft[1])
|
||||||
|
|
||||||
|
def pad_to(img: MatLike, target_size: tuple[int, int], rgb: tuple[int, int, int] = (255, 255, 255)) -> MatLike:
|
||||||
|
"""将图像居中填充/缩放到指定大小。缺少部分使用指定颜色填充。"""
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
tw, th = target_size
|
||||||
|
|
||||||
|
# 如果图像宽高都大于目标大小,则不进行填充
|
||||||
|
if h >= th and w >= tw:
|
||||||
|
return img
|
||||||
|
|
||||||
|
# 计算宽高比
|
||||||
|
aspect = w / h
|
||||||
|
target_aspect = tw / th
|
||||||
|
|
||||||
|
# 按比例缩放
|
||||||
|
if aspect > target_aspect:
|
||||||
|
# 图像较宽,以目标宽度为准
|
||||||
|
new_w = tw
|
||||||
|
new_h = int(tw / aspect)
|
||||||
|
else:
|
||||||
|
# 图像较高,以目标高度为准
|
||||||
|
new_h = th
|
||||||
|
new_w = int(th * aspect)
|
||||||
|
|
||||||
|
# 缩放图像
|
||||||
|
if new_w != w or new_h != h:
|
||||||
|
img = cv2.resize(img, (new_w, new_h))
|
||||||
|
|
||||||
|
# 创建目标画布并填充
|
||||||
|
ret = np.full((th, tw, 3), rgb, dtype=np.uint8)
|
||||||
|
|
||||||
|
# 计算需要填充的宽高
|
||||||
|
pad_h = th - new_h
|
||||||
|
pad_w = tw - new_w
|
||||||
|
|
||||||
|
# 将缩放后的图像居中放置
|
||||||
|
ret[
|
||||||
|
pad_h // 2:pad_h // 2 + new_h,
|
||||||
|
pad_w // 2:pad_w // 2 + new_w, :] = img
|
||||||
|
return ret
|
||||||
|
|
||||||
def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike':
|
def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike':
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
@ -153,17 +196,40 @@ class Ocr:
|
||||||
def __init__(self, engine: RapidOCR):
|
def __init__(self, engine: RapidOCR):
|
||||||
self.__engine = engine
|
self.__engine = engine
|
||||||
|
|
||||||
|
|
||||||
# TODO: 考虑缓存 OCR 结果,避免重复调用。
|
# TODO: 考虑缓存 OCR 结果,避免重复调用。
|
||||||
def ocr(self, img: 'MatLike') -> list[OcrResult]:
|
def ocr(
|
||||||
|
self,
|
||||||
|
img: 'MatLike',
|
||||||
|
*,
|
||||||
|
rect: Rect | None = None,
|
||||||
|
pad: bool = True,
|
||||||
|
) -> list[OcrResult]:
|
||||||
"""
|
"""
|
||||||
OCR 一个 cv2 的图像。注意识别结果中的**全角字符会被转换为半角字符**。
|
OCR 一个 cv2 的图像。注意识别结果中的**全角字符会被转换为半角字符**。
|
||||||
|
|
||||||
|
|
||||||
|
:param rect: 如果指定,则只识别指定矩形区域。
|
||||||
|
:param pad:
|
||||||
|
是否将过小的图像(尺寸 < 631x631)的图像填充到 631x631。
|
||||||
|
默认为 True。
|
||||||
|
|
||||||
|
对于 PaddleOCR 模型,图片尺寸太小会降低准确率。
|
||||||
|
将图片周围填充放大,有助于提高准确率,降低耗时。
|
||||||
:return: 所有识别结果
|
:return: 所有识别结果
|
||||||
"""
|
"""
|
||||||
|
if rect is not None:
|
||||||
|
x, y, w, h = rect
|
||||||
|
img = img[y:y+h, x:x+w]
|
||||||
|
original_img = img
|
||||||
|
if pad:
|
||||||
|
# TODO: 详细研究哪个尺寸最佳,以及背景颜色、图片位置是否对准确率与耗时有影响
|
||||||
|
# https://blog.csdn.net/YY007H/article/details/124973777
|
||||||
|
original_img = img.copy()
|
||||||
|
img = pad_to(img, (631, 631))
|
||||||
img_content = grayscaled(img)
|
img_content = grayscaled(img)
|
||||||
result, elapse = self.__engine(img_content)
|
result, elapse = self.__engine(img_content)
|
||||||
if result is None:
|
if result is None:
|
||||||
|
|
||||||
return []
|
return []
|
||||||
ret = [OcrResult(
|
ret = [OcrResult(
|
||||||
text=unicodedata.normalize('NFKC', r[1]).replace('ą', 'a'), # HACK: 识别结果中包含奇怪的符号,暂时替换掉
|
text=unicodedata.normalize('NFKC', r[1]).replace('ą', 'a'), # HACK: 识别结果中包含奇怪的符号,暂时替换掉
|
||||||
|
@ -177,7 +243,7 @@ class Ocr:
|
||||||
result_image = _draw_result(img, ret)
|
result_image = _draw_result(img, ret)
|
||||||
debug_result(
|
debug_result(
|
||||||
'ocr',
|
'ocr',
|
||||||
[result_image, img],
|
[result_image, original_img],
|
||||||
f"result: \n" + \
|
f"result: \n" + \
|
||||||
"<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \
|
"<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \
|
||||||
"\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.2f}</td></tr>" for r in ret]) + \
|
"\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.2f}</td></tr>" for r in ret]) + \
|
||||||
|
@ -185,22 +251,50 @@ class Ocr:
|
||||||
)
|
)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def find(self, img: 'MatLike', text: str | re.Pattern | StringMatchFunction) -> OcrResult | None:
|
def find(
|
||||||
|
self,
|
||||||
|
img: 'MatLike',
|
||||||
|
text: str | re.Pattern | StringMatchFunction,
|
||||||
|
*,
|
||||||
|
hint: HintBox | None = None,
|
||||||
|
rect: Rect | None = None,
|
||||||
|
pad: bool = True,
|
||||||
|
) -> OcrResult | None:
|
||||||
"""
|
"""
|
||||||
寻找指定文本
|
寻找指定文本。
|
||||||
|
|
||||||
|
:param hint: 如果指定,则首先只识别 HintBox 范围内的文本,若未命中,再全局寻找。
|
||||||
|
:param rect: 如果指定,则只识别指定矩形区域。此参数优先级低于 `hint`。
|
||||||
|
:param pad: 见 `ocr` 的 `pad` 参数。
|
||||||
:return: 找到的文本,如果未找到则返回 None
|
:return: 找到的文本,如果未找到则返回 None
|
||||||
"""
|
"""
|
||||||
for result in self.ocr(img):
|
if hint is not None:
|
||||||
|
if ret := self.find(img, text, rect=hint):
|
||||||
|
return ret
|
||||||
|
for result in self.ocr(img, rect=rect, pad=pad):
|
||||||
if _is_match(result.text, text):
|
if _is_match(result.text, text):
|
||||||
return result
|
return result
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def expect(self, img: 'MatLike', text: str | re.Pattern | StringMatchFunction) -> OcrResult:
|
def expect(
|
||||||
|
self,
|
||||||
|
img: 'MatLike',
|
||||||
|
text: str | re.Pattern | StringMatchFunction,
|
||||||
|
*,
|
||||||
|
hint: HintBox | None = None,
|
||||||
|
rect: Rect | None = None,
|
||||||
|
pad: bool = True,
|
||||||
|
) -> OcrResult:
|
||||||
"""
|
"""
|
||||||
寻找指定文本,如果未找到则抛出异常
|
寻找指定文本,如果未找到则抛出异常。
|
||||||
|
|
||||||
|
:param hint: 如果指定,则首先只识别 HintBox 范围内的文本,若未命中,再全局寻找。
|
||||||
|
:param rect: 如果指定,则只识别指定矩形区域。此参数优先级高于 `hint`。
|
||||||
|
:param pad: 见 `ocr` 的 `pad` 参数。
|
||||||
|
:return: 找到的文本
|
||||||
"""
|
"""
|
||||||
ret = self.find(img, text)
|
ret = self.find(img, text, hint=hint, rect=rect, pad=pad)
|
||||||
if ret is None:
|
if ret is None:
|
||||||
raise TextNotFoundError(text, img)
|
raise TextNotFoundError(text, img)
|
||||||
return ret
|
return ret
|
||||||
|
@ -212,10 +306,38 @@ jp = Ocr(_engine_jp)
|
||||||
en = Ocr(_engine_en)
|
en = Ocr(_engine_en)
|
||||||
"""英语 OCR 引擎。"""
|
"""英语 OCR 引擎。"""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
import time
|
||||||
from pprint import pprint as print
|
from pprint import pprint as print
|
||||||
import cv2
|
import cv2
|
||||||
img_path = 'test_images/acquire_pdorinku.png'
|
print('small')
|
||||||
|
img_path = r"C:\Users\user\Downloads\Screenshot_2025.01.28_21.00.40.172.png"
|
||||||
img = cv2.imread(img_path)
|
img = cv2.imread(img_path)
|
||||||
|
time_start = time.time()
|
||||||
result1 = jp.ocr(img)
|
result1 = jp.ocr(img)
|
||||||
print(result1)
|
time_end = time.time()
|
||||||
|
print(time_end - time_start)
|
||||||
|
# print(result1)
|
||||||
|
|
||||||
|
for i in np.linspace(300, 1000, 20):
|
||||||
|
i = int(i)
|
||||||
|
print('small-pad: ' + f'{str(i)}x{str(i)}')
|
||||||
|
img = pad_to(img, (int(i), int(i)))
|
||||||
|
time_start = time.time()
|
||||||
|
result1 = jp.ocr(img)
|
||||||
|
time_end = time.time()
|
||||||
|
print(time_end - time_start)
|
||||||
|
|
||||||
|
# print(result1)
|
||||||
|
|
||||||
|
|
||||||
|
print('big')
|
||||||
|
img_path = r"C:\Users\user\Pictures\BlueStacks\Screenshot_2025.01.28_21.00.40.172.png"
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
time_start = time.time()
|
||||||
|
result1 = jp.ocr(img)
|
||||||
|
time_end = time.time()
|
||||||
|
print(time_end - time_start)
|
||||||
|
# print(result1)
|
||||||
|
|
Loading…
Reference in New Issue