kotones-auto-assistant/kotonebot/backend/ocr.py

512 lines
17 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import time
import logging
import unicodedata
from functools import lru_cache
from dataclasses import dataclass
import warnings
from typing_extensions import Self, deprecated
from typing import Callable, NamedTuple
import cv2
import numpy as np
from cv2.typing import MatLike
from thefuzz import fuzz as _fuzz
from rapidocr_onnxruntime import RapidOCR
from ..util import lf_path
from ..primitives import Rect, Point
from .core import HintBox, Image, unify_image
from .debug import result as debug_result, debug
logger = logging.getLogger(__name__)
StringMatchFunction = Callable[[str], bool]
REGEX_NUMBERS = re.compile(r'\d+')
@dataclass
class OcrResult:
text: str
rect: Rect
confidence: float
original_rect: Rect
"""
识别结果在原图中的区域坐标。
如果识别时没有设置 `rect` 或 `hint` 参数,则此属性值与 `rect` 相同。
"""
def __repr__(self) -> str:
return f'OcrResult(text="{self.text}", rect={self.rect}, confidence={self.confidence})'
def replace(self, old: str, new: str, count: int = -1) -> Self:
"""
替换识别结果中的文本。
"""
self.text = self.text.replace(old, new, count)
return self
def regex(self, pattern: re.Pattern | str) -> list[str]:
"""
提取识别结果中符合正则表达式的文本。
"""
if isinstance(pattern, str):
pattern = re.compile(pattern)
return pattern.findall(self.text)
def numbers(self) -> list[int]:
"""
提取识别结果中的数字。
"""
return [int(x) for x in REGEX_NUMBERS.findall(self.text)]
class OcrResultList(list[OcrResult]):
def squash(self, remove_newlines: bool = True) -> OcrResult:
"""
将所有识别结果合并为一个大结果。
"""
if not self:
return OcrResult('', Rect(0, 0, 0, 0), 0, Rect(0, 0, 0, 0))
text = [r.text for r in self]
confidence = sum(r.confidence for r in self) / len(self)
points = []
for r in self:
points.append(Point(r.rect.x1, r.rect.y1))
points.append(Point(r.rect.x1 + r.rect.w, r.rect.y1))
points.append(Point(r.rect.x1, r.rect.y1 + r.rect.h))
points.append(Point(r.rect.x1 + r.rect.w, r.rect.y1 + r.rect.h))
rect = Rect(xywh=bounding_box(points))
text = '\n'.join(text)
if remove_newlines:
text = text.replace('\n', '')
return OcrResult(
text=text,
rect=rect,
confidence=confidence,
original_rect=rect,
)
def first(self) -> OcrResult | None:
"""
返回第一个识别结果。
"""
return self[0] if self else None
def where(self, pattern: StringMatchFunction) -> 'OcrResultList':
"""
返回符合条件的识别结果。
"""
return OcrResultList([x for x in self if pattern(x.text)])
class TextNotFoundError(Exception):
def __init__(self, pattern: str | re.Pattern | StringMatchFunction, image: 'MatLike'):
self.pattern = pattern
self.image = image
super().__init__(f"Expected text not found: {pattern}")
class TextComparator:
def __init__(self, name: str, text: str, func: Callable[[str], bool]):
self.name = name
self.text = text
self.func = func
def __call__(self, text: str) -> bool:
return self.func(text)
def __repr__(self) -> str:
return f'{self.name}("{self.text}")'
@deprecated("即将移除")
@lru_cache(maxsize=1000)
def fuzz(text: str) -> TextComparator:
"""返回 fuzzy 算法的字符串匹配函数。"""
func = lambda s: _fuzz.ratio(s, text) > 90
return TextComparator("fuzzy", text, func)
@lru_cache(maxsize=1000)
def regex(regex: str) -> TextComparator:
"""返回正则表达式字符串匹配函数。"""
func = lambda s: re.match(regex, s) is not None
return TextComparator("regex", regex, func)
@lru_cache(maxsize=1000)
def contains(text: str, *, ignore_case: bool = False) -> TextComparator:
"""返回包含指定文本的函数。"""
if ignore_case:
func = lambda s: text.lower() in s.lower()
else:
func = lambda s: text in s
return TextComparator("contains", text, func)
@lru_cache(maxsize=1000)
def equals(
text: str,
*,
remove_space: bool = False,
ignore_case: bool = True,
) -> TextComparator:
"""
返回等于指定文本的函数。
:param text: 要比较的文本。
:param remove_space: 是否忽略空格。默认为 False。
:param ignore_case: 是否忽略大小写。默认为 True。
"""
def compare(s: str) -> bool:
nonlocal text
if ignore_case:
text = text.lower()
s = s.lower()
if remove_space:
text = text.replace(' ', '').replace(' ', '')
s = s.replace(' ', '').replace(' ', '')
return text == s
return TextComparator("equals", text, compare)
def grayscaled(img: 'MatLike | str | Image') -> MatLike:
img = unify_image(img)
return cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
def _is_match(text: str, pattern: re.Pattern | str | StringMatchFunction | TextComparator) -> bool:
if isinstance(pattern, re.Pattern):
return pattern.match(text) is not None
elif callable(pattern):
return pattern(text)
else:
return text == pattern
# https://stackoverflow.com/questions/46335488/how-to-efficiently-find-the-bounding-box-of-a-collection-of-points
def _bounding_box(points):
x_coordinates, y_coordinates = zip(*points)
return [(min(x_coordinates), min(y_coordinates)), (max(x_coordinates), max(y_coordinates))]
def bounding_box(points: list[tuple[int, int]]) -> tuple[int, int, int, int]:
"""
计算点集的外接矩形。
:param points: 点集。以左上角为原点,向下向右为正方向。
:return: 外接矩形的左上角坐标和宽高
"""
topleft, bottomright = _bounding_box(points)
return (topleft[0], topleft[1], bottomright[0] - topleft[0], bottomright[1] - topleft[1])
def pad_to(img: MatLike, target_size: int, rgb: tuple[int, int, int] = (255, 255, 255)) -> tuple[MatLike, tuple[int, int]]:
"""
将图像居中填充到指定大小。缺少部分使用指定颜色填充。
:return: 填充后的图像和填充的偏移量 (x, y)。
"""
h, w = img.shape[:2]
# 计算需要填充的宽高
pad_h = max(0, target_size - h)
pad_w = max(0, target_size - w)
# 如果不需要填充则直接返回
if pad_h == 0 and pad_w == 0:
return img, (0, 0)
# 创建目标画布并填充
if len(img.shape) == 2:
# 灰度图像
ret = np.full((h + pad_h, w + pad_w), rgb[0], dtype=np.uint8)
else:
# RGB图像
ret = np.full((h + pad_h, w + pad_w, 3), rgb, dtype=np.uint8)
# 将原图像居中放置
if len(img.shape) == 2:
ret[
pad_h // 2:pad_h // 2 + h,
pad_w // 2:pad_w // 2 + w] = img
else:
ret[
pad_h // 2:pad_h // 2 + h,
pad_w // 2:pad_w // 2 + w, :] = img
return ret, (pad_w // 2, pad_h // 2)
def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike':
import numpy as np
from PIL import Image, ImageDraw, ImageFont
# 转换为PIL图像
result_image = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(result_image)
draw = ImageDraw.Draw(pil_image, 'RGBA')
# 加载字体
try:
font = ImageFont.truetype(lf_path('res/fonts/SourceHanSansHW-Regular.otf'), 16)
except:
font = ImageFont.load_default()
for r in result:
# 画矩形框
draw.rectangle(
[r.rect.x1, r.rect.y1, r.rect.x1 + r.rect.w, r.rect.y1 + r.rect.h],
outline=(255, 0, 0),
width=2
)
# 获取文本大小
text = r.text + f" ({r.confidence:.2f})" # 添加置信度显示
text_bbox = draw.textbbox((0, 0), text, font=font)
text_width = text_bbox[2] - text_bbox[0]
text_height = text_bbox[3] - text_bbox[1]
# 计算文本位置
text_x = r.rect.x1
text_y = r.rect.y1 - text_height - 5 if r.rect.y1 > text_height + 5 else r.rect.y1 + r.rect.h + 5
# 添加padding
padding = 4
bg_rect = [
text_x - padding,
text_y - padding,
text_x + text_width + padding,
text_y + text_height + padding
]
# 画半透明背景
draw.rectangle(
bg_rect,
fill=(0, 0, 0, 128)
)
# 画文字
draw.text(
(text_x, text_y),
text,
font=font,
fill=(255, 255, 255)
)
# 转回OpenCV格式
result_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
return result_image
class Ocr:
def __init__(self, engine: RapidOCR):
self.__engine = engine
# TODO: 考虑缓存 OCR 结果,避免重复调用。
def ocr(
self,
img: 'MatLike',
*,
rect: Rect | None = None,
pad: bool = True,
) -> OcrResultList:
"""
OCR 一个 cv2 的图像。注意识别结果中的**全角字符会被转换为半角字符**。
:param rect: 如果指定,则只识别指定矩形区域。
:param pad:
是否将过小的图像(尺寸 < 631x631的图像填充到 631x631。
默认为 True。
对于 PaddleOCR 模型,图片尺寸太小会降低准确率。
将图片周围填充放大,有助于提高准确率,降低耗时。
:return: 所有识别结果
"""
if rect is not None:
x, y, w, h = rect.xywh
img = img[y:y+h, x:x+w]
original_img = img
if pad:
# TODO: 详细研究哪个尺寸最佳,以及背景颜色、图片位置是否对准确率与耗时有影响
# https://blog.csdn.net/YY007H/article/details/124973777
original_img = img.copy()
img, pos_in_padded_img = pad_to(img, 631)
else:
pos_in_padded_img = (0, 0)
img_content = img
result, elapse = self.__engine(img_content)
if result is None:
return OcrResultList()
ret = []
for r in result:
# HACK: 识别结果中包含奇怪的符号,暂时替换掉
text = unicodedata.normalize('NFKC', r[1]).replace('ą', 'a')
# r[0] = [左上, 右上, 右下, 左下]
# 这里有个坑,返回的点不一定是矩形,只能保证是四边形
# 所以这里需要计算出四个点的外接矩形
result_rect = tuple(int(x) for x in bounding_box(r[0])) # type: ignore
# result_rect (x, y, w, h)
if rect is not None:
original_rect = (
result_rect[0] + rect.x1 - pos_in_padded_img[0],
result_rect[1] + rect.y1 - pos_in_padded_img[1],
result_rect[2],
result_rect[3]
)
else:
original_rect = result_rect
if not len(original_rect) == 4:
raise ValueError(f'Invalid original_rect: {original_rect}')
if not len(result_rect) == 4:
raise ValueError(f'Invalid result_rect: {result_rect}')
confidence = float(r[2])
ret.append(OcrResult(
text=text,
rect=Rect(xywh=result_rect),
original_rect=Rect(xywh=original_rect),
confidence=confidence
))
ret = OcrResultList(ret)
if debug.enabled:
result_image = _draw_result(img, ret)
elapse = elapse or [0, 0, 0]
debug_result(
'ocr',
[result_image, original_img],
f"pad={pad}\n" + \
f"rect={rect}\n" + \
f"elapsed: det={elapse[0]:.3f}s cls={elapse[1]:.3f}s rec={elapse[2]:.3f}s\n" + \
f"result: \n" + \
"<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \
"\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.3f}</td></tr>" for r in ret]) + \
"</table>"
)
return ret
def find(
self,
img: 'MatLike',
text: str | re.Pattern | StringMatchFunction,
*,
hint: HintBox | None = None,
rect: Rect | None = None,
pad: bool = True,
) -> OcrResult | None:
"""
识别图像中的文本,并寻找满足指定要求的文本。
:param hint: 如果指定,则首先只识别 HintBox 范围内的文本,若未命中,再全局寻找。
:param rect: 如果指定,则只识别指定矩形区域。此参数优先级低于 `hint`。
:param pad: 见 `ocr` 的 `pad` 参数。
:return: 找到的文本,如果未找到则返回 None
"""
if hint is not None:
warnings.warn("使用 `rect` 参数代替")
if ret := self.find(img, text, rect=Rect(xywh=hint.rect)):
logger.debug(f"find: {text} SUCCESS [hint={hint}]")
return ret
logger.debug(f"find: {text} FAILED [hint={hint}]")
start_time = time.time()
results = self.ocr(img, rect=rect, pad=pad)
end_time = time.time()
target = None
for result in results:
if _is_match(result.text, text):
target = result
break
logger.debug(
f"find: {text} {'SUCCESS' if target else 'FAILED'} " + \
f"[elapsed={end_time - start_time:.3f}s] [rect={rect}]"
)
return target
def find_all(
self,
img: 'MatLike',
texts: list[str | re.Pattern | StringMatchFunction],
*,
hint: HintBox | None = None,
rect: Rect | None = None,
pad: bool = True,
) -> list[OcrResult | None]:
"""
识别图像中的文本,并寻找多个满足指定要求的文本。
:return:
所有找到的文本,结果顺序与输入顺序相同。
若某个文本未找到,则该位置为 None。
"""
# HintBox 处理
if hint is not None:
warnings.warn("使用 `rect` 参数代替")
result = self.find_all(img, texts, rect=Rect(xywh=hint.rect), pad=pad)
if all(result):
return result
ret: list[OcrResult | None] = []
ocr_results = self.ocr(img, rect=rect, pad=pad)
logger.debug(f"ocr_results: {ocr_results}")
for text in texts:
for result in ocr_results:
if _is_match(result.text, text):
ret.append(result)
break
else:
ret.append(None)
return ret
def expect(
self,
img: 'MatLike',
text: str | re.Pattern | StringMatchFunction,
*,
hint: HintBox | None = None,
rect: Rect | None = None,
pad: bool = True,
) -> OcrResult:
"""
识别图像中的文本,并寻找满足指定要求的文本。如果未找到则抛出异常。
:param hint: 如果指定,则首先只识别 HintBox 范围内的文本,若未命中,再全局寻找。
:param rect: 如果指定,则只识别指定矩形区域。此参数优先级高于 `hint`。
:param pad: 见 `ocr` 的 `pad` 参数。
:return: 找到的文本
"""
ret = self.find(img, text, hint=hint, rect=rect, pad=pad)
if ret is None:
raise TextNotFoundError(text, img)
return ret
# TODO: 这个路径需要能够独立设置
_engine_jp: RapidOCR | None = None
_engine_en: RapidOCR | None = RapidOCR(
rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
use_det=True,
use_cls=False,
use_rec=True,
)
def jp() -> Ocr:
"""
日语 OCR 引擎。
"""
global _engine_jp
if _engine_jp is None:
_engine_jp = RapidOCR(
rec_model_path=lf_path('models/japan_PP-OCRv3_rec_infer.onnx'),
use_det=True,
use_cls=False,
use_rec=True,
)
return Ocr(_engine_jp)
def en() -> Ocr:
"""
英语 OCR 引擎。
"""
global _engine_en
if _engine_en is None:
_engine_en = RapidOCR(
rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
use_det=True,
use_cls=False,
use_rec=True,
)
return Ocr(_engine_en)
if __name__ == '__main__':
pass