feat(core): 新增图像预处理器

1. Image/ContextImage 对象新增 `preprocessors` 参数
2. 新增 HsvColorFilter,用于过滤出指定颜色
This commit is contained in:
XcantloadX 2025-03-16 11:27:56 +08:00
parent d08af8f715
commit b831e9e2bd
3 changed files with 114 additions and 8 deletions

View File

@ -42,13 +42,16 @@ import kotonebot.backend.color as raw_color
from kotonebot.backend.color import ( from kotonebot.backend.color import (
find as color_find, find_all as color_find_all find as color_find, find_all as color_find_all
) )
from kotonebot.backend.ocr import Ocr, OcrResult, OcrResultList, jp, en, StringMatchFunction from kotonebot.backend.ocr import (
Ocr, OcrResult, OcrResultList, jp, en, StringMatchFunction
)
from kotonebot.client.factory import create_device from kotonebot.client.factory import create_device
from kotonebot.config.manager import load_config, save_config from kotonebot.config.manager import load_config, save_config
from kotonebot.config.base_config import UserConfig from kotonebot.config.base_config import UserConfig
from kotonebot.backend.core import Image, HintBox from kotonebot.backend.core import Image, HintBox
from kotonebot.errors import KotonebotWarning from kotonebot.errors import KotonebotWarning
from kotonebot.client.factory import DeviceImpl from kotonebot.client.factory import DeviceImpl
from kotonebot.backend.preprocessor import PreprocessorProtocol
OcrLanguage = Literal['jp', 'en'] OcrLanguage = Literal['jp', 'en']
ScreenshotMode = Literal['auto', 'manual', 'manual-inherit'] ScreenshotMode = Literal['auto', 'manual', 'manual-inherit']
@ -396,6 +399,7 @@ class ContextImage:
*, *,
transparent: bool = False, transparent: bool = False,
interval: float = DEFAULT_INTERVAL, interval: float = DEFAULT_INTERVAL,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> TemplateMatchResult | None: ) -> TemplateMatchResult | None:
""" """
等待指定图像出现 等待指定图像出现
@ -406,7 +410,14 @@ class ContextImage:
while True: while True:
if is_manual: if is_manual:
device.screenshot() device.screenshot()
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored) ret = self.find(
template,
mask,
transparent=transparent,
threshold=threshold,
colored=colored,
preprocessors=preprocessors,
)
if ret is not None: if ret is not None:
self.context.device.last_find = ret self.context.device.last_find = ret
return ret return ret
@ -423,7 +434,8 @@ class ContextImage:
colored: bool = False, colored: bool = False,
*, *,
transparent: bool = False, transparent: bool = False,
interval: float = DEFAULT_INTERVAL interval: float = DEFAULT_INTERVAL,
preprocessors: list[PreprocessorProtocol] | None = None,
): ):
""" """
等待指定图像中的任意一个出现 等待指定图像中的任意一个出现
@ -439,7 +451,14 @@ class ContextImage:
if is_manual: if is_manual:
device.screenshot() device.screenshot()
for template, mask in zip(templates, _masks): for template, mask in zip(templates, _masks):
if self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored): if self.find(
template,
mask,
transparent=transparent,
threshold=threshold,
colored=colored,
preprocessors=preprocessors,
):
return True return True
if time.time() - start_time > timeout: if time.time() - start_time > timeout:
return False return False
@ -454,7 +473,8 @@ class ContextImage:
colored: bool = False, colored: bool = False,
*, *,
transparent: bool = False, transparent: bool = False,
interval: float = DEFAULT_INTERVAL interval: float = DEFAULT_INTERVAL,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> TemplateMatchResult: ) -> TemplateMatchResult:
""" """
等待指定图像出现 等待指定图像出现
@ -465,7 +485,14 @@ class ContextImage:
while True: while True:
if is_manual: if is_manual:
device.screenshot() device.screenshot()
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored) ret = self.find(
template,
mask,
transparent=transparent,
threshold=threshold,
colored=colored,
preprocessors=preprocessors,
)
if ret is not None: if ret is not None:
self.context.device.last_find = ret self.context.device.last_find = ret
return ret return ret
@ -482,7 +509,8 @@ class ContextImage:
colored: bool = False, colored: bool = False,
*, *,
transparent: bool = False, transparent: bool = False,
interval: float = DEFAULT_INTERVAL interval: float = DEFAULT_INTERVAL,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> TemplateMatchResult: ) -> TemplateMatchResult:
""" """
等待指定图像中的任意一个出现 等待指定图像中的任意一个出现
@ -498,7 +526,14 @@ class ContextImage:
if is_manual: if is_manual:
device.screenshot() device.screenshot()
for template, mask in zip(templates, _masks): for template, mask in zip(templates, _masks):
ret = self.find(template, mask, transparent=transparent, threshold=threshold, colored=colored) ret = self.find(
template,
mask,
transparent=transparent,
threshold=threshold,
colored=colored,
preprocessors=preprocessors,
)
if ret is not None: if ret is not None:
self.context.device.last_find = ret self.context.device.last_find = ret
return ret return ret

View File

@ -10,6 +10,7 @@ from skimage.metrics import structural_similarity
from .core import Image, unify_image from .core import Image, unify_image
from ..util import Rect, Point from ..util import Rect, Point
from .debug import result as debug_result, debug, img from .debug import result as debug_result, debug, img
from .preprocessor import PreprocessorProtocol
logger = getLogger(__name__) logger = getLogger(__name__)
@ -124,6 +125,8 @@ def _results2str(results: Sequence[TemplateMatchResult | MultipleTemplateMatchRe
return 'None' return 'None'
return ', '.join([_result2str(result) for result in results]) return ', '.join([_result2str(result) for result in results])
# TODO: 应该把 template_match 和 find、wait、expect 等函数的公共参数提取出来
# TODO: 需要在调试结果中输出 preprocessors 处理后的图像
def template_match( def template_match(
template: MatLike | str | Image, template: MatLike | str | Image,
image: MatLike | str | Image, image: MatLike | str | Image,
@ -134,6 +137,7 @@ def template_match(
max_results: int = 5, max_results: int = 5,
remove_duplicate: bool = True, remove_duplicate: bool = True,
colored: bool = False, colored: bool = False,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> list[TemplateMatchResult]: ) -> list[TemplateMatchResult]:
""" """
寻找模板在图像中的位置 寻找模板在图像中的位置
@ -150,6 +154,7 @@ def template_match(
:param max_results: 最大结果数默认为 1 :param max_results: 最大结果数默认为 1
:param remove_duplicate: 是否移除重复结果默认为 True :param remove_duplicate: 是否移除重复结果默认为 True
:param colored: 是否匹配颜色默认为 False :param colored: 是否匹配颜色默认为 False
:param preprocessors: 预处理列表默认为 None
""" """
# 统一参数 # 统一参数
template = unify_image(template, transparent) template = unify_image(template, transparent)
@ -164,6 +169,13 @@ def template_match(
# 从透明图像中提取 alpha 通道作为 mask # 从透明图像中提取 alpha 通道作为 mask
mask = cv2.threshold(template[:, :, 3], 0, 255, cv2.THRESH_BINARY)[1] mask = cv2.threshold(template[:, :, 3], 0, 255, cv2.THRESH_BINARY)[1]
template = template[:, :, :3] template = template[:, :, :3]
# 预处理
if preprocessors is not None:
for preprocessor in preprocessors:
image = preprocessor.process(image)
template = preprocessor.process(template)
if mask is not None:
mask = preprocessor.process(mask)
# 匹配模板 # 匹配模板
if mask is not None: if mask is not None:
# https://stackoverflow.com/questions/35642497/python-opencv-cv2-matchtemplate-with-transparency # https://stackoverflow.com/questions/35642497/python-opencv-cv2-matchtemplate-with-transparency
@ -302,6 +314,7 @@ def find_all_crop(
*, *,
colored: bool = False, colored: bool = False,
remove_duplicate: bool = True, remove_duplicate: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> list[CropResult]: ) -> list[CropResult]:
""" """
指定一个模板在输入图像中寻找其出现的所有位置并裁剪出结果 指定一个模板在输入图像中寻找其出现的所有位置并裁剪出结果
@ -313,6 +326,7 @@ def find_all_crop(
:param threshold: 阈值默认为 0.8 :param threshold: 阈值默认为 0.8
:param colored: 是否匹配颜色默认为 False :param colored: 是否匹配颜色默认为 False
:param remove_duplicate: 是否移除重复结果默认为 True :param remove_duplicate: 是否移除重复结果默认为 True
:param preprocessors: 预处理列表默认为 None
""" """
matches = template_match( matches = template_match(
template, template,
@ -323,6 +337,7 @@ def find_all_crop(
max_results=-1, max_results=-1,
remove_duplicate=remove_duplicate, remove_duplicate=remove_duplicate,
colored=colored, colored=colored,
preprocessors=preprocessors,
) )
# logger.debug( # logger.debug(
# f'find_all_crop(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} ' # f'find_all_crop(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
@ -345,6 +360,7 @@ def find(
debug_output: bool = True, debug_output: bool = True,
colored: bool = False, colored: bool = False,
remove_duplicate: bool = True, remove_duplicate: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> TemplateMatchResult | None: ) -> TemplateMatchResult | None:
""" """
指定一个模板在输入图像中寻找其出现的第一个位置 指定一个模板在输入图像中寻找其出现的第一个位置
@ -357,6 +373,7 @@ def find(
:param debug_output: 是否输出调试信息默认为 True :param debug_output: 是否输出调试信息默认为 True
:param colored: 是否匹配颜色默认为 False :param colored: 是否匹配颜色默认为 False
:param remove_duplicate: 是否移除重复结果默认为 True :param remove_duplicate: 是否移除重复结果默认为 True
:param preprocessors: 预处理列表默认为 None
""" """
matches = template_match( matches = template_match(
template, template,
@ -367,6 +384,7 @@ def find(
max_results=1, max_results=1,
remove_duplicate=remove_duplicate, remove_duplicate=remove_duplicate,
colored=colored, colored=colored,
preprocessors=preprocessors,
) )
# logger.debug( # logger.debug(
# f'find(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} ' # f'find(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
@ -396,6 +414,7 @@ def find_all(
remove_duplicate: bool = True, remove_duplicate: bool = True,
colored: bool = False, colored: bool = False,
debug_output: bool = True, debug_output: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> list[TemplateMatchResult]: ) -> list[TemplateMatchResult]:
""" """
指定一个模板在输入图像中寻找其出现的所有位置 指定一个模板在输入图像中寻找其出现的所有位置
@ -407,6 +426,7 @@ def find_all(
:param threshold: 阈值默认为 0.8 :param threshold: 阈值默认为 0.8
:param remove_duplicate: 是否移除重复结果默认为 True :param remove_duplicate: 是否移除重复结果默认为 True
:param colored: 是否匹配颜色默认为 False :param colored: 是否匹配颜色默认为 False
:param preprocessors: 预处理列表默认为 None
""" """
results = template_match( results = template_match(
template, template,
@ -417,6 +437,7 @@ def find_all(
max_results=-1, max_results=-1,
remove_duplicate=remove_duplicate, remove_duplicate=remove_duplicate,
colored=colored, colored=colored,
preprocessors=preprocessors,
) )
# logger.debug( # logger.debug(
# f'find_all(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} ' # f'find_all(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
@ -441,6 +462,7 @@ def find_multi(
threshold: float = 0.8, threshold: float = 0.8,
colored: bool = False, colored: bool = False,
remove_duplicate: bool = True, remove_duplicate: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> MultipleTemplateMatchResult | None: ) -> MultipleTemplateMatchResult | None:
""" """
指定多个模板在输入图像中逐个寻找模板返回第一个匹配到的结果 指定多个模板在输入图像中逐个寻找模板返回第一个匹配到的结果
@ -452,6 +474,7 @@ def find_multi(
:param threshold: 阈值默认为 0.8 :param threshold: 阈值默认为 0.8
:param colored: 是否匹配颜色默认为 False :param colored: 是否匹配颜色默认为 False
:param remove_duplicate: 是否移除重复结果默认为 True :param remove_duplicate: 是否移除重复结果默认为 True
:param preprocessors: 预处理列表默认为 None
""" """
ret = None ret = None
if masks is None: if masks is None:
@ -468,6 +491,7 @@ def find_multi(
colored=colored, colored=colored,
debug_output=False, debug_output=False,
remove_duplicate=remove_duplicate, remove_duplicate=remove_duplicate,
preprocessors=preprocessors,
) )
# 调试输出 # 调试输出
if find_result is not None: if find_result is not None:
@ -508,6 +532,7 @@ def find_all_multi(
threshold: float = 0.8, threshold: float = 0.8,
colored: bool = False, colored: bool = False,
remove_duplicate: bool = True, remove_duplicate: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> list[MultipleTemplateMatchResult]: ) -> list[MultipleTemplateMatchResult]:
""" """
指定多个模板在输入图像中逐个寻找模板返回所有匹配到的结果 指定多个模板在输入图像中逐个寻找模板返回所有匹配到的结果
@ -526,6 +551,7 @@ def find_all_multi(
:param threshold: 阈值默认为 0.8 :param threshold: 阈值默认为 0.8
:param colored: 是否匹配颜色默认为 False :param colored: 是否匹配颜色默认为 False
:param remove_duplicate: 是否移除重复结果默认为 True :param remove_duplicate: 是否移除重复结果默认为 True
:param preprocessors: 预处理列表默认为 None
:return: 匹配到的一维结果列表 :return: 匹配到的一维结果列表
""" """
ret: list[MultipleTemplateMatchResult] = [] ret: list[MultipleTemplateMatchResult] = []
@ -544,6 +570,7 @@ def find_all_multi(
colored=colored, colored=colored,
remove_duplicate=remove_duplicate, remove_duplicate=remove_duplicate,
debug_output=False, debug_output=False,
preprocessors=preprocessors,
) )
ret.extend([ ret.extend([
MultipleTemplateMatchResult.from_template_match_result(r, index) MultipleTemplateMatchResult.from_template_match_result(r, index)
@ -591,6 +618,7 @@ def count(
threshold: float = 0.8, threshold: float = 0.8,
remove_duplicate: bool = True, remove_duplicate: bool = True,
colored: bool = False, colored: bool = False,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> int: ) -> int:
""" """
指定一个模板统计其出现的次数 指定一个模板统计其出现的次数
@ -602,6 +630,7 @@ def count(
:param threshold: 阈值默认为 0.8 :param threshold: 阈值默认为 0.8
:param remove_duplicate: 是否移除重复结果默认为 True :param remove_duplicate: 是否移除重复结果默认为 True
:param colored: 是否匹配颜色默认为 False :param colored: 是否匹配颜色默认为 False
:param preprocessors: 预处理列表默认为 None
""" """
results = template_match( results = template_match(
template, template,
@ -612,6 +641,7 @@ def count(
max_results=-1, max_results=-1,
remove_duplicate=remove_duplicate, remove_duplicate=remove_duplicate,
colored=colored, colored=colored,
preprocessors=preprocessors,
) )
# logger.debug( # logger.debug(
# f'count(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} ' # f'count(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '
@ -641,6 +671,7 @@ def expect(
threshold: float = 0.8, threshold: float = 0.8,
colored: bool = False, colored: bool = False,
remove_duplicate: bool = True, remove_duplicate: bool = True,
preprocessors: list[PreprocessorProtocol] | None = None,
) -> TemplateMatchResult: ) -> TemplateMatchResult:
""" """
指定一个模板寻找其出现的第一个位置若未找到则抛出异常 指定一个模板寻找其出现的第一个位置若未找到则抛出异常
@ -652,6 +683,7 @@ def expect(
:param threshold: 阈值默认为 0.8 :param threshold: 阈值默认为 0.8
:param colored: 是否匹配颜色默认为 False :param colored: 是否匹配颜色默认为 False
:param remove_duplicate: 是否移除重复结果默认为 True :param remove_duplicate: 是否移除重复结果默认为 True
:param preprocessors: 预处理列表默认为 None
""" """
ret = find( ret = find(
image, image,
@ -662,6 +694,7 @@ def expect(
colored=colored, colored=colored,
remove_duplicate=remove_duplicate, remove_duplicate=remove_duplicate,
debug_output=False, debug_output=False,
preprocessors=preprocessors,
) )
# logger.debug( # logger.debug(
# f'expect(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} ' # f'expect(): template: {_img2str(template)} image: {_img2str(image)} mask: {_img2str(mask)} '

View File

@ -0,0 +1,38 @@
from typing import Protocol
import cv2
import numpy as np
from cv2.typing import MatLike
class PreprocessorProtocol(Protocol):
"""预处理协议。用于 Image 与 Ocr 中的 `preprocessor` 参数。"""
def process(self, image: MatLike) -> MatLike:
"""
预处理图像
:param image: 输入图像格式为 BGR
:return: 预处理后的图像格式不限
"""
...
class HsvColorFilter(PreprocessorProtocol):
"""HSV 颜色过滤器。用于保留指定颜色。"""
def __init__(
self,
lower: tuple[int, int, int],
upper: tuple[int, int, int],
*,
name: str | None = None,
):
self.lower = np.array(lower)
self.upper = np.array(upper)
self.name = name
def process(self, image: MatLike) -> MatLike:
hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
mask = cv2.inRange(hsv, self.lower, self.upper)
return mask
def __repr__(self) -> str:
return f'HsvColorFilter(for color "{self.name}" with range {self.lower} - {self.upper})'