284 lines
9.7 KiB
Python
284 lines
9.7 KiB
Python
import json
|
||
|
||
import re
|
||
from scipy.spatial.distance import cosine
|
||
import fitz
|
||
import shutil
|
||
import os
|
||
import logging
|
||
import backoff
|
||
from openai import OpenAI
|
||
import numpy as np
|
||
|
||
# 设置日志
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")
|
||
logger = logging.getLogger(__name__)
|
||
|
||
client = OpenAI(base_url="@base_url@", api_key="@api_key@")
|
||
|
||
# 设置输入和输出路径
|
||
input_file_path = "@input_file_path@"
|
||
output_dir = "saveChunk/"
|
||
output_file_path = "@output_file_path@"
|
||
|
||
# 设置每个文本块的最大分词数量
|
||
chunk_max_length = @chunk_max_length@
|
||
# 分块阈值
|
||
start_chunk_threshold = @start_chunk_threshold@
|
||
# 相似度阈值
|
||
similarity_threshold = @similarity_threshold@
|
||
# 数据分析次数
|
||
entries_per_file = @entries_per_file@
|
||
|
||
def read_file(file_path: str) -> str:
|
||
with open(file_path, "r", encoding="utf-8") as file:
|
||
return file.read()
|
||
|
||
@backoff.on_exception(backoff.expo, Exception, max_tries=3)
|
||
def generate_single_entry(text: str):
|
||
prompt = f"""
|
||
基于以下文本,生成1个用于指令数据集的高质量条目。条目应该直接关联到给定的文本内容,提出相关的问题或任务。
|
||
请确保生成多样化的指令类型,例如:
|
||
- 分析类:"分析..."
|
||
- 比较类:"比较..."
|
||
- 解释类:"解释..."
|
||
- 评价类:"评价..."
|
||
- 问答类:"为什么..."
|
||
|
||
文本内容:
|
||
{text}
|
||
|
||
请以下面的格式生成条目,确保所有字段都有适当的内容:
|
||
{{
|
||
"instruction": "使用上述多样化的指令类型之一,提出一个具体的、与文本相关的问题或任务",
|
||
"input": "如果需要额外的上下文信息,请在这里提供,否则跟上面的instruction保持一致",
|
||
"output": "对instruction的详细回答或任务的完成结果"
|
||
}}
|
||
确保所有生成的内容都与给定的文本直接相关,生成的是完整、有效的JSON格式,并且内容高质量、准确、详细,当有多个json时,用空行分隔。
|
||
"""
|
||
|
||
try:
|
||
|
||
resp = client.chat.completions.create(
|
||
model="glm4-chat01",
|
||
messages=[
|
||
{"role": "system", "content": "你是一个指令生成专家"},
|
||
{"role": "user", "content": prompt}
|
||
]
|
||
)
|
||
response = resp.choices[0].message.content
|
||
|
||
result = ""
|
||
jsonStrList = response.split("}")
|
||
for item in jsonStrList:
|
||
if item == "":
|
||
continue
|
||
json_str = item + "}"
|
||
try:
|
||
data = json.loads(json_str)
|
||
# 检查必要的键是否存在
|
||
required_keys = {"instruction", "input", "output"}
|
||
if not required_keys.issubset(data.keys()):
|
||
logger.error(f"生成的条目缺少一些关键字段,请检查:{required_keys - data.keys()}")
|
||
continue
|
||
result = result + item + "},"
|
||
except json.JSONDecodeError as e:
|
||
logger.error(f"解析JSON字符串时发生错误: {str(e)}, output: {json_str}")
|
||
|
||
logger.info(f"output: {result}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
logger.error(f"生成条目时发生错误: {str(e)}")
|
||
raise
|
||
|
||
|
||
def generate_dataset(folder_path: str, output_file_path, entries_per_file: int = 2):
|
||
dataset = []
|
||
result = "[ \\n"
|
||
for filename in os.listdir(folder_path):
|
||
if filename.endswith(".txt"):
|
||
file_path = os.path.join(folder_path, filename)
|
||
logger.info(f"正在处理文件: {filename}")
|
||
text = read_file(file_path)
|
||
for j in range(entries_per_file):
|
||
logger.info(f" 生成第 {j + 1}/{entries_per_file} 个条目")
|
||
entry = generate_single_entry(text)
|
||
if entry == None:
|
||
logger.error("生成条目时发生错误,跳过当前条目")
|
||
continue
|
||
result = result + entry
|
||
|
||
result = result[:-1] + "\\n]"
|
||
# 将结果写入到文件中
|
||
with open(output_file_path, "w") as f:
|
||
f.write(result)
|
||
return result
|
||
|
||
def get_sentence_embedding(sentence, client):
|
||
"""
|
||
获取句子的嵌入表示
|
||
|
||
参数:
|
||
sentence (str): 输入的句子
|
||
client: OpenAI 客户端实例
|
||
|
||
返回:
|
||
numpy.ndarray: 句子的嵌入向量
|
||
"""
|
||
# 使用 Xinference 嵌入 API 获取句子嵌入
|
||
global embedding
|
||
try:
|
||
response = client.embeddings.create(model="bge-large-zh-v1.5", input=sentence)
|
||
embedding = response.data[0].embedding
|
||
except Exception as e:
|
||
logger.error(e)
|
||
return np.array(embedding)
|
||
|
||
def split_text_by_semantic(text, chunk_max_length, similarity_threshold=0.5):
|
||
"""
|
||
基于语义相似度对文本进行分块
|
||
|
||
参数:
|
||
text (str): 输入的长文本
|
||
chunk_max_length (int): 每个文本块的最大长度(以token为单位)
|
||
similarity_threshold (float): 语义相似度阈值,默认为0.5
|
||
|
||
返回:
|
||
list: 分割后的文本块列表
|
||
"""
|
||
chunks = []
|
||
|
||
# 按句子分割文本(使用常见的中文标点符号)
|
||
# sentences = re.split(r"(。|!|?|;)", text)
|
||
sentences = re.split(r"\\n\s*\\n", text)
|
||
# 重新组合句子和标点
|
||
sentences = [s + p for s, p in zip(sentences[::2], sentences[1::2]) if s]
|
||
|
||
current_chunk = sentences[0]
|
||
# 获取当前 chunk 的嵌入表示
|
||
current_embedding = get_sentence_embedding(current_chunk, client)
|
||
|
||
for sentence in sentences[1:]:
|
||
# 过滤掉空数据
|
||
if not sentence.strip():
|
||
continue
|
||
# 删除文本中的空行
|
||
sentence = re.sub(r"\\n\s*\\n", "", sentence)
|
||
# 获取当前句子的嵌入表示
|
||
sentence_embedding = get_sentence_embedding(sentence, client)
|
||
# 计算当前 chunk 和当前句子的余弦相似度
|
||
similarity = 1 - cosine(current_embedding, sentence_embedding)
|
||
logger.info(f"similarity: {similarity}, and sentence: {sentence}")
|
||
|
||
# 如果相似度高于阈值且合并后不超过最大长度,则合并
|
||
if similarity > similarity_threshold and len(current_chunk + sentence) <= chunk_max_length:
|
||
current_chunk += sentence
|
||
# 更新当前 chunk 的嵌入表示
|
||
current_embedding = (current_embedding + sentence_embedding) / 2
|
||
else:
|
||
# 否则,保存当前 chunk 并开始新的 chunk
|
||
chunks.append(current_chunk)
|
||
current_chunk = sentence
|
||
current_embedding = sentence_embedding
|
||
|
||
# 添加最后一个 chunk
|
||
if current_chunk:
|
||
chunks.append(current_chunk)
|
||
|
||
return chunks
|
||
|
||
|
||
def read_text_file(file_path):
|
||
with open(file_path, "r", encoding="utf-8") as file:
|
||
return file.read()
|
||
|
||
|
||
def save_chunks_to_files(chunks, output_dir):
|
||
"""
|
||
将分割后的文本块保存到文件
|
||
|
||
参数:
|
||
chunks (list): 文本块列表
|
||
output_dir (str): 输出目录路径
|
||
"""
|
||
# 如果输出目录不存在,则创建
|
||
if not os.path.exists(output_dir):
|
||
os.makedirs(output_dir)
|
||
|
||
# 将每个文本块保存为单独的文件
|
||
for i, chunk in enumerate(chunks):
|
||
chunk_file_path = os.path.join(output_dir, f"chunk_{i + 1}.txt")
|
||
with open(chunk_file_path, "w", encoding="utf-8") as file:
|
||
file.write(chunk)
|
||
logger.info(f"已保存第 {i + 1} 个文本块到 {chunk_file_path}")
|
||
|
||
|
||
def pdf_to_text(pdf_path, txt_path):
|
||
pdf_document = fitz.open(pdf_path)
|
||
with open(txt_path, "w", encoding="utf-8") as text_file:
|
||
for page_num in range(len(pdf_document)):
|
||
page = pdf_document.load_page(page_num)
|
||
text = page.get_text()
|
||
text_file.write(text)
|
||
pdf_document.close()
|
||
|
||
|
||
def clean_dir(directory):
|
||
try:
|
||
shutil.rmtree(directory)
|
||
logger.info(f"成功删除文件夹及其内容: {directory}")
|
||
except FileNotFoundError:
|
||
logger.info(f"文件夹 {directory} 不存在")
|
||
except PermissionError:
|
||
logger.info(f"没有权限删除文件夹 {directory}")
|
||
except Exception as e:
|
||
logger.info(f"删除失败: {e}")
|
||
|
||
try:
|
||
os.makedirs(directory, exist_ok=True)
|
||
logger.info(f"成功创建文件夹: {directory}")
|
||
except OSError as e:
|
||
logger.info(f"创建文件夹失败: {e}")
|
||
|
||
|
||
def get_file_type(file_path):
|
||
_, ext = os.path.splitext(file_path)
|
||
ext = ext.lower()
|
||
|
||
if ext in [".txt"]:
|
||
return "txt"
|
||
elif ext in [".pdf"]:
|
||
return "pdf"
|
||
else:
|
||
return "unknown"
|
||
|
||
|
||
if __name__ == "__main__":
|
||
clean_dir(output_dir)
|
||
|
||
for root, dirs, files in os.walk(input_file_path):
|
||
for file in files:
|
||
input_file = os.path.join(root, file)
|
||
if get_file_type(input_file) == "pdf":
|
||
pdf_to_text(input_file, input_file + ".txt")
|
||
input_file = input_file + ".txt"
|
||
elif get_file_type(input_file) == "unknown":
|
||
raise ValueError("输入文件类型不正确,请输入文本文件或PDF文件")
|
||
# 读取长文本
|
||
long_text = read_text_file(input_file)
|
||
text_chunks = [long_text]
|
||
|
||
if len(long_text) > start_chunk_threshold:
|
||
text_chunks = split_text_by_semantic(long_text, chunk_max_length, similarity_threshold)
|
||
|
||
save_chunks_to_files(text_chunks, output_dir)
|
||
|
||
logger.info("开始生成数据集")
|
||
output_file_path = output_file_path + "/"
|
||
os.makedirs(os.path.dirname(output_file_path), exist_ok=True)
|
||
output_file = os.path.join(output_file_path, file+".json")
|
||
dataset = generate_dataset(output_dir, output_file, entries_per_file)
|
||
logger.info(f"数据集已生成并保存到 {output_file}")
|
||
|