feat(task): 新增图像数据库 image_db 模块
This commit is contained in:
parent
f895716813
commit
3b2d60dcb9
|
@ -0,0 +1,4 @@
|
||||||
|
from .db import ImageDatabase, Db, DatabaseQueryResult
|
||||||
|
from .descriptors import HistDescriptor
|
||||||
|
|
||||||
|
__all__ = ['ImageDatabase', 'Db', 'DatabaseQueryResult', 'HistDescriptor']
|
|
@ -0,0 +1,175 @@
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, NamedTuple, Protocol, Iterator
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from cv2.typing import MatLike
|
||||||
|
|
||||||
|
from .descriptors import HistDescriptor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DATABASE_INTERNAL_VERSION = 0
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Db:
|
||||||
|
"""数据库"""
|
||||||
|
internal_version: int
|
||||||
|
"""数据库内部版本号"""
|
||||||
|
version: str | None
|
||||||
|
"""保留字段"""
|
||||||
|
name: str | None
|
||||||
|
"""数据库名称"""
|
||||||
|
data: dict[str, Any]
|
||||||
|
"""数据"""
|
||||||
|
|
||||||
|
def insert(self, key: str, value: Any):
|
||||||
|
self.data[key] = value
|
||||||
|
|
||||||
|
def count(self):
|
||||||
|
return len(self.data)
|
||||||
|
|
||||||
|
class DataSource(Protocol):
|
||||||
|
def __iter__(self) -> Iterator[tuple[str, Any]]:
|
||||||
|
...
|
||||||
|
|
||||||
|
class FileDataSource(DataSource):
|
||||||
|
def __init__(self, folder_path: str, keep_ext: bool = True):
|
||||||
|
self.path = os.path.abspath(folder_path)
|
||||||
|
self.keep_ext = keep_ext
|
||||||
|
|
||||||
|
def __iter__(self) -> Iterator[tuple[str, Any]]:
|
||||||
|
for file in os.listdir(self.path):
|
||||||
|
if not self.keep_ext:
|
||||||
|
file = os.path.splitext(file)[0]
|
||||||
|
yield file, cv2.imread(os.path.join(self.path, file))
|
||||||
|
|
||||||
|
class DatabaseQueryResult(NamedTuple):
|
||||||
|
key: str
|
||||||
|
feature: Any
|
||||||
|
distance: float
|
||||||
|
|
||||||
|
def chi2_distance(hist1: np.ndarray, hist2: np.ndarray, eps=1e-10):
|
||||||
|
return 0.5 * np.sum((hist1 - hist2) ** 2 / (hist1 + hist2 + eps))
|
||||||
|
|
||||||
|
class ImageDatabase:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
source: DataSource,
|
||||||
|
db_path: str,
|
||||||
|
descriptor: HistDescriptor,
|
||||||
|
*,
|
||||||
|
name: str | None = None
|
||||||
|
):
|
||||||
|
self.db_path = db_path
|
||||||
|
self.__db: Db | None = None
|
||||||
|
self.descriptor = descriptor
|
||||||
|
self.source = source
|
||||||
|
|
||||||
|
# 载入数据库
|
||||||
|
logger.info('Loading database from %s...', db_path)
|
||||||
|
if os.path.exists(db_path):
|
||||||
|
try:
|
||||||
|
with open(db_path, 'rb') as f:
|
||||||
|
self.__db = pickle.load(f)
|
||||||
|
logger.info('Database loaded. Name=%s, version=%s, count=%d', self.db.name, self.db.version, self.db.count())
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning('Failed to load database from %s: %s', db_path, e)
|
||||||
|
self.__db = None
|
||||||
|
if self.__db is None:
|
||||||
|
self.__db = Db(DATABASE_INTERNAL_VERSION, None, name, {})
|
||||||
|
|
||||||
|
# 检查版本
|
||||||
|
if self.db.internal_version != DATABASE_INTERNAL_VERSION:
|
||||||
|
logger.info('Database internal version is %d, expected %d. Clearing database...', self.db.internal_version, DATABASE_INTERNAL_VERSION)
|
||||||
|
self.db.data.clear()
|
||||||
|
self.db.internal_version = DATABASE_INTERNAL_VERSION
|
||||||
|
|
||||||
|
# 载入数据源
|
||||||
|
logger.debug('Loading data source...')
|
||||||
|
for key, value in self.source:
|
||||||
|
self.insert(key, value)
|
||||||
|
self.save()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def db(self) -> Db:
|
||||||
|
if not self.__db:
|
||||||
|
raise RuntimeError('Database not loaded')
|
||||||
|
return self.__db
|
||||||
|
|
||||||
|
def save(self):
|
||||||
|
with open(self.db_path, 'wb') as f:
|
||||||
|
pickle.dump(self.db, f)
|
||||||
|
|
||||||
|
def insert(self, key: str, image: MatLike | str, *, overwrite: bool = False):
|
||||||
|
"""
|
||||||
|
向图像数据库中插入一条新记录。
|
||||||
|
|
||||||
|
:param key: 图片的 ID。
|
||||||
|
:param image: 图片的路径或 MatLike。
|
||||||
|
若为 MatLike,必须为 BGR 格式。
|
||||||
|
:param overwrite: 是否覆盖已存在的记录。
|
||||||
|
"""
|
||||||
|
if isinstance(image, str):
|
||||||
|
image = cv2.imread(image)
|
||||||
|
if overwrite or key not in self.db.data:
|
||||||
|
self.db.insert(key, self.descriptor(image))
|
||||||
|
logger.debug('Inserted image: %s', key)
|
||||||
|
|
||||||
|
def insert_many(self, images: dict[str, str | MatLike], *, overwrite: bool = False):
|
||||||
|
"""
|
||||||
|
向图像数据库中插入多条新记录。
|
||||||
|
|
||||||
|
:param images: 图片。key 为图片的 ID,value 为图片的路径或 MatLike。
|
||||||
|
若为 MatLike,必须为 BGR 格式。
|
||||||
|
:param overwrite: 是否覆盖已存在的记录。
|
||||||
|
"""
|
||||||
|
for name, image in images.items():
|
||||||
|
self.insert(name, image, overwrite=overwrite)
|
||||||
|
|
||||||
|
def search(self, query: MatLike, threshold: float = 10) -> list[DatabaseQueryResult]:
|
||||||
|
"""
|
||||||
|
搜索图片,返回所有符合阈值要求的图片,并按相似度降序排序。
|
||||||
|
|
||||||
|
:param image: 待搜索的图片。必须为 BGR 格式。
|
||||||
|
:param threshold: 距离阈值。阈值越大,对相似度的要求越低。
|
||||||
|
:return: 搜索结果。
|
||||||
|
"""
|
||||||
|
query_feature = self.descriptor(query)
|
||||||
|
results = list[DatabaseQueryResult]()
|
||||||
|
for key, feature in self.db.data.items():
|
||||||
|
dist = chi2_distance(query_feature, feature)
|
||||||
|
if dist < threshold:
|
||||||
|
results.append(DatabaseQueryResult(key, feature, float(dist)))
|
||||||
|
results.sort(key=lambda x: x.distance)
|
||||||
|
return results
|
||||||
|
|
||||||
|
def match(self, query: MatLike, threshold: float = 10) -> DatabaseQueryResult | None:
|
||||||
|
"""
|
||||||
|
匹配图片,寻找与输入图片最相似的图片。
|
||||||
|
|
||||||
|
:param image: 待匹配的图片。必须为 BGR 格式。
|
||||||
|
:param threshold: 距离阈值。阈值越大,对相似度的要求越低。
|
||||||
|
:return: 匹配结果。
|
||||||
|
"""
|
||||||
|
results = self.search(query, threshold)
|
||||||
|
if len(results) > 0:
|
||||||
|
return results[0]
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
from kotonebot.tasks.db.image_db.db import Db
|
||||||
|
logging.basicConfig(level=logging.DEBUG, format='[%(asctime)s] [%(levelname)s] [%(name)s] [%(funcName)s] [%(lineno)d] %(message)s')
|
||||||
|
imgs_path = r'E:\GithubRepos\KotonesAutoAssistant.worktrees\dev\kotonebot\tasks\resources\idol_cards'
|
||||||
|
needle_path = r'D:\05.png'
|
||||||
|
db = ImageDatabase(FileDataSource(imgs_path), r'D:\idols.pkl', HistDescriptor(8), name='idols')
|
||||||
|
# if db.db.count() == 0:
|
||||||
|
# db.insert({file: os.path.join(imgs_path, file) for file in os.listdir(imgs_path)})
|
||||||
|
needle = cv2.imread(needle_path)
|
||||||
|
result = db.match(needle)
|
||||||
|
print(result)
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .hist import HistDescriptor
|
||||||
|
|
||||||
|
__all__ = ['HistDescriptor']
|
|
@ -0,0 +1,39 @@
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from cv2.typing import MatLike
|
||||||
|
|
||||||
|
class HistDescriptor:
|
||||||
|
def __init__(self, bin_count: int):
|
||||||
|
self.bin_count = bin_count
|
||||||
|
|
||||||
|
def __call__(self, image: MatLike):
|
||||||
|
img = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
|
||||||
|
# 将图像均分为九个区域
|
||||||
|
masks = []
|
||||||
|
height, width = img.shape[:2]
|
||||||
|
for i in range(3):
|
||||||
|
for j in range(3):
|
||||||
|
start_row, start_col = i * height // 3, j * width // 3
|
||||||
|
end_row, end_col = (i + 1) * height // 3, (j + 1) * width // 3
|
||||||
|
mask = np.zeros(img.shape[:2], dtype=np.uint8)
|
||||||
|
mask[start_row:end_row, start_col:end_col] = 255
|
||||||
|
masks.append(mask)
|
||||||
|
# 依次计算九个区域的直方图
|
||||||
|
features = np.array([])
|
||||||
|
for mask in masks:
|
||||||
|
hist = cv2.calcHist(
|
||||||
|
[img],
|
||||||
|
[0, 1, 2],
|
||||||
|
mask,
|
||||||
|
[self.bin_count, self.bin_count, self.bin_count],
|
||||||
|
[0, 180, 0, 256, 0, 256]
|
||||||
|
)
|
||||||
|
hist = cv2.normalize(hist, hist)
|
||||||
|
features = np.append(features, hist.flatten())
|
||||||
|
return features
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
d = HistDescriptor(8)
|
||||||
|
img = cv2.imread(r'E:\GithubRepos\KotonesAutoAssistant.worktrees\dev\kotonebot\tasks\resources\idol_cards\i_card-amao-2-000_1.png')
|
||||||
|
print(d(img))
|
||||||
|
cv2.waitKey(0)
|
Loading…
Reference in New Issue