kotones-auto-assistant/tests/core/test_ocr.py

140 lines
5.8 KiB
Python

import re
import unittest
from kotonebot.backend.ocr import jp
from kotonebot.backend.ocr import OcrResult, OcrResultList, bounding_box
import cv2
class TestOcr(unittest.TestCase):
def setUp(self):
self.img = cv2.imread('tests/images/acquire_pdorinku.png')
def test_bounding_box(self):
# 测试基本情况
# 矩形
points = [(0, 0), (10, 0), (10, 10), (0, 10)]
assert bounding_box(points) == (0, 0, 10, 10)
# 三角形
points = [(0, 0), (10, 0), (5, 10)]
assert bounding_box(points) == (0, 0, 10, 10)
# 不规则四边形
points = [(0, 0), (5, 2), (8, 6), (2, 8)]
assert bounding_box(points) == (0, 0, 8, 8)
# 测试单点
points = [(5, 5)]
assert bounding_box(points) == (5, 5, 0, 0)
# 测试负坐标
points = [(-1, -1), (2, 3)]
assert bounding_box(points) == (-1, -1, 3, 4)
# 测试零宽度矩形
points = [(0, 0), (0, 10)]
assert bounding_box(points) == (0, 0, 0, 10)
# 测试零高度矩形
points = [(0, 0), (10, 0)]
assert bounding_box(points) == (0, 0, 10, 0)
# 测试大数值
points = [(1000000, 1000000), (2000000, 2000000)]
assert bounding_box(points) == (1000000, 1000000, 1000000, 1000000)
# 测试点的无序排列
points = [(10, 0), (0, 10), (0, 0), (10, 10)]
assert bounding_box(points) == (0, 0, 10, 10)
# 测试重复点
points = [(5, 5), (5, 5), (5, 5)]
assert bounding_box(points) == (5, 5, 0, 0)
def test_ocr_basic(self):
result = jp().ocr(self.img)
self.assertGreater(len(result), 0)
def test_ocr_rect(self):
result = jp().ocr(self.img, rect=(147, 614, 417, 32), pad=True)
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
x, y, w, h = result[0].original_rect
self.assertAlmostEqual(x, 147, delta=10)
self.assertAlmostEqual(y, 614, delta=10)
self.assertAlmostEqual(w, 417, delta=10)
self.assertAlmostEqual(h, 32, delta=10)
result = jp().ocr(self.img, rect=(147, 614, 417, 32), pad=False)
self.assertEqual(result[0].text, '受け取るPドリンクを選んでください。')
x, y, w, h = result[0].original_rect
self.assertAlmostEqual(x, 147, delta=10)
self.assertAlmostEqual(y, 614, delta=10)
self.assertAlmostEqual(w, 417, delta=10)
self.assertAlmostEqual(h, 32, delta=10)
def test_find(self):
self.assertTrue(jp().find(self.img, '中間まで'))
self.assertTrue(jp().find(self.img, '受け取るPドリンクを選んでください。'))
self.assertTrue(jp().find(self.img, '受け取る'))
class TestOcrResult(unittest.TestCase):
def test_regex(self):
result = OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
self.assertEqual(result.regex(r'\d+'), ['123', '4567', '890'])
self.assertEqual(result.regex(re.compile(r'\d+')), ['123', '4567', '890'])
def test_numbers(self):
result = OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
self.assertEqual(result.numbers(), [123, 4567, 890])
result2 = OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
self.assertEqual(result2.numbers(), [])
result3 = OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100))
self.assertEqual(result3.numbers(), [1234567890])
class TestOcrResultList(unittest.TestCase):
def test_list_compatibility(self):
result = OcrResultList([
OcrResult(text='abc', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='def', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='ghi', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
])
self.assertEqual(result[0].text, 'abc')
self.assertEqual(result[-1].text, 'ghi')
self.assertEqual(len(result[1:]), 2)
self.assertEqual(result[1:][0].text, 'def')
self.assertEqual([r.text for r in result], ['abc', 'def', 'ghi'])
self.assertEqual(len(result), 3)
self.assertTrue(result[0] in result)
# 空列表
result = OcrResultList()
self.assertEqual(result, [])
self.assertEqual(bool(result), False)
self.assertEqual(len(result), 0)
with self.assertRaises(IndexError):
result[0]
with self.assertRaises(IndexError):
result[-1]
self.assertEqual(result[1:], [])
self.assertEqual([r.text for r in result], [])
def test_where(self):
result = OcrResultList([
OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
])
self.assertEqual(result.where(lambda x: x.startswith('123')), [result[0], result[2]])
def test_first(self):
result = OcrResultList([
OcrResult(text='123dd4567rr890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='aaa', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
OcrResult(text='1234567890', rect=(0, 0, 100, 100), confidence=0.95, original_rect=(0, 0, 100, 100)),
])
self.assertEqual(result.first(), result[0])
result2 = OcrResultList()
self.assertIsNone(result2.first())