refactor(core): OCR 引擎推迟到启动脚本时加载

This commit is contained in:
XcantloadX 2025-05-05 17:39:34 +08:00
parent 8dc76e0f92
commit c9c67e6520
4 changed files with 43 additions and 29 deletions

View File

@ -255,7 +255,7 @@ class ContextStackVars:
class ContextOcr:
def __init__(self, context: 'Context'):
self.context = context
self.__engine = jp
self.__engine = jp()
def raw(self, lang: OcrLanguage = 'jp') -> Ocr:
"""
@ -264,9 +264,9 @@ class ContextOcr:
"""
match lang:
case 'jp':
return jp
return jp()
case 'en':
return en
return en()
case _:
raise ValueError(f"Invalid language: {lang}")

View File

@ -20,21 +20,6 @@ from ..util import Rect, lf_path
from .debug import result as debug_result, debug
logger = logging.getLogger(__name__)
# TODO: 这个路径需要能够独立设置
_engine_jp = RapidOCR(
rec_model_path=lf_path('models/japan_PP-OCRv4_rec_infer.onnx'),
use_det=True,
use_cls=False,
use_rec=True,
)
_engine_en = RapidOCR(
rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
use_det=True,
use_cls=False,
use_rec=True,
)
StringMatchFunction = Callable[[str], bool]
REGEX_NUMBERS = re.compile(r'\d+')
@ -483,13 +468,42 @@ class Ocr:
raise TextNotFoundError(text, img)
return ret
# TODO: 这个路径需要能够独立设置
_engine_jp: RapidOCR | None = None
_engine_en: RapidOCR | None = RapidOCR(
rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
use_det=True,
use_cls=False,
use_rec=True,
)
def jp() -> Ocr:
"""
日语 OCR 引擎
"""
global _engine_jp
if _engine_jp is None:
_engine_jp = RapidOCR(
rec_model_path=lf_path('models/japan_PP-OCRv3_rec_infer.onnx'),
use_det=True,
use_cls=False,
use_rec=True,
)
return Ocr(_engine_jp)
jp = Ocr(_engine_jp)
"""日语 OCR 引擎。"""
en = Ocr(_engine_en)
"""英语 OCR 引擎。"""
def en() -> Ocr:
"""
英语 OCR 引擎
"""
global _engine_en
if _engine_en is None:
_engine_en = RapidOCR(
rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
use_det=True,
use_cls=False,
use_rec=True,
)
return Ocr(_engine_en)
if __name__ == '__main__':

View File

View File

@ -51,11 +51,11 @@ class TestOcr(unittest.TestCase):
assert bounding_box(points) == (5, 5, 0, 0)
def test_ocr_basic(self):
result = jp.ocr(self.img)
result = jp().ocr(self.img)
self.assertGreater(len(result), 0)
def test_ocr_rect(self):
result = jp.ocr(self.img, rect=(147, 614, 417, 32), pad=True)
result = jp().ocr(self.img, rect=(147, 614, 417, 32), pad=True)
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
x, y, w, h = result[0].original_rect
self.assertAlmostEqual(x, 147, delta=10)
@ -63,7 +63,7 @@ class TestOcr(unittest.TestCase):
self.assertAlmostEqual(w, 417, delta=10)
self.assertAlmostEqual(h, 32, delta=10)
result = jp.ocr(self.img, rect=(147, 614, 417, 32), pad=False)
result = jp().ocr(self.img, rect=(147, 614, 417, 32), pad=False)
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
x, y, w, h = result[0].original_rect
self.assertAlmostEqual(x, 147, delta=10)
@ -72,9 +72,9 @@ class TestOcr(unittest.TestCase):
self.assertAlmostEqual(h, 32, delta=10)
def test_find(self):
self.assertTrue(jp.find(self.img, '中間まで'))
self.assertTrue(jp.find(self.img, '受け取るPドリンクを選んでください。'))
self.assertTrue(jp.find(self.img, '受け取る'))
self.assertTrue(jp().find(self.img, '中間まで'))
self.assertTrue(jp().find(self.img, '受け取るPドリンクを選んでください。'))
self.assertTrue(jp().find(self.img, '受け取る'))
class TestOcrResult(unittest.TestCase):