feat(core): 新增图像预处理器
1. Image/ContextImage 对象新增 `preprocessors` 参数 2. 新增 HsvColorFilter,用于过滤出指定颜色
This commit is contained in:
parent
d08af8f715
commit
b831e9e2bd
|
@ -42,13 +42,16 @@ import kotonebot.backend.color as raw_color
|
||||||
from kotonebot.backend.color import (
|
from kotonebot.backend.color import (
|
||||||
find as color_find, find_all as color_find_all
|
find as color_find, find_all as color_find_all
|
||||||
)
|
)
|
||||||
from kotonebot.backend.ocr import Ocr, OcrResult, OcrResultList, jp, en, StringMatchFunction
|
from kotonebot.backend.ocr import (
|
||||||
|
Ocr, OcrResult, OcrResultList, jp, en, StringMatchFunction
|
||||||
|
)
|
||||||
from kotonebot.client.factory import create_device
|
from kotonebot.client.factory import create_device
|
||||||
from kotonebot.config.manager import load_config, save_config
|
from kotonebot.config.manager import load_config, save_config
|
||||||
from kotonebot.config.base_config import UserConfig
|
from kotonebot.config.base_config import UserConfig
|
||||||
from kotonebot.backend.core import Image, HintBox
|
from kotonebot.backend.core import Image, HintBox
|
||||||
from kotonebot.errors import KotonebotWarning
|
from kotonebot.errors import KotonebotWarning
|
||||||
from kotonebot.client.factory import DeviceImpl
|
from kotonebot.client.factory import DeviceImpl
|
||||||
|
from kotonebot.backend.preprocessor import PreprocessorProtocol
|
||||||
|
|
||||||
OcrLanguage = Literal['jp', 'en']
|
OcrLanguage = Literal['jp', 'en']
|
||||||
ScreenshotMode = Literal['auto', 'manual', 'manual-inherit']
|
ScreenshotMode = Literal['auto', 'manual', 'manual-inherit']
|
||||||
|
@ -396,6 +399,7 @@ class ContextImage:
|
||||||
*,
|
*,
|
||||||
transparent: bool = False,
|
transparent: bool = False,
|
||||||
interval: float = DEFAULT_INTERVAL,
|
interval: float = DEFAULT_INTERVAL,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
) -> TemplateMatchResult | None:
|
) -> TemplateMatchResult | None:
|
||||||
"""
|
"""
|
||||||
等待指定图像出现。
|
等待指定图像出现。
|
||||||
|
@ -406,7 +410,14 @@ class ContextImage:
|
||||||
while True:
|
while True:
|
||||||
if is_manual:
|
if is_manual:
|
||||||
device.screenshot()
|
device.screenshot()
|
||||||
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored)
|
ret = self.find(
|
||||||
|
template,
|
||||||
|
mask,
|
||||||
|
transparent=transparent,
|
||||||
|
threshold=threshold,
|
||||||
|
colored=colored,
|
||||||
|
preprocessors=preprocessors,
|
||||||
|
)
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
self.context.device.last_find = ret
|
self.context.device.last_find = ret
|
||||||
return ret
|
return ret
|
||||||
|
@ -423,7 +434,8 @@ class ContextImage:
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
*,
|
*,
|
||||||
transparent: bool = False,
|
transparent: bool = False,
|
||||||
interval: float = DEFAULT_INTERVAL
|
interval: float = DEFAULT_INTERVAL,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
等待指定图像中的任意一个出现。
|
等待指定图像中的任意一个出现。
|
||||||
|
@ -439,7 +451,14 @@ class ContextImage:
|
||||||
if is_manual:
|
if is_manual:
|
||||||
device.screenshot()
|
device.screenshot()
|
||||||
for template, mask in zip(templates, _masks):
|
for template, mask in zip(templates, _masks):
|
||||||
if self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored):
|
if self.find(
|
||||||
|
template,
|
||||||
|
mask,
|
||||||
|
transparent=transparent,
|
||||||
|
threshold=threshold,
|
||||||
|
colored=colored,
|
||||||
|
preprocessors=preprocessors,
|
||||||
|
):
|
||||||
return True
|
return True
|
||||||
if time.time() - start_time > timeout:
|
if time.time() - start_time > timeout:
|
||||||
return False
|
return False
|
||||||
|
@ -454,7 +473,8 @@ class ContextImage:
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
*,
|
*,
|
||||||
transparent: bool = False,
|
transparent: bool = False,
|
||||||
interval: float = DEFAULT_INTERVAL
|
interval: float = DEFAULT_INTERVAL,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
) -> TemplateMatchResult:
|
) -> TemplateMatchResult:
|
||||||
"""
|
"""
|
||||||
等待指定图像出现。
|
等待指定图像出现。
|
||||||
|
@ -465,7 +485,14 @@ class ContextImage:
|
||||||
while True:
|
while True:
|
||||||
if is_manual:
|
if is_manual:
|
||||||
device.screenshot()
|
device.screenshot()
|
||||||
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored)
|
ret = self.find(
|
||||||
|
template,
|
||||||
|
mask,
|
||||||
|
transparent=transparent,
|
||||||
|
threshold=threshold,
|
||||||
|
colored=colored,
|
||||||
|
preprocessors=preprocessors,
|
||||||
|
)
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
self.context.device.last_find = ret
|
self.context.device.last_find = ret
|
||||||
return ret
|
return ret
|
||||||
|
@ -482,7 +509,8 @@ class ContextImage:
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
*,
|
*,
|
||||||
transparent: bool = False,
|
transparent: bool = False,
|
||||||
interval: float = DEFAULT_INTERVAL
|
interval: float = DEFAULT_INTERVAL,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
) -> TemplateMatchResult:
|
) -> TemplateMatchResult:
|
||||||
"""
|
"""
|
||||||
等待指定图像中的任意一个出现。
|
等待指定图像中的任意一个出现。
|
||||||
|
@ -498,7 +526,14 @@ class ContextImage:
|
||||||
if is_manual:
|
if is_manual:
|
||||||
device.screenshot()
|
device.screenshot()
|
||||||
for template, mask in zip(templates, _masks):
|
for template, mask in zip(templates, _masks):
|
||||||
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored)
|
ret = self.find(
|
||||||
|
template,
|
||||||
|
mask,
|
||||||
|
transparent=transparent,
|
||||||
|
threshold=threshold,
|
||||||
|
colored=colored,
|
||||||
|
preprocessors=preprocessors,
|
||||||
|
)
|
||||||
if ret is not None:
|
if ret is not None:
|
||||||
self.context.device.last_find = ret
|
self.context.device.last_find = ret
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -10,6 +10,7 @@ from skimage.metrics import structural_similarity
|
||||||
from .core import Image, unify_image
|
from .core import Image, unify_image
|
||||||
from ..util import Rect, Point
|
from ..util import Rect, Point
|
||||||
from .debug import result as debug_result, debug, img
|
from .debug import result as debug_result, debug, img
|
||||||
|
from .preprocessor import PreprocessorProtocol
|
||||||
|
|
||||||
logger = getLogger(__name__)
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
@ -124,6 +125,8 @@ def _results2str(results: Sequence[TemplateMatchResult | MultipleTemplateMatchRe
|
||||||
return 'None'
|
return 'None'
|
||||||
return ', '.join([_result2str(result) for result in results])
|
return ', '.join([_result2str(result) for result in results])
|
||||||
|
|
||||||
|
# TODO: 应该把 template_match 和 find、wait、expect 等函数的公共参数提取出来
|
||||||
|
# TODO: 需要在调试结果中输出 preprocessors 处理后的图像
|
||||||
def template_match(
|
def template_match(
|
||||||
template: MatLike | str | Image,
|
template: MatLike | str | Image,
|
||||||
image: MatLike | str | Image,
|
image: MatLike | str | Image,
|
||||||
|
@ -134,6 +137,7 @@ def template_match(
|
||||||
max_results: int = 5,
|
max_results: int = 5,
|
||||||
remove_duplicate: bool = True,
|
remove_duplicate: bool = True,
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
) -> list[TemplateMatchResult]:
|
) -> list[TemplateMatchResult]:
|
||||||
"""
|
"""
|
||||||
寻找模板在图像中的位置。
|
寻找模板在图像中的位置。
|
||||||
|
@ -150,6 +154,7 @@ def template_match(
|
||||||
:param max_results: 最大结果数,默认为 1。
|
:param max_results: 最大结果数,默认为 1。
|
||||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||||
:param colored: 是否匹配颜色,默认为 False。
|
:param colored: 是否匹配颜色,默认为 False。
|
||||||
|
:param preprocessors: 预处理列表,默认为 None。
|
||||||
"""
|
"""
|
||||||
# 统一参数
|
# 统一参数
|
||||||
template = unify_image(template, transparent)
|
template = unify_image(template, transparent)
|
||||||
|
@ -164,6 +169,13 @@ def template_match(
|
||||||
# 从透明图像中提取 alpha 通道作为 mask
|
# 从透明图像中提取 alpha 通道作为 mask
|
||||||
mask = cv2.threshold(template[:, :, 3], 0, 255, cv2.THRESH_BINARY)[1]
|
mask = cv2.threshold(template[:, :, 3], 0, 255, cv2.THRESH_BINARY)[1]
|
||||||
template = template[:, :, :3]
|
template = template[:, :, :3]
|
||||||
|
# 预处理
|
||||||
|
if preprocessors is not None:
|
||||||
|
for preprocessor in preprocessors:
|
||||||
|
image = preprocessor.process(image)
|
||||||
|
template = preprocessor.process(template)
|
||||||
|
if mask is not None:
|
||||||
|
mask = preprocessor.process(mask)
|
||||||
# 匹配模板
|
# 匹配模板
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
# https://stackoverflow.com/questions/35642497/python-opencv-cv2-matchtemplate-with-transparency
|
# https://stackoverflow.com/questions/35642497/python-opencv-cv2-matchtemplate-with-transparency
|
||||||
|
@ -302,6 +314,7 @@ def find_all_crop(
|
||||||
*,
|
*,
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
remove_duplicate: bool = True,
|
remove_duplicate: bool = True,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
) -> list[CropResult]:
|
) -> list[CropResult]:
|
||||||
"""
|
"""
|
||||||
指定一个模板,在输入图像中寻找其出现的所有位置,并裁剪出结果。
|
指定一个模板,在输入图像中寻找其出现的所有位置,并裁剪出结果。
|
||||||
|
@ -313,6 +326,7 @@ def find_all_crop(
|
||||||
:param threshold: 阈值,默认为 0.8。
|
:param threshold: 阈值,默认为 0.8。
|
||||||
:param colored: 是否匹配颜色,默认为 False。
|
:param colored: 是否匹配颜色,默认为 False。
|
||||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||||
|
:param preprocessors: 预处理列表,默认为 None。
|
||||||
"""
|
"""
|
||||||
matches = template_match(
|
matches = template_match(
|
||||||
template,
|
template,
|
||||||
|
@ -323,6 +337,7 @@ def find_all_crop(
|
||||||
max_results=-1,
|
max_results=-1,
|
||||||
remove_duplicate=remove_duplicate,
|
remove_duplicate=remove_duplicate,
|
||||||
colored=colored,
|
colored=colored,
|
||||||
|
preprocessors=preprocessors,
|
||||||
)
|
)
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
# f'find_all_crop(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
# f'find_all_crop(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
||||||
|
@ -345,6 +360,7 @@ def find(
|
||||||
debug_output: bool = True,
|
debug_output: bool = True,
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
remove_duplicate: bool = True,
|
remove_duplicate: bool = True,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
) -> TemplateMatchResult | None:
|
) -> TemplateMatchResult | None:
|
||||||
"""
|
"""
|
||||||
指定一个模板,在输入图像中寻找其出现的第一个位置。
|
指定一个模板,在输入图像中寻找其出现的第一个位置。
|
||||||
|
@ -357,6 +373,7 @@ def find(
|
||||||
:param debug_output: 是否输出调试信息,默认为 True。
|
:param debug_output: 是否输出调试信息,默认为 True。
|
||||||
:param colored: 是否匹配颜色,默认为 False。
|
:param colored: 是否匹配颜色,默认为 False。
|
||||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||||
|
:param preprocessors: 预处理列表,默认为 None。
|
||||||
"""
|
"""
|
||||||
matches = template_match(
|
matches = template_match(
|
||||||
template,
|
template,
|
||||||
|
@ -367,6 +384,7 @@ def find(
|
||||||
max_results=1,
|
max_results=1,
|
||||||
remove_duplicate=remove_duplicate,
|
remove_duplicate=remove_duplicate,
|
||||||
colored=colored,
|
colored=colored,
|
||||||
|
preprocessors=preprocessors,
|
||||||
)
|
)
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
# f'find(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
# f'find(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
||||||
|
@ -396,6 +414,7 @@ def find_all(
|
||||||
remove_duplicate: bool = True,
|
remove_duplicate: bool = True,
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
debug_output: bool = True,
|
debug_output: bool = True,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
) -> list[TemplateMatchResult]:
|
) -> list[TemplateMatchResult]:
|
||||||
"""
|
"""
|
||||||
指定一个模板,在输入图像中寻找其出现的所有位置。
|
指定一个模板,在输入图像中寻找其出现的所有位置。
|
||||||
|
@ -407,6 +426,7 @@ def find_all(
|
||||||
:param threshold: 阈值,默认为 0.8。
|
:param threshold: 阈值,默认为 0.8。
|
||||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||||
:param colored: 是否匹配颜色,默认为 False。
|
:param colored: 是否匹配颜色,默认为 False。
|
||||||
|
:param preprocessors: 预处理列表,默认为 None。
|
||||||
"""
|
"""
|
||||||
results = template_match(
|
results = template_match(
|
||||||
template,
|
template,
|
||||||
|
@ -417,6 +437,7 @@ def find_all(
|
||||||
max_results=-1,
|
max_results=-1,
|
||||||
remove_duplicate=remove_duplicate,
|
remove_duplicate=remove_duplicate,
|
||||||
colored=colored,
|
colored=colored,
|
||||||
|
preprocessors=preprocessors,
|
||||||
)
|
)
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
# f'find_all(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
# f'find_all(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
||||||
|
@ -441,6 +462,7 @@ def find_multi(
|
||||||
threshold: float = 0.8,
|
threshold: float = 0.8,
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
remove_duplicate: bool = True,
|
remove_duplicate: bool = True,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
) -> MultipleTemplateMatchResult | None:
|
) -> MultipleTemplateMatchResult | None:
|
||||||
"""
|
"""
|
||||||
指定多个模板,在输入图像中逐个寻找模板,返回第一个匹配到的结果。
|
指定多个模板,在输入图像中逐个寻找模板,返回第一个匹配到的结果。
|
||||||
|
@ -452,6 +474,7 @@ def find_multi(
|
||||||
:param threshold: 阈值,默认为 0.8。
|
:param threshold: 阈值,默认为 0.8。
|
||||||
:param colored: 是否匹配颜色,默认为 False。
|
:param colored: 是否匹配颜色,默认为 False。
|
||||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||||
|
:param preprocessors: 预处理列表,默认为 None。
|
||||||
"""
|
"""
|
||||||
ret = None
|
ret = None
|
||||||
if masks is None:
|
if masks is None:
|
||||||
|
@ -468,6 +491,7 @@ def find_multi(
|
||||||
colored=colored,
|
colored=colored,
|
||||||
debug_output=False,
|
debug_output=False,
|
||||||
remove_duplicate=remove_duplicate,
|
remove_duplicate=remove_duplicate,
|
||||||
|
preprocessors=preprocessors,
|
||||||
)
|
)
|
||||||
# 调试输出
|
# 调试输出
|
||||||
if find_result is not None:
|
if find_result is not None:
|
||||||
|
@ -508,6 +532,7 @@ def find_all_multi(
|
||||||
threshold: float = 0.8,
|
threshold: float = 0.8,
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
remove_duplicate: bool = True,
|
remove_duplicate: bool = True,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
) -> list[MultipleTemplateMatchResult]:
|
) -> list[MultipleTemplateMatchResult]:
|
||||||
"""
|
"""
|
||||||
指定多个模板,在输入图像中逐个寻找模板,返回所有匹配到的结果。
|
指定多个模板,在输入图像中逐个寻找模板,返回所有匹配到的结果。
|
||||||
|
@ -526,6 +551,7 @@ def find_all_multi(
|
||||||
:param threshold: 阈值,默认为 0.8。
|
:param threshold: 阈值,默认为 0.8。
|
||||||
:param colored: 是否匹配颜色,默认为 False。
|
:param colored: 是否匹配颜色,默认为 False。
|
||||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||||
|
:param preprocessors: 预处理列表,默认为 None。
|
||||||
:return: 匹配到的一维结果列表。
|
:return: 匹配到的一维结果列表。
|
||||||
"""
|
"""
|
||||||
ret: list[MultipleTemplateMatchResult] = []
|
ret: list[MultipleTemplateMatchResult] = []
|
||||||
|
@ -544,6 +570,7 @@ def find_all_multi(
|
||||||
colored=colored,
|
colored=colored,
|
||||||
remove_duplicate=remove_duplicate,
|
remove_duplicate=remove_duplicate,
|
||||||
debug_output=False,
|
debug_output=False,
|
||||||
|
preprocessors=preprocessors,
|
||||||
)
|
)
|
||||||
ret.extend([
|
ret.extend([
|
||||||
MultipleTemplateMatchResult.from_template_match_result(r, index)
|
MultipleTemplateMatchResult.from_template_match_result(r, index)
|
||||||
|
@ -591,6 +618,7 @@ def count(
|
||||||
threshold: float = 0.8,
|
threshold: float = 0.8,
|
||||||
remove_duplicate: bool = True,
|
remove_duplicate: bool = True,
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""
|
"""
|
||||||
指定一个模板,统计其出现的次数。
|
指定一个模板,统计其出现的次数。
|
||||||
|
@ -602,6 +630,7 @@ def count(
|
||||||
:param threshold: 阈值,默认为 0.8。
|
:param threshold: 阈值,默认为 0.8。
|
||||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||||
:param colored: 是否匹配颜色,默认为 False。
|
:param colored: 是否匹配颜色,默认为 False。
|
||||||
|
:param preprocessors: 预处理列表,默认为 None。
|
||||||
"""
|
"""
|
||||||
results = template_match(
|
results = template_match(
|
||||||
template,
|
template,
|
||||||
|
@ -612,6 +641,7 @@ def count(
|
||||||
max_results=-1,
|
max_results=-1,
|
||||||
remove_duplicate=remove_duplicate,
|
remove_duplicate=remove_duplicate,
|
||||||
colored=colored,
|
colored=colored,
|
||||||
|
preprocessors=preprocessors,
|
||||||
)
|
)
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
# f'count(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
# f'count(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
||||||
|
@ -641,6 +671,7 @@ def expect(
|
||||||
threshold: float = 0.8,
|
threshold: float = 0.8,
|
||||||
colored: bool = False,
|
colored: bool = False,
|
||||||
remove_duplicate: bool = True,
|
remove_duplicate: bool = True,
|
||||||
|
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||||
) -> TemplateMatchResult:
|
) -> TemplateMatchResult:
|
||||||
"""
|
"""
|
||||||
指定一个模板,寻找其出现的第一个位置。若未找到,则抛出异常。
|
指定一个模板,寻找其出现的第一个位置。若未找到,则抛出异常。
|
||||||
|
@ -652,6 +683,7 @@ def expect(
|
||||||
:param threshold: 阈值,默认为 0.8。
|
:param threshold: 阈值,默认为 0.8。
|
||||||
:param colored: 是否匹配颜色,默认为 False。
|
:param colored: 是否匹配颜色,默认为 False。
|
||||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||||
|
:param preprocessors: 预处理列表,默认为 None。
|
||||||
"""
|
"""
|
||||||
ret = find(
|
ret = find(
|
||||||
image,
|
image,
|
||||||
|
@ -662,6 +694,7 @@ def expect(
|
||||||
colored=colored,
|
colored=colored,
|
||||||
remove_duplicate=remove_duplicate,
|
remove_duplicate=remove_duplicate,
|
||||||
debug_output=False,
|
debug_output=False,
|
||||||
|
preprocessors=preprocessors,
|
||||||
)
|
)
|
||||||
# logger.debug(
|
# logger.debug(
|
||||||
# f'expect(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
# f'expect(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
|
||||||
|
from typing import Protocol
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from cv2.typing import MatLike
|
||||||
|
|
||||||
|
class PreprocessorProtocol(Protocol):
|
||||||
|
"""预处理协议。用于 Image 与 Ocr 中的 `preprocessor` 参数。"""
|
||||||
|
def process(self, image: MatLike) -> MatLike:
|
||||||
|
"""
|
||||||
|
预处理图像。
|
||||||
|
|
||||||
|
:param image: 输入图像,格式为 BGR。
|
||||||
|
:return: 预处理后的图像,格式不限。
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
class HsvColorFilter(PreprocessorProtocol):
|
||||||
|
"""HSV 颜色过滤器。用于保留指定颜色。"""
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
lower: tuple[int, int, int],
|
||||||
|
upper: tuple[int, int, int],
|
||||||
|
*,
|
||||||
|
name: str | None = None,
|
||||||
|
):
|
||||||
|
self.lower = np.array(lower)
|
||||||
|
self.upper = np.array(upper)
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def process(self, image: MatLike) -> MatLike:
|
||||||
|
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||||
|
mask = cv2.inRange(hsv, self.lower, self.upper)
|
||||||
|
return mask
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'HsvColorFilter(for color "{self.name}" with range {self.lower} - {self.upper})'
|
Loading…
Reference in New Issue