refactor(core): OCR 引擎推迟到启动脚本时加载
This commit is contained in:
parent
8dc76e0f92
commit
c9c67e6520
|
@ -255,7 +255,7 @@ class ContextStackVars:
|
||||||
class ContextOcr:
|
class ContextOcr:
|
||||||
def __init__(self, context: 'Context'):
|
def __init__(self, context: 'Context'):
|
||||||
self.context = context
|
self.context = context
|
||||||
self.__engine = jp
|
self.__engine = jp()
|
||||||
|
|
||||||
def raw(self, lang: OcrLanguage = 'jp') -> Ocr:
|
def raw(self, lang: OcrLanguage = 'jp') -> Ocr:
|
||||||
"""
|
"""
|
||||||
|
@ -264,9 +264,9 @@ class ContextOcr:
|
||||||
"""
|
"""
|
||||||
match lang:
|
match lang:
|
||||||
case 'jp':
|
case 'jp':
|
||||||
return jp
|
return jp()
|
||||||
case 'en':
|
case 'en':
|
||||||
return en
|
return en()
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Invalid language: {lang}")
|
raise ValueError(f"Invalid language: {lang}")
|
||||||
|
|
||||||
|
|
|
@ -20,21 +20,6 @@ from ..util import Rect, lf_path
|
||||||
from .debug import result as debug_result, debug
|
from .debug import result as debug_result, debug
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# TODO: 这个路径需要能够独立设置
|
|
||||||
_engine_jp = RapidOCR(
|
|
||||||
rec_model_path=lf_path('models/japan_PP-OCRv4_rec_infer.onnx'),
|
|
||||||
use_det=True,
|
|
||||||
use_cls=False,
|
|
||||||
use_rec=True,
|
|
||||||
)
|
|
||||||
_engine_en = RapidOCR(
|
|
||||||
rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
|
|
||||||
use_det=True,
|
|
||||||
use_cls=False,
|
|
||||||
use_rec=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
StringMatchFunction = Callable[[str], bool]
|
StringMatchFunction = Callable[[str], bool]
|
||||||
REGEX_NUMBERS = re.compile(r'\d+')
|
REGEX_NUMBERS = re.compile(r'\d+')
|
||||||
|
|
||||||
|
@ -483,13 +468,42 @@ class Ocr:
|
||||||
raise TextNotFoundError(text, img)
|
raise TextNotFoundError(text, img)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
# TODO: 这个路径需要能够独立设置
|
||||||
|
_engine_jp: RapidOCR | None = None
|
||||||
|
_engine_en: RapidOCR | None = RapidOCR(
|
||||||
|
rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
|
||||||
|
use_det=True,
|
||||||
|
use_cls=False,
|
||||||
|
use_rec=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def jp() -> Ocr:
|
||||||
|
"""
|
||||||
|
日语 OCR 引擎。
|
||||||
|
"""
|
||||||
|
global _engine_jp
|
||||||
|
if _engine_jp is None:
|
||||||
|
_engine_jp = RapidOCR(
|
||||||
|
rec_model_path=lf_path('models/japan_PP-OCRv3_rec_infer.onnx'),
|
||||||
|
use_det=True,
|
||||||
|
use_cls=False,
|
||||||
|
use_rec=True,
|
||||||
|
)
|
||||||
|
return Ocr(_engine_jp)
|
||||||
|
|
||||||
jp = Ocr(_engine_jp)
|
def en() -> Ocr:
|
||||||
"""日语 OCR 引擎。"""
|
"""
|
||||||
en = Ocr(_engine_en)
|
英语 OCR 引擎。
|
||||||
"""英语 OCR 引擎。"""
|
"""
|
||||||
|
global _engine_en
|
||||||
|
if _engine_en is None:
|
||||||
|
_engine_en = RapidOCR(
|
||||||
|
rec_model_path=lf_path('models/en_PP-OCRv3_rec_infer.onnx'),
|
||||||
|
use_det=True,
|
||||||
|
use_cls=False,
|
||||||
|
use_rec=True,
|
||||||
|
)
|
||||||
|
return Ocr(_engine_en)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
|
|
@ -51,11 +51,11 @@ class TestOcr(unittest.TestCase):
|
||||||
assert bounding_box(points) == (5, 5, 0, 0)
|
assert bounding_box(points) == (5, 5, 0, 0)
|
||||||
|
|
||||||
def test_ocr_basic(self):
|
def test_ocr_basic(self):
|
||||||
result = jp.ocr(self.img)
|
result = jp().ocr(self.img)
|
||||||
self.assertGreater(len(result), 0)
|
self.assertGreater(len(result), 0)
|
||||||
|
|
||||||
def test_ocr_rect(self):
|
def test_ocr_rect(self):
|
||||||
result = jp.ocr(self.img, rect=(147, 614, 417, 32), pad=True)
|
result = jp().ocr(self.img, rect=(147, 614, 417, 32), pad=True)
|
||||||
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
|
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
|
||||||
x, y, w, h = result[0].original_rect
|
x, y, w, h = result[0].original_rect
|
||||||
self.assertAlmostEqual(x, 147, delta=10)
|
self.assertAlmostEqual(x, 147, delta=10)
|
||||||
|
@ -63,7 +63,7 @@ class TestOcr(unittest.TestCase):
|
||||||
self.assertAlmostEqual(w, 417, delta=10)
|
self.assertAlmostEqual(w, 417, delta=10)
|
||||||
self.assertAlmostEqual(h, 32, delta=10)
|
self.assertAlmostEqual(h, 32, delta=10)
|
||||||
|
|
||||||
result = jp.ocr(self.img, rect=(147, 614, 417, 32), pad=False)
|
result = jp().ocr(self.img, rect=(147, 614, 417, 32), pad=False)
|
||||||
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
|
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
|
||||||
x, y, w, h = result[0].original_rect
|
x, y, w, h = result[0].original_rect
|
||||||
self.assertAlmostEqual(x, 147, delta=10)
|
self.assertAlmostEqual(x, 147, delta=10)
|
||||||
|
@ -72,9 +72,9 @@ class TestOcr(unittest.TestCase):
|
||||||
self.assertAlmostEqual(h, 32, delta=10)
|
self.assertAlmostEqual(h, 32, delta=10)
|
||||||
|
|
||||||
def test_find(self):
|
def test_find(self):
|
||||||
self.assertTrue(jp.find(self.img, '中間まで'))
|
self.assertTrue(jp().find(self.img, '中間まで'))
|
||||||
self.assertTrue(jp.find(self.img, '受け取るPドリンクを選んでください。'))
|
self.assertTrue(jp().find(self.img, '受け取るPドリンクを選んでください。'))
|
||||||
self.assertTrue(jp.find(self.img, '受け取る'))
|
self.assertTrue(jp().find(self.img, '受け取る'))
|
||||||
|
|
||||||
|
|
||||||
class TestOcrResult(unittest.TestCase):
|
class TestOcrResult(unittest.TestCase):
|
||||||
|
|
Loading…
Reference in New Issue