feat(core): 新增图像预处理器

1. Image/ContextImage 对象新增 `preprocessors` 参数
2. 新增 HsvColorFilter,用于过滤出指定颜色
This commit is contained in:
XcantloadX 2025-03-16 11:27:56 +08:00
parent d08af8f715
commit b831e9e2bd
3 changed files with 114 additions and 8 deletions

View File

@ -42,13 +42,16 @@ import kotonebot.backend.color as raw_color
from kotonebot.backend.color import (
find as color_find, find_all as color_find_all
)
from kotonebot.backend.ocr import Ocr, OcrResult, OcrResultList, jp, en, StringMatchFunction
from kotonebot.backend.ocr import (
Ocr, OcrResult, OcrResultList, jp, en, StringMatchFunction
)
from kotonebot.client.factory import create_device
from kotonebot.config.manager import load_config, save_config
from kotonebot.config.base_config import UserConfig
from kotonebot.backend.core import Image, HintBox
from kotonebot.errors import KotonebotWarning
from kotonebot.client.factory import DeviceImpl
from kotonebot.backend.preprocessor import PreprocessorProtocol
OcrLanguage = Literal['jp', 'en']
ScreenshotMode = Literal['auto', 'manual', 'manual-inherit']
@ -396,6 +399,7 @@ class ContextImage:
*,
transparent: bool = False,
interval: float = DEFAULT_INTERVAL,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> TemplateMatchResult | None:
"""
等待指定图像出现
@ -406,7 +410,14 @@ class ContextImage:
while True:
if is_manual:
device.screenshot()
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored)
ret = self.find(
template,
mask,
transparent=transparent,
threshold=threshold,
colored=colored,
preprocessors=preprocessors,
)
if ret is not None:
self.context.device.last_find = ret
return ret
@ -423,7 +434,8 @@ class ContextImage:
colored: bool = False,
*,
transparent: bool = False,
interval: float = DEFAULT_INTERVAL
interval: float = DEFAULT_INTERVAL,
preprocessors: list[PreprocessorProtocol] | None = None,
):
"""
等待指定图像中的任意一个出现
@ -439,7 +451,14 @@ class ContextImage:
if is_manual:
device.screenshot()
for template, mask in zip(templates, _masks):
if self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored):
if self.find(
template,
mask,
transparent=transparent,
threshold=threshold,
colored=colored,
preprocessors=preprocessors,
):
return True
if time.time() - start_time > timeout:
return False
@ -454,7 +473,8 @@ class ContextImage:
colored: bool = False,
*,
transparent: bool = False,
interval: float = DEFAULT_INTERVAL
interval: float = DEFAULT_INTERVAL,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> TemplateMatchResult:
"""
等待指定图像出现
@ -465,7 +485,14 @@ class ContextImage:
while True:
if is_manual:
device.screenshot()
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored)
ret = self.find(
template,
mask,
transparent=transparent,
threshold=threshold,
colored=colored,
preprocessors=preprocessors,
)
if ret is not None:
self.context.device.last_find = ret
return ret
@ -482,7 +509,8 @@ class ContextImage:
colored: bool = False,
*,
transparent: bool = False,
interval: float = DEFAULT_INTERVAL
interval: float = DEFAULT_INTERVAL,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> TemplateMatchResult:
"""
等待指定图像中的任意一个出现
@ -498,7 +526,14 @@ class ContextImage:
if is_manual:
device.screenshot()
for template, mask in zip(templates, _masks):
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored)
ret = self.find(
template,
mask,
transparent=transparent,
threshold=threshold,
colored=colored,
preprocessors=preprocessors,
)
if ret is not None:
self.context.device.last_find = ret
return ret

View File

@ -10,6 +10,7 @@ from skimage.metrics import structural_similarity
from .core import Image, unify_image
from ..util import Rect, Point
from .debug import result as debug_result, debug, img
from .preprocessor import PreprocessorProtocol
logger = getLogger(__name__)
@ -124,6 +125,8 @@ def _results2str(results: Sequence[TemplateMatchResult | MultipleTemplateMatchRe
return 'None'
return ', '.join([_result2str(result) for result in results])
# TODO: 应该把 template_match 和 find、wait、expect 等函数的公共参数提取出来
# TODO: 需要在调试结果中输出 preprocessors 处理后的图像
def template_match(
template: MatLike | str | Image,
image: MatLike | str | Image,
@ -134,6 +137,7 @@ def template_match(
max_results: int = 5,
remove_duplicate: bool = True,
colored: bool = False,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> list[TemplateMatchResult]:
"""
寻找模板在图像中的位置
@ -150,6 +154,7 @@ def template_match(
:param max_results: 最大结果数默认为 1
:param remove_duplicate: 是否移除重复结果默认为 True
:param colored: 是否匹配颜色默认为 False
:param preprocessors: 预处理列表默认为 None
"""
# 统一参数
template = unify_image(template, transparent)
@ -164,6 +169,13 @@ def template_match(
# 从透明图像中提取 alpha 通道作为 mask
mask = cv2.threshold(template[:, :, 3], 0, 255, cv2.THRESH_BINARY)[1]
template = template[:, :, :3]
# 预处理
if preprocessors is not None:
for preprocessor in preprocessors:
image = preprocessor.process(image)
template = preprocessor.process(template)
if mask is not None:
mask = preprocessor.process(mask)
# 匹配模板
if mask is not None:
# https://stackoverflow.com/questions/35642497/python-opencv-cv2-matchtemplate-with-transparency
@ -302,6 +314,7 @@ def find_all_crop(
*,
colored: bool = False,
remove_duplicate: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> list[CropResult]:
"""
指定一个模板在输入图像中寻找其出现的所有位置并裁剪出结果
@ -313,6 +326,7 @@ def find_all_crop(
:param threshold: 阈值默认为 0.8
:param colored: 是否匹配颜色默认为 False
:param remove_duplicate: 是否移除重复结果默认为 True
:param preprocessors: 预处理列表默认为 None
"""
matches = template_match(
template,
@ -323,6 +337,7 @@ def find_all_crop(
max_results=-1,
remove_duplicate=remove_duplicate,
colored=colored,
preprocessors=preprocessors,
)
# logger.debug(
# f'find_all_crop(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
@ -345,6 +360,7 @@ def find(
debug_output: bool = True,
colored: bool = False,
remove_duplicate: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> TemplateMatchResult | None:
"""
指定一个模板在输入图像中寻找其出现的第一个位置
@ -357,6 +373,7 @@ def find(
:param debug_output: 是否输出调试信息默认为 True
:param colored: 是否匹配颜色默认为 False
:param remove_duplicate: 是否移除重复结果默认为 True
:param preprocessors: 预处理列表默认为 None
"""
matches = template_match(
template,
@ -367,6 +384,7 @@ def find(
max_results=1,
remove_duplicate=remove_duplicate,
colored=colored,
preprocessors=preprocessors,
)
# logger.debug(
# f'find(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
@ -396,6 +414,7 @@ def find_all(
remove_duplicate: bool = True,
colored: bool = False,
debug_output: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> list[TemplateMatchResult]:
"""
指定一个模板在输入图像中寻找其出现的所有位置
@ -407,6 +426,7 @@ def find_all(
:param threshold: 阈值默认为 0.8
:param remove_duplicate: 是否移除重复结果默认为 True
:param colored: 是否匹配颜色默认为 False
:param preprocessors: 预处理列表默认为 None
"""
results = template_match(
template,
@ -417,6 +437,7 @@ def find_all(
max_results=-1,
remove_duplicate=remove_duplicate,
colored=colored,
preprocessors=preprocessors,
)
# logger.debug(
# f'find_all(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
@ -441,6 +462,7 @@ def find_multi(
threshold: float = 0.8,
colored: bool = False,
remove_duplicate: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> MultipleTemplateMatchResult | None:
"""
指定多个模板在输入图像中逐个寻找模板返回第一个匹配到的结果
@ -452,6 +474,7 @@ def find_multi(
:param threshold: 阈值默认为 0.8
:param colored: 是否匹配颜色默认为 False
:param remove_duplicate: 是否移除重复结果默认为 True
:param preprocessors: 预处理列表默认为 None
"""
ret = None
if masks is None:
@ -468,6 +491,7 @@ def find_multi(
colored=colored,
debug_output=False,
remove_duplicate=remove_duplicate,
preprocessors=preprocessors,
)
# 调试输出
if find_result is not None:
@ -508,6 +532,7 @@ def find_all_multi(
threshold: float = 0.8,
colored: bool = False,
remove_duplicate: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> list[MultipleTemplateMatchResult]:
"""
指定多个模板在输入图像中逐个寻找模板返回所有匹配到的结果
@ -526,6 +551,7 @@ def find_all_multi(
:param threshold: 阈值默认为 0.8
:param colored: 是否匹配颜色默认为 False
:param remove_duplicate: 是否移除重复结果默认为 True
:param preprocessors: 预处理列表默认为 None
:return: 匹配到的一维结果列表
"""
ret: list[MultipleTemplateMatchResult] = []
@ -544,6 +570,7 @@ def find_all_multi(
colored=colored,
remove_duplicate=remove_duplicate,
debug_output=False,
preprocessors=preprocessors,
)
ret.extend([
MultipleTemplateMatchResult.from_template_match_result(r, index)
@ -591,6 +618,7 @@ def count(
threshold: float = 0.8,
remove_duplicate: bool = True,
colored: bool = False,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> int:
"""
指定一个模板统计其出现的次数
@ -602,6 +630,7 @@ def count(
:param threshold: 阈值默认为 0.8
:param remove_duplicate: 是否移除重复结果默认为 True
:param colored: 是否匹配颜色默认为 False
:param preprocessors: 预处理列表默认为 None
"""
results = template_match(
template,
@ -612,6 +641,7 @@ def count(
max_results=-1,
remove_duplicate=remove_duplicate,
colored=colored,
preprocessors=preprocessors,
)
# logger.debug(
# f'count(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
@ -641,6 +671,7 @@ def expect(
threshold: float = 0.8,
colored: bool = False,
remove_duplicate: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> TemplateMatchResult:
"""
指定一个模板寻找其出现的第一个位置若未找到则抛出异常
@ -652,6 +683,7 @@ def expect(
:param threshold: 阈值默认为 0.8
:param colored: 是否匹配颜色默认为 False
:param remove_duplicate: 是否移除重复结果默认为 True
:param preprocessors: 预处理列表默认为 None
"""
ret = find(
image,
@ -662,6 +694,7 @@ def expect(
colored=colored,
remove_duplicate=remove_duplicate,
debug_output=False,
preprocessors=preprocessors,
)
# logger.debug(
# f'expect(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '

View File

@ -0,0 +1,38 @@
from typing import Protocol
import cv2
import numpy as np
from cv2.typing import MatLike
class PreprocessorProtocol(Protocol):
"""预处理协议。用于 Image 与 Ocr 中的 `preprocessor` 参数。"""
def process(self, image: MatLike) -> MatLike:
"""
预处理图像
:param image: 输入图像格式为 BGR
:return: 预处理后的图像格式不限
"""
...
class HsvColorFilter(PreprocessorProtocol):
"""HSV 颜色过滤器。用于保留指定颜色。"""
def __init__(
self,
lower: tuple[int, int, int],
upper: tuple[int, int, int],
*,
name: str | None = None,
):
self.lower = np.array(lower)
self.upper = np.array(upper)
self.name = name
def process(self, image: MatLike) -> MatLike:
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
mask = cv2.inRange(hsv, self.lower, self.upper)
return mask
def __repr__(self) -> str:
return f'HsvColorFilter(for color "{self.name}" with range {self.lower} - {self.upper})'