183 lines
5.5 KiB
Python
183 lines
5.5 KiB
Python
import re
|
|
import time
|
|
import unicodedata
|
|
from os import PathLike
|
|
from typing import TYPE_CHECKING, Callable, NamedTuple, overload
|
|
|
|
from .util import Rect, grayscaled
|
|
from .debug import result as debug_result, debug
|
|
|
|
import cv2
|
|
from cv2.typing import MatLike
|
|
from rapidocr_onnxruntime import RapidOCR
|
|
|
|
_engine_jp = RapidOCR(
|
|
rec_model_path=r'res\models\japan_PP-OCRv3_rec_infer.onnx',
|
|
use_det=True,
|
|
use_cls=False,
|
|
use_rec=True,
|
|
)
|
|
_engine_en = RapidOCR(
|
|
rec_model_path=r'res\models\en_PP-OCRv3_rec_infer.onnx',
|
|
use_det=True,
|
|
use_cls=False,
|
|
use_rec=True,
|
|
)
|
|
|
|
StringMatchFunction = Callable[[str], bool]
|
|
|
|
class OcrResult(NamedTuple):
|
|
text: str
|
|
rect: Rect
|
|
confidence: float
|
|
|
|
class TextNotFoundError(Exception):
|
|
def __init__(self, pattern: str | re.Pattern | StringMatchFunction, image: 'MatLike'):
|
|
self.pattern = pattern
|
|
self.image = image
|
|
if isinstance(pattern, (str, re.Pattern)):
|
|
super().__init__(f"Expected text not found: {pattern}")
|
|
else:
|
|
super().__init__(f"Expected text not found: {pattern.__name__}")
|
|
|
|
|
|
def _is_match(text: str, pattern: re.Pattern | str | StringMatchFunction) -> bool:
|
|
if isinstance(pattern, re.Pattern):
|
|
return pattern.match(text) is not None
|
|
elif callable(pattern):
|
|
return pattern(text)
|
|
else:
|
|
return text == pattern
|
|
|
|
def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike':
|
|
import numpy as np
|
|
from PIL import Image, ImageDraw, ImageFont
|
|
|
|
# 转换为PIL图像
|
|
result_image = cv2.cvtColor(image.copy(), cv2.COLOR_BGR2RGB)
|
|
pil_image = Image.fromarray(result_image)
|
|
draw = ImageDraw.Draw(pil_image, 'RGBA')
|
|
|
|
# 加载字体
|
|
try:
|
|
font = ImageFont.truetype(r'res\fonts\SourceHanSansHW-Regular.otf', 16)
|
|
except:
|
|
font = ImageFont.load_default()
|
|
|
|
for r in result:
|
|
# 画矩形框
|
|
draw.rectangle(
|
|
[r.rect[0], r.rect[1], r.rect[0] + r.rect[2], r.rect[1] + r.rect[3]],
|
|
outline=(255, 0, 0),
|
|
width=2
|
|
)
|
|
|
|
# 获取文本大小
|
|
text = r.text + f" ({r.confidence:.2f})" # 添加置信度显示
|
|
text_bbox = draw.textbbox((0, 0), text, font=font)
|
|
text_width = text_bbox[2] - text_bbox[0]
|
|
text_height = text_bbox[3] - text_bbox[1]
|
|
|
|
# 计算文本位置
|
|
text_x = r.rect[0]
|
|
text_y = r.rect[1] - text_height - 5 if r.rect[1] > text_height + 5 else r.rect[1] + r.rect[3] + 5
|
|
|
|
# 添加padding
|
|
padding = 4
|
|
bg_rect = [
|
|
text_x - padding,
|
|
text_y - padding,
|
|
text_x + text_width + padding,
|
|
text_y + text_height + padding
|
|
]
|
|
|
|
# 画半透明背景
|
|
draw.rectangle(
|
|
bg_rect,
|
|
fill=(0, 0, 0, 128)
|
|
)
|
|
|
|
# 画文字
|
|
draw.text(
|
|
(text_x, text_y),
|
|
text,
|
|
font=font,
|
|
fill=(255, 255, 255)
|
|
)
|
|
|
|
# 转回OpenCV格式
|
|
result_image = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
|
return result_image
|
|
|
|
class Ocr:
|
|
def __init__(self, engine: RapidOCR):
|
|
self.__engine = engine
|
|
|
|
|
|
# TODO: 考虑缓存 OCR 结果,避免重复调用。
|
|
def ocr(self, img: 'MatLike') -> list[OcrResult]:
|
|
"""
|
|
OCR 一个 cv2 的图像。注意识别结果中的**全角字符会被转换为半角字符**。
|
|
|
|
:return: 所有识别结果
|
|
"""
|
|
img_content = grayscaled(img)
|
|
result, elapse = self.__engine(img_content)
|
|
if result is None:
|
|
return []
|
|
ret = [OcrResult(
|
|
text=unicodedata.normalize('NFKC', r[1]).replace('ą', 'a'), # HACK: 识别结果中包含奇怪的符号,暂时替换掉
|
|
rect=(
|
|
int(r[0][0][0]), # 左上x
|
|
int(r[0][0][1]), # 左上y
|
|
int(r[0][2][0] - r[0][0][0]), # 宽度 = 右下x - 左上x # type: ignore
|
|
int(r[0][2][1] - r[0][0][1]), # 高度 = 右下y - 左上y # type: ignore
|
|
),
|
|
confidence=r[2] # type: ignore
|
|
) for r in result] # type: ignore
|
|
if debug.enabled:
|
|
result_image = _draw_result(img, ret)
|
|
debug_result(
|
|
'ocr',
|
|
[result_image, img],
|
|
f"result: \n" + \
|
|
"<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \
|
|
"\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.2f}</td></tr>" for r in ret]) + \
|
|
"</table>"
|
|
)
|
|
return ret
|
|
|
|
def find(self, img: 'MatLike', text: str | re.Pattern | StringMatchFunction) -> OcrResult | None:
|
|
"""
|
|
寻找指定文本
|
|
|
|
:return: 找到的文本,如果未找到则返回 None
|
|
"""
|
|
for result in self.ocr(img):
|
|
if _is_match(result.text, text):
|
|
return result
|
|
return None
|
|
|
|
def expect(self, img: 'MatLike', text: str | re.Pattern | StringMatchFunction) -> OcrResult:
|
|
"""
|
|
寻找指定文本,如果未找到则抛出异常
|
|
"""
|
|
ret = self.find(img, text)
|
|
if ret is None:
|
|
raise TextNotFoundError(text, img)
|
|
return ret
|
|
|
|
|
|
|
|
jp = Ocr(_engine_jp)
|
|
"""日语 OCR 引擎。"""
|
|
en = Ocr(_engine_en)
|
|
"""英语 OCR 引擎。"""
|
|
|
|
if __name__ == '__main__':
|
|
from pprint import pprint as print
|
|
import cv2
|
|
img_path = 'test_images/acquire_pdorinku.png'
|
|
img = cv2.imread(img_path)
|
|
result1 = jp.ocr(img)
|
|
print(result1) |