From c9c67e65202f359aa5170875aef2c5679dedf1a3 Mon Sep 17 00:00:00 2001 From: XcantloadX <3188996979@qq.com> Date: Mon, 5 May 2025 17:39:34 +0800 Subject: [PATCH] =?UTF-8?q?refactor(core):=20OCR=20=E5=BC=95=E6=93=8E?= =?UTF-8?q?=E6=8E=A8=E8=BF=9F=E5=88=B0=E5=90=AF=E5=8A=A8=E8=84=9A=E6=9C=AC?= =?UTF-8?q?=E6=97=B6=E5=8A=A0=E8=BD=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- kotonebot/backend/context/context.py | 6 ++-- kotonebot/backend/ocr.py | 54 +++++++++++++++++----------- kotonebot/primitives/__init__.py | 0 tests/core/test_ocr.py | 12 +++---- 4 files changed, 43 insertions(+), 29 deletions(-) create mode 100644 kotonebot/primitives/__init__.py diff --git a/kotonebot/backend/context/context.py b/kotonebot/backend/context/context.py index 778d526..11716f0 100644 --- a/kotonebot/backend/context/context.py +++ b/kotonebot/backend/context/context.py @@ -255,7 +255,7 @@ class ContextStackVars: class ContextOcr: def __init__(self, context: 'Context'): self.context = context - self.__engine = jp + self.__engine = jp() def raw(self, lang: OcrLanguage = 'jp') -> Ocr: """ @@ -264,9 +264,9 @@ class ContextOcr: """ match lang: case 'jp': - return jp + return jp() case 'en': - return en + return en() case _: raise ValueError(f"Invalid language: {lang}") diff --git a/kotonebot/backend/ocr.py b/kotonebot/backend/ocr.py index 69fe62c..3fcf9e2 100644 --- a/kotonebot/backend/ocr.py +++ b/kotonebot/backend/ocr.py @@ -20,21 +20,6 @@ from ..util import Rect, lf_path from .debug import result as debug_result, debug logger = logging.getLogger(__name__) - -# TODO: 这个路径需要能够独立设置 -_engine_jp = RapidOCR( - rec_model_path=lf_path('models/japan_PP-OCRv4_rec_infer.onnx'), - use_det=True, - use_cls=False, - use_rec=True, -) -_engine_en = RapidOCR( - rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'), - use_det=True, - use_cls=False, - use_rec=True, -) - StringMatchFunction = Callable[[str], bool] REGEX_NUMBERS = re.compile(r'\d+') @@ -483,13 +468,42 @@ class Ocr: raise TextNotFoundError(text, img) return ret +# TODO: 这个路径需要能够独立设置 +_engine_jp: RapidOCR | None = None +_engine_en: RapidOCR | None = RapidOCR( + rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'), + use_det=True, + use_cls=False, + use_rec=True, +) +def jp() -> Ocr: + """ + 日语 OCR 引擎。 + """ + global _engine_jp + if _engine_jp is None: + _engine_jp = RapidOCR( + rec_model_path=lf_path('models/japan_PP-OCRv3_rec_infer.onnx'), + use_det=True, + use_cls=False, + use_rec=True, + ) + return Ocr(_engine_jp) -jp = Ocr(_engine_jp) -"""日语 OCR 引擎。""" -en = Ocr(_engine_en) -"""英语 OCR 引擎。""" - +def en() -> Ocr: + """ + 英语 OCR 引擎。 + """ + global _engine_en + if _engine_en is None: + _engine_en = RapidOCR( + rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'), + use_det=True, + use_cls=False, + use_rec=True, + ) + return Ocr(_engine_en) if __name__ == '__main__': diff --git a/kotonebot/primitives/__init__.py b/kotonebot/primitives/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/test_ocr.py b/tests/core/test_ocr.py index 1a132d7..f3471d2 100644 --- a/tests/core/test_ocr.py +++ b/tests/core/test_ocr.py @@ -51,11 +51,11 @@ class TestOcr(unittest.TestCase): assert bounding_box(points) == (5, 5, 0, 0) def test_ocr_basic(self): - result = jp.ocr(self.img) + result = jp().ocr(self.img) self.assertGreater(len(result), 0) def test_ocr_rect(self): - result = jp.ocr(self.img, rect=(147, 614, 417, 32), pad=True) + result = jp().ocr(self.img, rect=(147, 614, 417, 32), pad=True) self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。') x, y, w, h = result[0].original_rect self.assertAlmostEqual(x, 147, delta=10) @@ -63,7 +63,7 @@ class TestOcr(unittest.TestCase): self.assertAlmostEqual(w, 417, delta=10) self.assertAlmostEqual(h, 32, delta=10) - result = jp.ocr(self.img, rect=(147, 614, 417, 32), pad=False) + result = jp().ocr(self.img, rect=(147, 614, 417, 32), pad=False) self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。') x, y, w, h = result[0].original_rect self.assertAlmostEqual(x, 147, delta=10) @@ -72,9 +72,9 @@ class TestOcr(unittest.TestCase): self.assertAlmostEqual(h, 32, delta=10) def test_find(self): - self.assertTrue(jp.find(self.img, '中間まで')) - self.assertTrue(jp.find(self.img, '受け取るPドリンクを選んでください。')) - self.assertTrue(jp.find(self.img, '受け取る')) + self.assertTrue(jp().find(self.img, '中間まで')) + self.assertTrue(jp().find(self.img, '受け取るPドリンクを選んでください。')) + self.assertTrue(jp().find(self.img, '受け取る')) class TestOcrResult(unittest.TestCase):