fix(core): 修复 OCR 识别时若传入了 rect/hint 参数,最终结果坐标不正确的问题
This commit is contained in:
parent
83a2b9ff13
commit
6a47d7d878
|
@ -256,7 +256,6 @@ class ContextOcr:
|
|||
"""OCR 当前设备画面或指定图像。"""
|
||||
return self.__engine.ocr(ContextStackVars.ensure_current().screenshot, rect=rect)
|
||||
|
||||
|
||||
def find(
|
||||
self,
|
||||
pattern: str | re.Pattern | StringMatchFunction,
|
||||
|
@ -271,7 +270,7 @@ class ContextOcr:
|
|||
hint=hint,
|
||||
rect=rect,
|
||||
)
|
||||
self.context.device.last_find = ret
|
||||
self.context.device.last_find = ret.original_rect if ret else None
|
||||
return ret
|
||||
|
||||
def find_all(
|
||||
|
@ -300,7 +299,7 @@ class ContextOcr:
|
|||
与 `find()` 的区别在于,`expect()` 未找到时会抛出异常。
|
||||
"""
|
||||
ret = self.__engine.expect(ContextStackVars.ensure_current().screenshot, pattern)
|
||||
self.context.device.last_find = ret
|
||||
self.context.device.last_find = ret.original_rect if ret else None
|
||||
return ret
|
||||
|
||||
def expect_wait(
|
||||
|
@ -320,7 +319,7 @@ class ContextOcr:
|
|||
result = self.find(pattern)
|
||||
|
||||
if result is not None:
|
||||
self.context.device.last_find = result
|
||||
self.context.device.last_find = result.original_rect if result else None
|
||||
return result
|
||||
if time.time() - start_time > timeout:
|
||||
raise TimeoutError(f"Timeout waiting for {pattern}")
|
||||
|
@ -342,7 +341,7 @@ class ContextOcr:
|
|||
while True:
|
||||
result = self.find(pattern)
|
||||
if result is not None:
|
||||
self.context.device.last_find = result
|
||||
self.context.device.last_find = result.original_rect if result else None
|
||||
return result
|
||||
if time.time() - start_time > timeout:
|
||||
return None
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import re
|
||||
import time
|
||||
import logging
|
||||
import unicodedata
|
||||
from functools import lru_cache
|
||||
|
@ -39,6 +40,12 @@ class OcrResult:
|
|||
text: str
|
||||
rect: Rect
|
||||
confidence: float
|
||||
original_rect: Rect
|
||||
"""
|
||||
识别结果在原图中的区域坐标。
|
||||
|
||||
如果识别时没有设置 `rect` 或 `hint` 参数,则此属性值与 `rect` 相同。
|
||||
"""
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'OcrResult(text="{self.text}", rect={self.rect}, confidence={self.confidence})'
|
||||
|
@ -70,7 +77,7 @@ class OcrResultList(list[OcrResult]):
|
|||
将所有识别结果合并为一个大结果。
|
||||
"""
|
||||
if not self:
|
||||
return OcrResult('', (0, 0, 0, 0), 0)
|
||||
return OcrResult('', (0, 0, 0, 0), 0, (0, 0, 0, 0))
|
||||
text = [r.text for r in self]
|
||||
confidence = sum(r.confidence for r in self) / len(self)
|
||||
points = []
|
||||
|
@ -87,6 +94,7 @@ class OcrResultList(list[OcrResult]):
|
|||
text=text,
|
||||
rect=rect,
|
||||
confidence=confidence,
|
||||
original_rect=rect,
|
||||
)
|
||||
|
||||
def first(self) -> OcrResult | None:
|
||||
|
@ -180,63 +188,48 @@ def _bounding_box(points):
|
|||
|
||||
def bounding_box(points: list[tuple[int, int]]) -> tuple[int, int, int, int]:
|
||||
"""
|
||||
计算点集的外接矩形
|
||||
计算点集的外接矩形。
|
||||
|
||||
:param points: 点集
|
||||
:param points: 点集。以左上角为原点,向下向右为正方向。
|
||||
:return: 外接矩形的左上角坐标和宽高
|
||||
"""
|
||||
topleft, bottomright = _bounding_box(points)
|
||||
return (topleft[0], topleft[1], bottomright[0] - topleft[0], bottomright[1] - topleft[1])
|
||||
|
||||
def pad_to(img: MatLike, target_size: tuple[int, int], rgb: tuple[int, int, int] = (255, 255, 255)) -> MatLike:
|
||||
"""将图像居中填充/缩放到指定大小。缺少部分使用指定颜色填充。"""
|
||||
def pad_to(img: MatLike, target_size: int, rgb: tuple[int, int, int] = (255, 255, 255)) -> tuple[MatLike, tuple[int, int]]:
|
||||
"""
|
||||
将图像居中填充到指定大小。缺少部分使用指定颜色填充。
|
||||
|
||||
:return: 填充后的图像和填充的偏移量 (x, y)。
|
||||
"""
|
||||
h, w = img.shape[:2]
|
||||
tw, th = target_size
|
||||
|
||||
# 如果图像宽高都大于目标大小,则不进行填充
|
||||
if h >= th and w >= tw:
|
||||
return img
|
||||
|
||||
# 计算宽高比
|
||||
aspect = w / h
|
||||
target_aspect = tw / th
|
||||
# 计算需要填充的宽高
|
||||
pad_h = max(0, target_size - h)
|
||||
pad_w = max(0, target_size - w)
|
||||
|
||||
# 按比例缩放
|
||||
if aspect > target_aspect:
|
||||
# 图像较宽,以目标宽度为准
|
||||
new_w = tw
|
||||
new_h = int(tw / aspect)
|
||||
else:
|
||||
# 图像较高,以目标高度为准
|
||||
new_h = th
|
||||
new_w = int(th * aspect)
|
||||
|
||||
# 缩放图像
|
||||
if new_w != w or new_h != h:
|
||||
img = cv2.resize(img, (new_w, new_h))
|
||||
# 如果不需要填充则直接返回
|
||||
if pad_h == 0 and pad_w == 0:
|
||||
return img, (0, 0)
|
||||
|
||||
# 创建目标画布并填充
|
||||
if len(img.shape) == 2:
|
||||
# 灰度图像
|
||||
ret = np.full((th, tw), rgb[0], dtype=np.uint8)
|
||||
ret = np.full((h + pad_h, w + pad_w), rgb[0], dtype=np.uint8)
|
||||
else:
|
||||
# RGB图像
|
||||
ret = np.full((th, tw, 3), rgb, dtype=np.uint8)
|
||||
ret = np.full((h + pad_h, w + pad_w, 3), rgb, dtype=np.uint8)
|
||||
|
||||
# 计算需要填充的宽高
|
||||
pad_h = th - new_h
|
||||
pad_w = tw - new_w
|
||||
|
||||
# 将缩放后的图像居中放置
|
||||
# 将原图像居中放置
|
||||
if len(img.shape) == 2:
|
||||
ret[
|
||||
pad_h // 2:pad_h // 2 + new_h,
|
||||
pad_w // 2:pad_w // 2 + new_w] = img
|
||||
pad_h // 2:pad_h // 2 + h,
|
||||
pad_w // 2:pad_w // 2 + w] = img
|
||||
else:
|
||||
ret[
|
||||
pad_h // 2:pad_h // 2 + new_h,
|
||||
pad_w // 2:pad_w // 2 + new_w, :] = img
|
||||
return ret
|
||||
pad_h // 2:pad_h // 2 + h,
|
||||
pad_w // 2:pad_w // 2 + w, :] = img
|
||||
return ret, (pad_w // 2, pad_h // 2)
|
||||
|
||||
def _draw_result(image: 'MatLike', result: list[OcrResult]) -> 'MatLike':
|
||||
import numpy as np
|
||||
|
@ -331,19 +324,38 @@ class Ocr:
|
|||
# TODO: 详细研究哪个尺寸最佳,以及背景颜色、图片位置是否对准确率与耗时有影响
|
||||
# https://blog.csdn.net/YY007H/article/details/124973777
|
||||
original_img = img.copy()
|
||||
img = pad_to(img, (631, 631))
|
||||
img, pos_in_padded_img = pad_to(img, 631)
|
||||
else:
|
||||
pos_in_padded_img = (0, 0)
|
||||
img_content = img
|
||||
result, elapse = self.__engine(img_content)
|
||||
if result is None:
|
||||
return OcrResultList()
|
||||
ret = [OcrResult(
|
||||
text=unicodedata.normalize('NFKC', r[1]).replace('ą', 'a'), # HACK: 识别结果中包含奇怪的符号,暂时替换掉
|
||||
ret = []
|
||||
for r in result:
|
||||
# HACK: 识别结果中包含奇怪的符号,暂时替换掉
|
||||
text = unicodedata.normalize('NFKC', r[1]).replace('ą', 'a')
|
||||
# r[0] = [左上, 右上, 右下, 左下]
|
||||
# 这里有个坑,返回的点不一定是矩形,只能保证是四边形
|
||||
# 所以这里需要计算出四个点的外接矩形
|
||||
rect=tuple(int(x) for x in bounding_box(r[0])), # type: ignore
|
||||
confidence=r[2] # type: ignore
|
||||
) for r in result] # type: ignore
|
||||
result_rect = tuple(int(x) for x in bounding_box(r[0])) # type: ignore
|
||||
# result_rect (x, y, w, h)
|
||||
if rect is not None:
|
||||
original_rect = (
|
||||
result_rect[0] + rect[0] - pos_in_padded_img[0],
|
||||
result_rect[1] + rect[1] - pos_in_padded_img[1],
|
||||
result_rect[2],
|
||||
result_rect[3]
|
||||
)
|
||||
else:
|
||||
original_rect = result_rect
|
||||
confidence = float(r[2])
|
||||
ret.append(OcrResult(
|
||||
text=text,
|
||||
rect=result_rect,
|
||||
original_rect=original_rect,
|
||||
confidence=confidence
|
||||
))
|
||||
ret = OcrResultList(ret)
|
||||
if debug.enabled:
|
||||
result_image = _draw_result(img, ret)
|
||||
|
@ -356,7 +368,7 @@ class Ocr:
|
|||
f"elapsed: det={elapse[0]:.3f}s cls={elapse[1]:.3f}s rec={elapse[2]:.3f}s\n" + \
|
||||
f"result: \n" + \
|
||||
"<table class='result-table'><tr><th>Text</th><th>Confidence</th></tr>" + \
|
||||
"\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.2f}</td></tr>" for r in ret]) + \
|
||||
"\n".join([f"<tr><td>{r.text}</td><td>{r.confidence:.3f}</td></tr>" for r in ret]) + \
|
||||
"</table>"
|
||||
)
|
||||
return ret
|
||||
|
@ -380,11 +392,23 @@ class Ocr:
|
|||
"""
|
||||
if hint is not None:
|
||||
if ret := self.find(img, text, rect=hint):
|
||||
logger.debug(f"find: {text} with hint={hint} SUCCESS")
|
||||
return ret
|
||||
for result in self.ocr(img, rect=rect, pad=pad):
|
||||
logger.debug(f"find: {text} with hint={hint} FAILED. Retrying on whole image...")
|
||||
|
||||
start_time = time.time()
|
||||
results = self.ocr(img, rect=rect, pad=pad)
|
||||
end_time = time.time()
|
||||
target = None
|
||||
for result in results:
|
||||
if _is_match(result.text, text):
|
||||
return result
|
||||
return None
|
||||
target = result
|
||||
break
|
||||
logger.debug(
|
||||
f"find: {text} with rect={rect} elapsed={end_time - start_time:.3f}s " + \
|
||||
f"{'SUCCESS' if target else 'FAILED'}"
|
||||
)
|
||||
return target
|
||||
|
||||
def find_all(
|
||||
self,
|
||||
|
|
|
@ -50,11 +50,28 @@ class TestOcr(unittest.TestCase):
|
|||
points = [(5, 5), (5, 5), (5, 5)]
|
||||
assert bounding_box(points) == (5, 5, 0, 0)
|
||||
|
||||
def test_ocr_ocr(self):
|
||||
def test_ocr_basic(self):
|
||||
result = jp.ocr(self.img)
|
||||
self.assertGreater(len(result), 0)
|
||||
|
||||
def test_ocr_find(self):
|
||||
def test_ocr_rect(self):
|
||||
result = jp.ocr(self.img, rect=(147, 614, 417, 32), pad=True)
|
||||
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
|
||||
x, y, w, h = result[0].original_rect
|
||||
self.assertAlmostEqual(x, 147, delta=10)
|
||||
self.assertAlmostEqual(y, 614, delta=10)
|
||||
self.assertAlmostEqual(w, 417, delta=10)
|
||||
self.assertAlmostEqual(h, 32, delta=10)
|
||||
|
||||
result = jp.ocr(self.img, rect=(147, 614, 417, 32), pad=False)
|
||||
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
|
||||
x, y, w, h = result[0].original_rect
|
||||
self.assertAlmostEqual(x, 147, delta=10)
|
||||
self.assertAlmostEqual(y, 614, delta=10)
|
||||
self.assertAlmostEqual(w, 417, delta=10)
|
||||
self.assertAlmostEqual(h, 32, delta=10)
|
||||
|
||||
def test_find(self):
|
||||
self.assertTrue(jp.find(self.img, '中間まで'))
|
||||
self.assertTrue(jp.find(self.img, '受け取るPドリンクを選んでください。'))
|
||||
self.assertTrue(jp.find(self.img, '受け取る'))
|
||||
|
@ -62,25 +79,25 @@ class TestOcr(unittest.TestCase):
|
|||
|
||||
class TestOcrResult(unittest.TestCase):
|
||||
def test_regex(self):
|
||||
result = OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95)
|
||||
result = OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
|
||||
self.assertEqual(result.regex(r'\d+'), ['123', '4567', '890'])
|
||||
self.assertEqual(result.regex(re.compile(r'\d+')), ['123', '4567', '890'])
|
||||
|
||||
def test_numbers(self):
|
||||
result = OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95)
|
||||
result = OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
|
||||
self.assertEqual(result.numbers(), [123, 4567, 890])
|
||||
result2 = OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95)
|
||||
result2 = OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
|
||||
self.assertEqual(result2.numbers(), [])
|
||||
result3 = OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95)
|
||||
result3 = OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
|
||||
self.assertEqual(result3.numbers(), [1234567890])
|
||||
|
||||
|
||||
class TestOcrResultList(unittest.TestCase):
|
||||
def test_list_compatibility(self):
|
||||
result = OcrResultList([
|
||||
OcrResult(text='abc', rect=(0, 0, 100, 100), confidence=0.95),
|
||||
OcrResult(text='def', rect=(0, 0, 100, 100), confidence=0.95),
|
||||
OcrResult(text='ghi', rect=(0, 0, 100, 100), confidence=0.95),
|
||||
OcrResult(text='abc', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
|
||||
OcrResult(text='def', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
|
||||
OcrResult(text='ghi', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
|
||||
])
|
||||
|
||||
self.assertEqual(result[0].text, 'abc')
|
||||
|
@ -105,17 +122,17 @@ class TestOcrResultList(unittest.TestCase):
|
|||
|
||||
def test_where(self):
|
||||
result = OcrResultList([
|
||||
OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95),
|
||||
OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95),
|
||||
OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95),
|
||||
OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
|
||||
OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
|
||||
OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
|
||||
])
|
||||
self.assertEqual(result.where(lambda x: x.startswith('123')), [result[0], result[2]])
|
||||
|
||||
def test_first(self):
|
||||
result = OcrResultList([
|
||||
OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95),
|
||||
OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95),
|
||||
OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95),
|
||||
OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
|
||||
OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
|
||||
OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
|
||||
])
|
||||
self.assertEqual(result.first(), result[0])
|
||||
result2 = OcrResultList()
|
||||
|
|
Loading…
Reference in New Issue