kotones-auto-assistant/experiments/moondream_test_gradio.py

140 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from transformers import AutoModelForCausalLM, AutoTokenizer
import gradio as gr
from PIL import Image
import torch
import numpy as np
from PIL import ImageDraw
# 加载模型
model_id = "vikhyatk/moondream2"
revision = "2025-01-09" # Pin to specific version
def load_model():
model = AutoModelForCausalLM.from_pretrained(
model_id, trust_remote_code=True, revision=revision,
torch_dtype=torch.float16, device_map={"": "cuda"}
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_id, revision=revision)
return model, tokenizer
model, tokenizer = load_model()
print('Model loaded successfully')
def answer_question(image, question):
enc_image = model.encode_image(image)
answer = model.answer_question(enc_image, question, tokenizer)
return answer
def generate_caption(image, length="normal"):
result = model.caption(image, length=length)
return result["caption"]
def visual_query(image, query):
result = model.query(image, query)
return result["answer"]
def detect_objects(image, object_type):
# 获取检测结果
result = model.detect(image, object_type)
objects = result["objects"]
print("Detection result:", result) # 添加调试信息
# 在图像上标注检测结果
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
width, height = image.size
# 为每个检测到的对象画框
for obj in objects:
# 将相对坐标转换为像素坐标
x_min = int(float(obj['x_min']) * width)
y_min = int(float(obj['y_min']) * height)
x_max = int(float(obj['x_max']) * width)
y_max = int(float(obj['y_max']) * height)
box = [x_min, y_min, x_max, y_max]
print(f"Object box (pixels): {box}") # 添加每个对象的框信息
draw.rectangle(box, outline="red", width=3)
# 返回详细的检测结果
result_text = f"Found {len(objects)} {object_type}(s)\n"
for i, obj in enumerate(objects, 1):
result_text += f"Object {i}: x={obj['x_min']:.3f}-{obj['x_max']:.3f}, y={obj['y_min']:.3f}-{obj['y_max']:.3f}\n"
return result_text, img_draw
def point_objects(image, object_type):
# 获取点定位结果
result = model.point(image, object_type)
points = result["points"]
print("Points result:", points) # 添加调试信息
# 在图像上标注点
img_draw = image.copy()
draw = ImageDraw.Draw(img_draw)
width, height = image.size
# 为每个点画圆
result_text = f"Found {len(points)} {object_type}(s)\n"
for i, point in enumerate(points, 1):
if isinstance(point, dict) and 'x' in point and 'y' in point:
# 坐标是相对值0-1需要转换为实际像素坐标
x = int(float(point['x']) * width)
y = int(float(point['y']) * height)
result_text += f"Point {i}: x={point['x']:.3f}, y={point['y']:.3f}\n"
radius = 10
draw.ellipse([x-radius, y-radius, x+radius, y+radius], fill="red")
else:
print(f"Unexpected point format: {point}")
result_text += f"Point {i}: Invalid format\n"
return result_text, img_draw
# 创建 Gradio 界面
with gr.Blocks() as demo:
gr.Markdown("# Moondream2 视觉分析演示")
with gr.Tab("问答"):
with gr.Row():
image_input1 = gr.Image(type="pil", label="上传图片")
question_input = gr.Textbox(label="输入问题")
answer_output = gr.Textbox(label="回答")
answer_button = gr.Button("获取回答")
answer_button.click(answer_question, inputs=[image_input1, question_input], outputs=answer_output)
with gr.Tab("图像描述"):
with gr.Row():
image_input2 = gr.Image(type="pil", label="上传图片")
length_input = gr.Radio(["short", "normal"], label="描述长度", value="normal")
caption_output = gr.Textbox(label="描述结果")
caption_button = gr.Button("生成描述")
caption_button.click(generate_caption, inputs=[image_input2, length_input], outputs=caption_output)
with gr.Tab("视觉查询"):
with gr.Row():
image_input3 = gr.Image(type="pil", label="上传图片")
query_input = gr.Textbox(label="输入查询")
query_output = gr.Textbox(label="查询结果")
query_button = gr.Button("执行查询")
query_button.click(visual_query, inputs=[image_input3, query_input], outputs=query_output)
with gr.Tab("物体检测"):
with gr.Row():
image_input4 = gr.Image(type="pil", label="上传图片")
object_input = gr.Textbox(label="输入要检测的物体类型")
with gr.Row():
detect_text_output = gr.Textbox(label="检测结果")
detect_image_output = gr.Image(type="pil", label="可视化结果")
detect_button = gr.Button("开始检测")
detect_button.click(detect_objects, inputs=[image_input4, object_input], outputs=[detect_text_output, detect_image_output])
with gr.Tab("点定位"):
with gr.Row():
image_input5 = gr.Image(type="pil", label="上传图片")
point_object_input = gr.Textbox(label="输入要定位的物体类型")
with gr.Row():
point_text_output = gr.Textbox(label="定位结果")
point_image_output = gr.Image(type="pil", label="可视化结果")
point_button = gr.Button("开始定位")
point_button.click(point_objects, inputs=[image_input5, point_object_input], outputs=[point_text_output, point_image_output])
demo.launch(share=True)