feat(task): 新增图像数据库 image_db 模块

This commit is contained in:
XcantloadX 2025-04-01 21:39:09 +08:00
parent f895716813
commit 3b2d60dcb9
4 changed files with 221 additions and 0 deletions

View File

@ -0,0 +1,4 @@
from .db import ImageDatabase, Db, DatabaseQueryResult
from .descriptors import HistDescriptor
__all__ = ['ImageDatabase', 'Db', 'DatabaseQueryResult', 'HistDescriptor']

View File

@ -0,0 +1,175 @@
import os
import pickle
import logging
from dataclasses import dataclass
from typing import Any, NamedTuple, Protocol, Iterator
import cv2
import numpy as np
from cv2.typing import MatLike
from .descriptors import HistDescriptor
logger = logging.getLogger(__name__)
DATABASE_INTERNAL_VERSION = 0
@dataclass
class Db:
"""数据库"""
internal_version: int
"""数据库内部版本号"""
version: str | None
"""保留字段"""
name: str | None
"""数据库名称"""
data: dict[str, Any]
"""数据"""
def insert(self, key: str, value: Any):
self.data[key] = value
def count(self):
return len(self.data)
class DataSource(Protocol):
def __iter__(self) -> Iterator[tuple[str, Any]]:
...
class FileDataSource(DataSource):
def __init__(self, folder_path: str, keep_ext: bool = True):
self.path = os.path.abspath(folder_path)
self.keep_ext = keep_ext
def __iter__(self) -> Iterator[tuple[str, Any]]:
for file in os.listdir(self.path):
if not self.keep_ext:
file = os.path.splitext(file)[0]
yield file, cv2.imread(os.path.join(self.path, file))
class DatabaseQueryResult(NamedTuple):
key: str
feature: Any
distance: float
def chi2_distance(hist1: np.ndarray, hist2: np.ndarray, eps=1e-10):
return 0.5 * np.sum((hist1 - hist2) ** 2 / (hist1 + hist2 + eps))
class ImageDatabase:
def __init__(
self,
source: DataSource,
db_path: str,
descriptor: HistDescriptor,
*,
name: str | None = None
):
self.db_path = db_path
self.__db: Db | None = None
self.descriptor = descriptor
self.source = source
# 载入数据库
logger.info('Loading database from %s...', db_path)
if os.path.exists(db_path):
try:
with open(db_path, 'rb') as f:
self.__db = pickle.load(f)
logger.info('Database loaded. Name=%s, version=%s, count=%d', self.db.name, self.db.version, self.db.count())
except Exception as e:
logger.warning('Failed to load database from %s: %s', db_path, e)
self.__db = None
if self.__db is None:
self.__db = Db(DATABASE_INTERNAL_VERSION, None, name, {})
# 检查版本
if self.db.internal_version != DATABASE_INTERNAL_VERSION:
logger.info('Database internal version is %d, expected %d. Clearing database...', self.db.internal_version, DATABASE_INTERNAL_VERSION)
self.db.data.clear()
self.db.internal_version = DATABASE_INTERNAL_VERSION
# 载入数据源
logger.debug('Loading data source...')
for key, value in self.source:
self.insert(key, value)
self.save()
@property
def db(self) -> Db:
if not self.__db:
raise RuntimeError('Database not loaded')
return self.__db
def save(self):
with open(self.db_path, 'wb') as f:
pickle.dump(self.db, f)
def insert(self, key: str, image: MatLike | str, *, overwrite: bool = False):
"""
向图像数据库中插入一条新记录
:param key: 图片的 ID
:param image: 图片的路径或 MatLike
若为 MatLike必须为 BGR 格式
:param overwrite: 是否覆盖已存在的记录
"""
if isinstance(image, str):
image = cv2.imread(image)
if overwrite or key not in self.db.data:
self.db.insert(key, self.descriptor(image))
logger.debug('Inserted image: %s', key)
def insert_many(self, images: dict[str, str | MatLike], *, overwrite: bool = False):
"""
向图像数据库中插入多条新记录
:param images: 图片key 为图片的 IDvalue 为图片的路径或 MatLike
若为 MatLike必须为 BGR 格式
:param overwrite: 是否覆盖已存在的记录
"""
for name, image in images.items():
self.insert(name, image, overwrite=overwrite)
def search(self, query: MatLike, threshold: float = 10) -> list[DatabaseQueryResult]:
"""
搜索图片返回所有符合阈值要求的图片并按相似度降序排序
:param image: 待搜索的图片必须为 BGR 格式
:param threshold: 距离阈值阈值越大对相似度的要求越低
:return: 搜索结果
"""
query_feature = self.descriptor(query)
results = list[DatabaseQueryResult]()
for key, feature in self.db.data.items():
dist = chi2_distance(query_feature, feature)
if dist < threshold:
results.append(DatabaseQueryResult(key, feature, float(dist)))
results.sort(key=lambda x: x.distance)
return results
def match(self, query: MatLike, threshold: float = 10) -> DatabaseQueryResult | None:
"""
匹配图片寻找与输入图片最相似的图片
:param image: 待匹配的图片必须为 BGR 格式
:param threshold: 距离阈值阈值越大对相似度的要求越低
:return: 匹配结果
"""
results = self.search(query, threshold)
if len(results) > 0:
return results[0]
else:
return None
if __name__ == '__main__':
from kotonebot.tasks.db.image_db.db import Db
logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] [%(levelname)s] [%(name)s] [%(funcName)s] [%(lineno)d] %(message)s')
imgs_path = r'E:\GithubRepos\KotonesAutoAssistant.worktrees\dev\kotonebot\tasks\resources\idol_cards'
needle_path = r'D:\05.png'
db = ImageDatabase(FileDataSource(imgs_path), r'D:\idols.pkl', HistDescriptor(8), name='idols')
# if db.db.count() == 0:
# db.insert({file: os.path.join(imgs_path, file) for file in os.listdir(imgs_path)})
needle = cv2.imread(needle_path)
result = db.match(needle)
print(result)

View File

@ -0,0 +1,3 @@
from .hist import HistDescriptor
__all__ = ['HistDescriptor']

View File

@ -0,0 +1,39 @@
import cv2
import numpy as np
from cv2.typing import MatLike
class HistDescriptor:
def __init__(self, bin_count: int):
self.bin_count = bin_count
def __call__(self, image: MatLike):
img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
# 将图像均分为九个区域
masks = []
height, width = img.shape[:2]
for i in range(3):
for j in range(3):
start_row, start_col = i * height // 3, j * width // 3
end_row, end_col = (i + 1) * height // 3, (j + 1) * width // 3
mask = np.zeros(img.shape[:2], dtype=np.uint8)
mask[start_row:end_row, start_col:end_col] = 255
masks.append(mask)
# 依次计算九个区域的直方图
features = np.array([])
for mask in masks:
hist = cv2.calcHist(
[img],
[0, 1, 2],
mask,
[self.bin_count, self.bin_count, self.bin_count],
[0, 180, 0, 256, 0, 256]
)
hist = cv2.normalize(hist, hist)
features = np.append(features, hist.flatten())
return features
if __name__ == '__main__':
d = HistDescriptor(8)
img = cv2.imread(r'E:\GithubRepos\KotonesAutoAssistant.worktrees\dev\kotonebot\tasks\resources\idol_cards\i_card-amao-2-000_1.png')
print(d(img))
cv2.waitKey(0)