fix(core): 修复 OCR 识别时若传入了 rect/hint 参数,最终结果坐标不正确的问题

This commit is contained in:
XcantloadX 2025-02-11 14:23:27 +08:00
parent 83a2b9ff13
commit 6a47d7d878
3 changed files with 108 additions and 68 deletions

View File

@ -256,7 +256,6 @@ class ContextOcr:
"""OCR 当前设备画面或指定图像。"""
return self.__engine.ocr(ContextStackVars.ensure_current().screenshot, rect=rect)
def find(
self,
pattern: str | re.Pattern | StringMatchFunction,
@ -271,7 +270,7 @@ class ContextOcr:
hint=hint,
rect=rect,
)
self.context.device.last_find = ret
self.context.device.last_find = ret.original_rect if ret else None
return ret
def find_all(
@ -300,7 +299,7 @@ class ContextOcr:
`find()` 的区别在于`expect()` 未找到时会抛出异常
"""
ret = self.__engine.expect(ContextStackVars.ensure_current().screenshot, pattern)
self.context.device.last_find = ret
self.context.device.last_find = ret.original_rect if ret else None
return ret
def expect_wait(
@ -320,7 +319,7 @@ class ContextOcr:
result = self.find(pattern)
if result is not None:
self.context.device.last_find = result
self.context.device.last_find = result.original_rect if result else None
return result
if time.time() - start_time > timeout:
raise TimeoutError(f"Timeout waiting for {pattern}")
@ -342,7 +341,7 @@ class ContextOcr:
while True:
result = self.find(pattern)
if result is not None:
self.context.device.last_find = result
self.context.device.last_find = result.original_rect if result else None
return result
if time.time() - start_time > timeout:
return None

View File

@ -1,4 +1,5 @@
import re
import time
import logging
import unicodedata
from functools import lru_cache
@ -39,6 +40,12 @@ class OcrResult:
text: str
rect: Rect
confidence: float
original_rect: Rect
"""
识别结果在原图中的区域坐标
如果识别时没有设置 `rect` `hint` 参数则此属性值与 `rect` 相同
"""
def __repr__(self) -> str:
return f'OcrResult(text="{self.text}", rect={self.rect}, confidence={self.confidence})'
@ -70,7 +77,7 @@ class OcrResultList(list[OcrResult]):
将所有识别结果合并为一个大结果
"""
if not self:
return OcrResult('', (0, 0, 0, 0), 0)
return OcrResult('', (0, 0, 0, 0), 0, (0, 0, 0, 0))
text = [r.text for r in self]
confidence = sum(r.confidence for r in self) / len(self)
points = []
@ -87,6 +94,7 @@ class OcrResultList(list[OcrResult]):
text=text,
rect=rect,
confidence=confidence,
original_rect=rect,
)
def first(self) -> OcrResult | None:
@ -180,63 +188,48 @@ def _bounding_box(points):
def bounding_box(points: list[tuple[int, int]]) -> tuple[int, int, int, int]:
"""
计算点集的外接矩形
计算点集的外接矩形
:param points: 点集
:param points: 点集以左上角为原点向下向右为正方向
:return: 外接矩形的左上角坐标和宽高
"""
topleft, bottomright = _bounding_box(points)
return (topleft[0], topleft[1], bottomright[0] - topleft[0], bottomright[1] - topleft[1])
def pad_to(img: MatLike, target_size: tuple[int, int], rgb: tuple[int, int, int] = (255, 255, 255)) -> MatLike:
"""将图像居中填充/缩放到指定大小。缺少部分使用指定颜色填充。"""
def pad_to(img: MatLike, target_size: int, rgb: tuple[int, int, int] = (255, 255, 255)) -> tuple[MatLike, tuple[int, int]]:
"""
将图像居中填充到指定大小缺少部分使用指定颜色填充
:return: 填充后的图像和填充的偏移量 (x, y)
"""
h, w = img.shape[:2]
tw, th = target_size
# 如果图像宽高都大于目标大小,则不进行填充
if h >= th and w >= tw:
return img
# 计算宽高比
aspect = w / h
target_aspect = tw / th
# 计算需要填充的宽高
pad_h = max(0, target_size - h)
pad_w = max(0, target_size - w)
# 按比例缩放
if aspect > target_aspect:
# 图像较宽,以目标宽度为准
new_w = tw
new_h = int(tw / aspect)
else:
# 图像较高,以目标高度为准
new_h = th
new_w = int(th * aspect)
# 缩放图像
if new_w != w or new_h != h:
img = cv2.resize(img, (new_w, new_h))
# 如果不需要填充则直接返回
if pad_h == 0 and pad_w == 0:
return img, (0, 0)
# 创建目标画布并填充
if len(img.shape) == 2:
# 灰度图像
ret = np.full((th, tw), rgb[0], dtype=np.uint8)
ret = np.full((h + pad_h, w + pad_w), rgb[0], dtype=np.uint8)
else:
# RGB图像
ret = np.full((th, tw, 3), rgb, dtype=np.uint8)
ret = np.full((h + pad_h, w + pad_w, 3), rgb, dtype=np.uint8)
# 计算需要填充的宽高
pad_h = th - new_h
pad_w = tw - new_w
# 将缩放后的图像居中放置
# 将原图像居中放置
if len(img.shape) == 2:
ret[
pad_h // 2:pad_h // 2 + new_h,
pad_w // 2:pad_w // 2 + new_w] = img
pad_h // 2:pad_h // 2 + h,
pad_w // 2:pad_w // 2 + w] = img
else:
ret[
pad_h // 2:pad_h // 2 + new_h,
pad_w // 2:pad_w // 2 + new_w, :] = img
return ret
pad_h // 2:pad_h // 2 + h,
pad_w // 2:pad_w // 2 + w, :] = img
return ret, (pad_w // 2, pad_h // 2)
def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike':
import numpy as np
@ -331,19 +324,38 @@ class Ocr:
# TODO: 详细研究哪个尺寸最佳,以及背景颜色、图片位置是否对准确率与耗时有影响
# https://blog.csdn.net/YY007H/article/details/124973777
original_img = img.copy()
img = pad_to(img, (631, 631))
img, pos_in_padded_img = pad_to(img, 631)
else:
pos_in_padded_img = (0, 0)
img_content = img
result, elapse = self.__engine(img_content)
if result is None:
return OcrResultList()
ret = [OcrResult(
text=unicodedata.normalize('NFKC', r[1]).replace('ą', 'a'), # HACK: 识别结果中包含奇怪的符号,暂时替换掉
ret = []
for r in result:
# HACK: 识别结果中包含奇怪的符号,暂时替换掉
text = unicodedata.normalize('NFKC', r[1]).replace('ą', 'a')
# r[0] = [左上, 右上, 右下, 左下]
# 这里有个坑,返回的点不一定是矩形,只能保证是四边形
# 所以这里需要计算出四个点的外接矩形
rect=tuple(int(x) for x in bounding_box(r[0])), # type: ignore
confidence=r[2] # type: ignore
) for r in result] # type: ignore
result_rect = tuple(int(x) for x in bounding_box(r[0])) # type: ignore
# result_rect (x, y, w, h)
if rect is not None:
original_rect = (
result_rect[0] + rect[0] - pos_in_padded_img[0],
result_rect[1] + rect[1] - pos_in_padded_img[1],
result_rect[2],
result_rect[3]
)
else:
original_rect = result_rect
confidence = float(r[2])
ret.append(OcrResult(
text=text,
rect=result_rect,
original_rect=original_rect,
confidence=confidence
))
ret = OcrResultList(ret)
if debug.enabled:
result_image = _draw_result(img, ret)
@ -356,7 +368,7 @@ class Ocr:
f"elapsed: det={elapse[0]:.3f}s cls={elapse[1]:.3f}s rec={elapse[2]:.3f}s\n" + \
f"result: \n" + \
"<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \
"\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.2f}</td></tr>" for r in ret]) + \
"\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.3f}</td></tr>" for r in ret]) + \
"</table>"
)
return ret
@ -380,11 +392,23 @@ class Ocr:
"""
if hint is not None:
if ret := self.find(img, text, rect=hint):
logger.debug(f"find: {text} with hint={hint} SUCCESS")
return ret
for result in self.ocr(img, rect=rect, pad=pad):
logger.debug(f"find: {text} with hint={hint} FAILED. Retrying on whole image...")
start_time = time.time()
results = self.ocr(img, rect=rect, pad=pad)
end_time = time.time()
target = None
for result in results:
if _is_match(result.text, text):
return result
return None
target = result
break
logger.debug(
f"find: {text} with rect={rect} elapsed={end_time - start_time:.3f}s " + \
f"{'SUCCESS' if target else 'FAILED'}"
)
return target
def find_all(
self,

View File

@ -50,11 +50,28 @@ class TestOcr(unittest.TestCase):
points = [(5, 5), (5, 5), (5, 5)]
assert bounding_box(points) == (5, 5, 0, 0)
def test_ocr_ocr(self):
def test_ocr_basic(self):
result = jp.ocr(self.img)
self.assertGreater(len(result), 0)
def test_ocr_find(self):
def test_ocr_rect(self):
result = jp.ocr(self.img, rect=(147, 614, 417, 32), pad=True)
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
x, y, w, h = result[0].original_rect
self.assertAlmostEqual(x, 147, delta=10)
self.assertAlmostEqual(y, 614, delta=10)
self.assertAlmostEqual(w, 417, delta=10)
self.assertAlmostEqual(h, 32, delta=10)
result = jp.ocr(self.img, rect=(147, 614, 417, 32), pad=False)
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
x, y, w, h = result[0].original_rect
self.assertAlmostEqual(x, 147, delta=10)
self.assertAlmostEqual(y, 614, delta=10)
self.assertAlmostEqual(w, 417, delta=10)
self.assertAlmostEqual(h, 32, delta=10)
def test_find(self):
self.assertTrue(jp.find(self.img, '中間まで'))
self.assertTrue(jp.find(self.img, '受け取るPドリンクを選んでください。'))
self.assertTrue(jp.find(self.img, '受け取る'))
@ -62,25 +79,25 @@ class TestOcr(unittest.TestCase):
class TestOcrResult(unittest.TestCase):
def test_regex(self):
result = OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95)
result = OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
self.assertEqual(result.regex(r'\d+'), ['123', '4567', '890'])
self.assertEqual(result.regex(re.compile(r'\d+')), ['123', '4567', '890'])
def test_numbers(self):
result = OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95)
result = OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
self.assertEqual(result.numbers(), [123, 4567, 890])
result2 = OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95)
result2 = OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
self.assertEqual(result2.numbers(), [])
result3 = OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95)
result3 = OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
self.assertEqual(result3.numbers(), [1234567890])
class TestOcrResultList(unittest.TestCase):
def test_list_compatibility(self):
result = OcrResultList([
OcrResult(text='abc', rect=(0, 0, 100, 100), confidence=0.95),
OcrResult(text='def', rect=(0, 0, 100, 100), confidence=0.95),
OcrResult(text='ghi', rect=(0, 0, 100, 100), confidence=0.95),
OcrResult(text='abc', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='def', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='ghi', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
])
self.assertEqual(result[0].text, 'abc')
@ -105,17 +122,17 @@ class TestOcrResultList(unittest.TestCase):
def test_where(self):
result = OcrResultList([
OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95),
OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95),
OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95),
OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
])
self.assertEqual(result.where(lambda x: x.startswith('123')), [result[0], result[2]])
def test_first(self):
result = OcrResultList([
OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95),
OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95),
OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95),
OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
])
self.assertEqual(result.first(), result[0])
result2 = OcrResultList()