feat(core): 新增图像预处理器
1. Image/ContextImage 对象新增 `preprocessors` 参数 2. 新增 HsvColorFilter,用于过滤出指定颜色
This commit is contained in:
parent
d08af8f715
commit
b831e9e2bd
|
@ -42,13 +42,16 @@ import kotonebot.backend.color as raw_color
|
|||
from kotonebot.backend.color import (
|
||||
find as color_find, find_all as color_find_all
|
||||
)
|
||||
from kotonebot.backend.ocr import Ocr, OcrResult, OcrResultList, jp, en, StringMatchFunction
|
||||
from kotonebot.backend.ocr import (
|
||||
Ocr, OcrResult, OcrResultList, jp, en, StringMatchFunction
|
||||
)
|
||||
from kotonebot.client.factory import create_device
|
||||
from kotonebot.config.manager import load_config, save_config
|
||||
from kotonebot.config.base_config import UserConfig
|
||||
from kotonebot.backend.core import Image, HintBox
|
||||
from kotonebot.errors import KotonebotWarning
|
||||
from kotonebot.client.factory import DeviceImpl
|
||||
from kotonebot.backend.preprocessor import PreprocessorProtocol
|
||||
|
||||
OcrLanguage = Literal['jp', 'en']
|
||||
ScreenshotMode = Literal['auto', 'manual', 'manual-inherit']
|
||||
|
@ -396,6 +399,7 @@ class ContextImage:
|
|||
*,
|
||||
transparent: bool = False,
|
||||
interval: float = DEFAULT_INTERVAL,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
) -> TemplateMatchResult | None:
|
||||
"""
|
||||
等待指定图像出现。
|
||||
|
@ -406,7 +410,14 @@ class ContextImage:
|
|||
while True:
|
||||
if is_manual:
|
||||
device.screenshot()
|
||||
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored)
|
||||
ret = self.find(
|
||||
template,
|
||||
mask,
|
||||
transparent=transparent,
|
||||
threshold=threshold,
|
||||
colored=colored,
|
||||
preprocessors=preprocessors,
|
||||
)
|
||||
if ret is not None:
|
||||
self.context.device.last_find = ret
|
||||
return ret
|
||||
|
@ -423,7 +434,8 @@ class ContextImage:
|
|||
colored: bool = False,
|
||||
*,
|
||||
transparent: bool = False,
|
||||
interval: float = DEFAULT_INTERVAL
|
||||
interval: float = DEFAULT_INTERVAL,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
):
|
||||
"""
|
||||
等待指定图像中的任意一个出现。
|
||||
|
@ -439,7 +451,14 @@ class ContextImage:
|
|||
if is_manual:
|
||||
device.screenshot()
|
||||
for template, mask in zip(templates, _masks):
|
||||
if self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored):
|
||||
if self.find(
|
||||
template,
|
||||
mask,
|
||||
transparent=transparent,
|
||||
threshold=threshold,
|
||||
colored=colored,
|
||||
preprocessors=preprocessors,
|
||||
):
|
||||
return True
|
||||
if time.time() - start_time > timeout:
|
||||
return False
|
||||
|
@ -454,7 +473,8 @@ class ContextImage:
|
|||
colored: bool = False,
|
||||
*,
|
||||
transparent: bool = False,
|
||||
interval: float = DEFAULT_INTERVAL
|
||||
interval: float = DEFAULT_INTERVAL,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
) -> TemplateMatchResult:
|
||||
"""
|
||||
等待指定图像出现。
|
||||
|
@ -465,7 +485,14 @@ class ContextImage:
|
|||
while True:
|
||||
if is_manual:
|
||||
device.screenshot()
|
||||
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored)
|
||||
ret = self.find(
|
||||
template,
|
||||
mask,
|
||||
transparent=transparent,
|
||||
threshold=threshold,
|
||||
colored=colored,
|
||||
preprocessors=preprocessors,
|
||||
)
|
||||
if ret is not None:
|
||||
self.context.device.last_find = ret
|
||||
return ret
|
||||
|
@ -482,7 +509,8 @@ class ContextImage:
|
|||
colored: bool = False,
|
||||
*,
|
||||
transparent: bool = False,
|
||||
interval: float = DEFAULT_INTERVAL
|
||||
interval: float = DEFAULT_INTERVAL,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
) -> TemplateMatchResult:
|
||||
"""
|
||||
等待指定图像中的任意一个出现。
|
||||
|
@ -498,7 +526,14 @@ class ContextImage:
|
|||
if is_manual:
|
||||
device.screenshot()
|
||||
for template, mask in zip(templates, _masks):
|
||||
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored)
|
||||
ret = self.find(
|
||||
template,
|
||||
mask,
|
||||
transparent=transparent,
|
||||
threshold=threshold,
|
||||
colored=colored,
|
||||
preprocessors=preprocessors,
|
||||
)
|
||||
if ret is not None:
|
||||
self.context.device.last_find = ret
|
||||
return ret
|
||||
|
|
|
@ -10,6 +10,7 @@ from skimage.metrics import structural_similarity
|
|||
from .core import Image, unify_image
|
||||
from ..util import Rect, Point
|
||||
from .debug import result as debug_result, debug, img
|
||||
from .preprocessor import PreprocessorProtocol
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
@ -124,6 +125,8 @@ def _results2str(results: Sequence[TemplateMatchResult | MultipleTemplateMatchRe
|
|||
return 'None'
|
||||
return ', '.join([_result2str(result) for result in results])
|
||||
|
||||
# TODO: 应该把 template_match 和 find、wait、expect 等函数的公共参数提取出来
|
||||
# TODO: 需要在调试结果中输出 preprocessors 处理后的图像
|
||||
def template_match(
|
||||
template: MatLike | str | Image,
|
||||
image: MatLike | str | Image,
|
||||
|
@ -134,6 +137,7 @@ def template_match(
|
|||
max_results: int = 5,
|
||||
remove_duplicate: bool = True,
|
||||
colored: bool = False,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
) -> list[TemplateMatchResult]:
|
||||
"""
|
||||
寻找模板在图像中的位置。
|
||||
|
@ -150,6 +154,7 @@ def template_match(
|
|||
:param max_results: 最大结果数,默认为 1。
|
||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||
:param colored: 是否匹配颜色,默认为 False。
|
||||
:param preprocessors: 预处理列表,默认为 None。
|
||||
"""
|
||||
# 统一参数
|
||||
template = unify_image(template, transparent)
|
||||
|
@ -164,6 +169,13 @@ def template_match(
|
|||
# 从透明图像中提取 alpha 通道作为 mask
|
||||
mask = cv2.threshold(template[:, :, 3], 0, 255, cv2.THRESH_BINARY)[1]
|
||||
template = template[:, :, :3]
|
||||
# 预处理
|
||||
if preprocessors is not None:
|
||||
for preprocessor in preprocessors:
|
||||
image = preprocessor.process(image)
|
||||
template = preprocessor.process(template)
|
||||
if mask is not None:
|
||||
mask = preprocessor.process(mask)
|
||||
# 匹配模板
|
||||
if mask is not None:
|
||||
# https://stackoverflow.com/questions/35642497/python-opencv-cv2-matchtemplate-with-transparency
|
||||
|
@ -302,6 +314,7 @@ def find_all_crop(
|
|||
*,
|
||||
colored: bool = False,
|
||||
remove_duplicate: bool = True,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
) -> list[CropResult]:
|
||||
"""
|
||||
指定一个模板,在输入图像中寻找其出现的所有位置,并裁剪出结果。
|
||||
|
@ -313,6 +326,7 @@ def find_all_crop(
|
|||
:param threshold: 阈值,默认为 0.8。
|
||||
:param colored: 是否匹配颜色,默认为 False。
|
||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||
:param preprocessors: 预处理列表,默认为 None。
|
||||
"""
|
||||
matches = template_match(
|
||||
template,
|
||||
|
@ -323,6 +337,7 @@ def find_all_crop(
|
|||
max_results=-1,
|
||||
remove_duplicate=remove_duplicate,
|
||||
colored=colored,
|
||||
preprocessors=preprocessors,
|
||||
)
|
||||
# logger.debug(
|
||||
# f'find_all_crop(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
||||
|
@ -345,6 +360,7 @@ def find(
|
|||
debug_output: bool = True,
|
||||
colored: bool = False,
|
||||
remove_duplicate: bool = True,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
) -> TemplateMatchResult | None:
|
||||
"""
|
||||
指定一个模板,在输入图像中寻找其出现的第一个位置。
|
||||
|
@ -357,6 +373,7 @@ def find(
|
|||
:param debug_output: 是否输出调试信息,默认为 True。
|
||||
:param colored: 是否匹配颜色,默认为 False。
|
||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||
:param preprocessors: 预处理列表,默认为 None。
|
||||
"""
|
||||
matches = template_match(
|
||||
template,
|
||||
|
@ -367,6 +384,7 @@ def find(
|
|||
max_results=1,
|
||||
remove_duplicate=remove_duplicate,
|
||||
colored=colored,
|
||||
preprocessors=preprocessors,
|
||||
)
|
||||
# logger.debug(
|
||||
# f'find(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
||||
|
@ -396,6 +414,7 @@ def find_all(
|
|||
remove_duplicate: bool = True,
|
||||
colored: bool = False,
|
||||
debug_output: bool = True,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
) -> list[TemplateMatchResult]:
|
||||
"""
|
||||
指定一个模板,在输入图像中寻找其出现的所有位置。
|
||||
|
@ -407,6 +426,7 @@ def find_all(
|
|||
:param threshold: 阈值,默认为 0.8。
|
||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||
:param colored: 是否匹配颜色,默认为 False。
|
||||
:param preprocessors: 预处理列表,默认为 None。
|
||||
"""
|
||||
results = template_match(
|
||||
template,
|
||||
|
@ -417,6 +437,7 @@ def find_all(
|
|||
max_results=-1,
|
||||
remove_duplicate=remove_duplicate,
|
||||
colored=colored,
|
||||
preprocessors=preprocessors,
|
||||
)
|
||||
# logger.debug(
|
||||
# f'find_all(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
||||
|
@ -441,6 +462,7 @@ def find_multi(
|
|||
threshold: float = 0.8,
|
||||
colored: bool = False,
|
||||
remove_duplicate: bool = True,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
) -> MultipleTemplateMatchResult | None:
|
||||
"""
|
||||
指定多个模板,在输入图像中逐个寻找模板,返回第一个匹配到的结果。
|
||||
|
@ -452,6 +474,7 @@ def find_multi(
|
|||
:param threshold: 阈值,默认为 0.8。
|
||||
:param colored: 是否匹配颜色,默认为 False。
|
||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||
:param preprocessors: 预处理列表,默认为 None。
|
||||
"""
|
||||
ret = None
|
||||
if masks is None:
|
||||
|
@ -468,6 +491,7 @@ def find_multi(
|
|||
colored=colored,
|
||||
debug_output=False,
|
||||
remove_duplicate=remove_duplicate,
|
||||
preprocessors=preprocessors,
|
||||
)
|
||||
# 调试输出
|
||||
if find_result is not None:
|
||||
|
@ -508,6 +532,7 @@ def find_all_multi(
|
|||
threshold: float = 0.8,
|
||||
colored: bool = False,
|
||||
remove_duplicate: bool = True,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
) -> list[MultipleTemplateMatchResult]:
|
||||
"""
|
||||
指定多个模板,在输入图像中逐个寻找模板,返回所有匹配到的结果。
|
||||
|
@ -526,6 +551,7 @@ def find_all_multi(
|
|||
:param threshold: 阈值,默认为 0.8。
|
||||
:param colored: 是否匹配颜色,默认为 False。
|
||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||
:param preprocessors: 预处理列表,默认为 None。
|
||||
:return: 匹配到的一维结果列表。
|
||||
"""
|
||||
ret: list[MultipleTemplateMatchResult] = []
|
||||
|
@ -544,6 +570,7 @@ def find_all_multi(
|
|||
colored=colored,
|
||||
remove_duplicate=remove_duplicate,
|
||||
debug_output=False,
|
||||
preprocessors=preprocessors,
|
||||
)
|
||||
ret.extend([
|
||||
MultipleTemplateMatchResult.from_template_match_result(r, index)
|
||||
|
@ -591,6 +618,7 @@ def count(
|
|||
threshold: float = 0.8,
|
||||
remove_duplicate: bool = True,
|
||||
colored: bool = False,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
指定一个模板,统计其出现的次数。
|
||||
|
@ -602,6 +630,7 @@ def count(
|
|||
:param threshold: 阈值,默认为 0.8。
|
||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||
:param colored: 是否匹配颜色,默认为 False。
|
||||
:param preprocessors: 预处理列表,默认为 None。
|
||||
"""
|
||||
results = template_match(
|
||||
template,
|
||||
|
@ -612,6 +641,7 @@ def count(
|
|||
max_results=-1,
|
||||
remove_duplicate=remove_duplicate,
|
||||
colored=colored,
|
||||
preprocessors=preprocessors,
|
||||
)
|
||||
# logger.debug(
|
||||
# f'count(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
||||
|
@ -641,6 +671,7 @@ def expect(
|
|||
threshold: float = 0.8,
|
||||
colored: bool = False,
|
||||
remove_duplicate: bool = True,
|
||||
preprocessors: list[PreprocessorProtocol] | None = None,
|
||||
) -> TemplateMatchResult:
|
||||
"""
|
||||
指定一个模板,寻找其出现的第一个位置。若未找到,则抛出异常。
|
||||
|
@ -652,6 +683,7 @@ def expect(
|
|||
:param threshold: 阈值,默认为 0.8。
|
||||
:param colored: 是否匹配颜色,默认为 False。
|
||||
:param remove_duplicate: 是否移除重复结果,默认为 True。
|
||||
:param preprocessors: 预处理列表,默认为 None。
|
||||
"""
|
||||
ret = find(
|
||||
image,
|
||||
|
@ -662,6 +694,7 @@ def expect(
|
|||
colored=colored,
|
||||
remove_duplicate=remove_duplicate,
|
||||
debug_output=False,
|
||||
preprocessors=preprocessors,
|
||||
)
|
||||
# logger.debug(
|
||||
# f'expect(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
|
||||
from typing import Protocol
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from cv2.typing import MatLike
|
||||
|
||||
class PreprocessorProtocol(Protocol):
|
||||
"""预处理协议。用于 Image 与 Ocr 中的 `preprocessor` 参数。"""
|
||||
def process(self, image: MatLike) -> MatLike:
|
||||
"""
|
||||
预处理图像。
|
||||
|
||||
:param image: 输入图像,格式为 BGR。
|
||||
:return: 预处理后的图像,格式不限。
|
||||
"""
|
||||
...
|
||||
|
||||
class HsvColorFilter(PreprocessorProtocol):
|
||||
"""HSV 颜色过滤器。用于保留指定颜色。"""
|
||||
def __init__(
|
||||
self,
|
||||
lower: tuple[int, int, int],
|
||||
upper: tuple[int, int, int],
|
||||
*,
|
||||
name: str | None = None,
|
||||
):
|
||||
self.lower = np.array(lower)
|
||||
self.upper = np.array(upper)
|
||||
self.name = name
|
||||
|
||||
def process(self, image: MatLike) -> MatLike:
|
||||
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||
mask = cv2.inRange(hsv, self.lower, self.upper)
|
||||
return mask
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'HsvColorFilter(for color "{self.name}" with range {self.lower} - {self.upper})'
|
Loading…
Reference in New Issue