From 3b2d60dcb95ed2a53f842513ea32af13ad1e4d11 Mon Sep 17 00:00:00 2001 From: XcantloadX <3188996979@qq.com> Date: Tue, 1 Apr 2025 21:39:09 +0800 Subject: [PATCH] =?UTF-8?q?feat(task):=20=E6=96=B0=E5=A2=9E=E5=9B=BE?= =?UTF-8?q?=E5=83=8F=E6=95=B0=E6=8D=AE=E5=BA=93=20image=5Fdb=20=E6=A8=A1?= =?UTF-8?q?=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- kotonebot/tasks/image_db/__init__.py | 4 + kotonebot/tasks/image_db/db.py | 175 ++++++++++++++++++ .../tasks/image_db/descriptors/__init__.py | 3 + kotonebot/tasks/image_db/descriptors/hist.py | 39 ++++ 4 files changed, 221 insertions(+) create mode 100644 kotonebot/tasks/image_db/__init__.py create mode 100644 kotonebot/tasks/image_db/db.py create mode 100644 kotonebot/tasks/image_db/descriptors/__init__.py create mode 100644 kotonebot/tasks/image_db/descriptors/hist.py diff --git a/kotonebot/tasks/image_db/__init__.py b/kotonebot/tasks/image_db/__init__.py new file mode 100644 index 0000000..c409d2e --- /dev/null +++ b/kotonebot/tasks/image_db/__init__.py @@ -0,0 +1,4 @@ +from .db import ImageDatabase, Db, DatabaseQueryResult +from .descriptors import HistDescriptor + +__all__ = ['ImageDatabase', 'Db', 'DatabaseQueryResult', 'HistDescriptor'] diff --git a/kotonebot/tasks/image_db/db.py b/kotonebot/tasks/image_db/db.py new file mode 100644 index 0000000..0e3988f --- /dev/null +++ b/kotonebot/tasks/image_db/db.py @@ -0,0 +1,175 @@ +import os +import pickle +import logging +from dataclasses import dataclass +from typing import Any, NamedTuple, Protocol, Iterator + +import cv2 +import numpy as np +from cv2.typing import MatLike + +from .descriptors import HistDescriptor + +logger = logging.getLogger(__name__) + +DATABASE_INTERNAL_VERSION = 0 + +@dataclass +class Db: + """数据库""" + internal_version: int + """数据库内部版本号""" + version: str | None + """保留字段""" + name: str | None + """数据库名称""" + data: dict[str, Any] + """数据""" + + def insert(self, key: str, value: Any): + self.data[key] = value + + def count(self): + return len(self.data) + +class DataSource(Protocol): + def __iter__(self) -> Iterator[tuple[str, Any]]: + ... + +class FileDataSource(DataSource): + def __init__(self, folder_path: str, keep_ext: bool = True): + self.path = os.path.abspath(folder_path) + self.keep_ext = keep_ext + + def __iter__(self) -> Iterator[tuple[str, Any]]: + for file in os.listdir(self.path): + if not self.keep_ext: + file = os.path.splitext(file)[0] + yield file, cv2.imread(os.path.join(self.path, file)) + +class DatabaseQueryResult(NamedTuple): + key: str + feature: Any + distance: float + +def chi2_distance(hist1: np.ndarray, hist2: np.ndarray, eps=1e-10): + return 0.5 * np.sum((hist1 - hist2) ** 2 / (hist1 + hist2 + eps)) + +class ImageDatabase: + def __init__( + self, + source: DataSource, + db_path: str, + descriptor: HistDescriptor, + *, + name: str | None = None + ): + self.db_path = db_path + self.__db: Db | None = None + self.descriptor = descriptor + self.source = source + + # 载入数据库 + logger.info('Loading database from %s...', db_path) + if os.path.exists(db_path): + try: + with open(db_path, 'rb') as f: + self.__db = pickle.load(f) + logger.info('Database loaded. Name=%s, version=%s, count=%d', self.db.name, self.db.version, self.db.count()) + except Exception as e: + logger.warning('Failed to load database from %s: %s', db_path, e) + self.__db = None + if self.__db is None: + self.__db = Db(DATABASE_INTERNAL_VERSION, None, name, {}) + + # 检查版本 + if self.db.internal_version != DATABASE_INTERNAL_VERSION: + logger.info('Database internal version is %d, expected %d. Clearing database...', self.db.internal_version, DATABASE_INTERNAL_VERSION) + self.db.data.clear() + self.db.internal_version = DATABASE_INTERNAL_VERSION + + # 载入数据源 + logger.debug('Loading data source...') + for key, value in self.source: + self.insert(key, value) + self.save() + + @property + def db(self) -> Db: + if not self.__db: + raise RuntimeError('Database not loaded') + return self.__db + + def save(self): + with open(self.db_path, 'wb') as f: + pickle.dump(self.db, f) + + def insert(self, key: str, image: MatLike | str, *, overwrite: bool = False): + """ + 向图像数据库中插入一条新记录。 + + :param key: 图片的 ID。 + :param image: 图片的路径或 MatLike。 + 若为 MatLike,必须为 BGR 格式。 + :param overwrite: 是否覆盖已存在的记录。 + """ + if isinstance(image, str): + image = cv2.imread(image) + if overwrite or key not in self.db.data: + self.db.insert(key, self.descriptor(image)) + logger.debug('Inserted image: %s', key) + + def insert_many(self, images: dict[str, str | MatLike], *, overwrite: bool = False): + """ + 向图像数据库中插入多条新记录。 + + :param images: 图片。key 为图片的 ID,value 为图片的路径或 MatLike。 + 若为 MatLike,必须为 BGR 格式。 + :param overwrite: 是否覆盖已存在的记录。 + """ + for name, image in images.items(): + self.insert(name, image, overwrite=overwrite) + + def search(self, query: MatLike, threshold: float = 10) -> list[DatabaseQueryResult]: + """ + 搜索图片,返回所有符合阈值要求的图片,并按相似度降序排序。 + + :param image: 待搜索的图片。必须为 BGR 格式。 + :param threshold: 距离阈值。阈值越大,对相似度的要求越低。 + :return: 搜索结果。 + """ + query_feature = self.descriptor(query) + results = list[DatabaseQueryResult]() + for key, feature in self.db.data.items(): + dist = chi2_distance(query_feature, feature) + if dist < threshold: + results.append(DatabaseQueryResult(key, feature, float(dist))) + results.sort(key=lambda x: x.distance) + return results + + def match(self, query: MatLike, threshold: float = 10) -> DatabaseQueryResult | None: + """ + 匹配图片,寻找与输入图片最相似的图片。 + + :param image: 待匹配的图片。必须为 BGR 格式。 + :param threshold: 距离阈值。阈值越大,对相似度的要求越低。 + :return: 匹配结果。 + """ + results = self.search(query, threshold) + if len(results) > 0: + return results[0] + else: + return None + + +if __name__ == '__main__': + from kotonebot.tasks.db.image_db.db import Db + logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] [%(levelname)s] [%(name)s] [%(funcName)s] [%(lineno)d] %(message)s') + imgs_path = r'E:\GithubRepos\KotonesAutoAssistant.worktrees\dev\kotonebot\tasks\resources\idol_cards' + needle_path = r'D:\05.png' + db = ImageDatabase(FileDataSource(imgs_path), r'D:\idols.pkl', HistDescriptor(8), name='idols') + # if db.db.count() == 0: + # db.insert({file: os.path.join(imgs_path, file) for file in os.listdir(imgs_path)}) + needle = cv2.imread(needle_path) + result = db.match(needle) + print(result) diff --git a/kotonebot/tasks/image_db/descriptors/__init__.py b/kotonebot/tasks/image_db/descriptors/__init__.py new file mode 100644 index 0000000..bc40097 --- /dev/null +++ b/kotonebot/tasks/image_db/descriptors/__init__.py @@ -0,0 +1,3 @@ +from .hist import HistDescriptor + +__all__ = ['HistDescriptor'] diff --git a/kotonebot/tasks/image_db/descriptors/hist.py b/kotonebot/tasks/image_db/descriptors/hist.py new file mode 100644 index 0000000..90c85aa --- /dev/null +++ b/kotonebot/tasks/image_db/descriptors/hist.py @@ -0,0 +1,39 @@ +import cv2 +import numpy as np +from cv2.typing import MatLike + +class HistDescriptor: + def __init__(self, bin_count: int): + self.bin_count = bin_count + + def __call__(self, image: MatLike): + img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) + # 将图像均分为九个区域 + masks = [] + height, width = img.shape[:2] + for i in range(3): + for j in range(3): + start_row, start_col = i * height // 3, j * width // 3 + end_row, end_col = (i + 1) * height // 3, (j + 1) * width // 3 + mask = np.zeros(img.shape[:2], dtype=np.uint8) + mask[start_row:end_row, start_col:end_col] = 255 + masks.append(mask) + # 依次计算九个区域的直方图 + features = np.array([]) + for mask in masks: + hist = cv2.calcHist( + [img], + [0, 1, 2], + mask, + [self.bin_count, self.bin_count, self.bin_count], + [0, 180, 0, 256, 0, 256] + ) + hist = cv2.normalize(hist, hist) + features = np.append(features, hist.flatten()) + return features + +if __name__ == '__main__': + d = HistDescriptor(8) + img = cv2.imread(r'E:\GithubRepos\KotonesAutoAssistant.worktrees\dev\kotonebot\tasks\resources\idol_cards\i_card-amao-2-000_1.png') + print(d(img)) + cv2.waitKey(0)