mirror of https://github.com/OpenSPG/KAG
feat(solver): support kag thinker (#640)
* feat(kag): update to v0.7 (#456) * add think cost * update csv scanner * add final rerank * add reasoner * add iterative planner * fix dpr search * fix dpr search * add reference data * move odps import * update requirement.txt * update 2wiki * add missing file * fix markdown reader * add iterative planning * update version * update runner * update 2wiki example * update bridge * merge solver and solver_new * add cur day * writer delete * update multi process * add missing files * fix report * add chunk retrieved executor * update try in stream runner result * add path * add math executor * update hotpotqa example * remove log * fix python coder solver * update hotpotqa example * fix python coder solver * update config * fix bad * add log * remove unused code * commit with task thought * move kag model to common * add default chat llm * fix * use static planner * support chunk graph node * add args * support naive rag * llm client support tool calls * add default async * add openai * fix result * fix markdown reader * fix thinker * update asyncio interface * feat(solver): add mcp support (#444) * 上传mcp client相关代码 * 1、完成一套mcp client的调用,从pipeline到planner、executor 2、允许json中传入多个mcp_server,通过大模型进行调用并选择 3、调通baidu_map_mcp的使用 * 1、schema * bugfix:删减冗余代码 --------- Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> * fix affairqa after solver refactor * fix affairqa after solver refactor * fix readme * add params * update version * update mcp executor * update mcp executor * solver add mcp executor * add missing file * add mpc executor * add executor * x * update * fix requirement * fix main llm config * fix solver * bugfix:修复invoke函数调用逻辑 * chg eva * update example * add kag layer * add step task * support dot refresh * support dot refresh * support dot refresh * support dot refresh * add retrieved num * add retrieved num * add pipelineconf * update ppr * update musique prompts * update * add to_dict for BuilderComponentData * async build * add deduce prompt * add deduce prompt * add deduce prompt * fix reader * add deduce prompt * add page thinker report * modify prmpt * add step status * add self cognition * add self cognition * add memory graph storage * add now time * update memory config * add now time * chg graph loader * 添加prqa数据集和代码 * bugfix:prqa调用逻辑修复 * optimize:优化代码逻辑,生成答案规范化 * add retry py code * update memory graph * update memory graph * fix * fix ner * add with_out_refer generator prompt * fix * close ckpt * fix query * fix query * update version * add llm checker * add llm checker * 1、上传evalutor.py以及修改gold_answer.json格式 2、优化代码逻辑 3、修改README.md文件 * update exp * update exp * rerank support * add static rewrite query * recall more chunks * fix graph load * add static rewrite query * fix bugs * add finish check * add finish check * add finish check * add finish check * 1、上传evalutor.py的结果 2、优化代码逻辑,优化readme文件 * add lf retry * add memory graph api * fix reader api * add ner * add metrics * fix bug * remove ner * add reraise fo retry * add edge prop to memory graph * add memory graph * 1、评测数据集结果修正 2、优化evaluator.py代码 3、删除结果不存在而gold_answer中有答案的问题 * 删除评测结果文件 * fix knext host addr * async eva * add lf prompt * add lf prompt * add config * add retry * add unknown check * add rc result * add rc result * add rc result * add rc result * 依据kag pipeline格式修改代码逻辑并通过测试 * bugfix:删除冗余代码 * fix report prompt * bugfix:触发重试机制 * bugfix:中文符号错误 * fix rethinker prompt * update version to 0.6.2b78 * update version * 1、修改evaluator.py,通过大模型计算准确率,符合最新调用逻辑 2、修改prompt,让没有回答的结果重复测试 * update affairqa for evaluate * update affairqa for evaluate * bugfix:修正数据集 * bugfix:修正数据集 * bugfix:修正数据集 * fix name conflict * bugfix:删除错误问题 * bugfix:文件名命名错误导致evaluator失败 * update for affairqa eval * bugfix:修改代码保持evaluate逻辑一致 * x * update for affairqa readme * remove temp eval scripts * bugfix for math deduce * merge 0.6.2_dev * merge 0.6.2_dev * fix * update client addr * updated version * update for affairqa eval * evaUtils 支持中文 * fix affairqa eval: * remove unused example * update kag config * fix default value * update readme * fix init * 注释信息修改,并添加部分class说明 * update example config * Tc 0.7.0 (#459) * 提交affairQA 代码 * fix affairqa eval --------- Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com> * fix all examples * reformat --------- Co-authored-by: peilong <peilong.zpl@antgroup.com> Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com> Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com> * update chunk metadata * update chunk metadata * add debug reporter * update table text * add server * fix math executor * update api-key for openai vec * update * fix naive rag bug * format code * fix --------- Co-authored-by: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com> Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com> Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com> Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>
This commit is contained in:
parent
9b2d894295
commit
e1012d39e4
|
@ -1 +1 @@
|
|||
0.8.0
|
||||
0.8.0
|
|
@ -463,9 +463,17 @@ def resolve_instance(
|
|||
|
||||
|
||||
def extract_tag_content(text):
|
||||
# 匹配<tag>和</tag>之间的内容,支持任意标签名
|
||||
matches = re.findall(r"<([^>]+)>(.*?)</\1>", text, flags=re.DOTALL)
|
||||
return [(tag, content.strip()) for tag, content in matches]
|
||||
pattern = r"<(\w+)\b[^>]*>(.*?)</\1>|<(\w+)\b[^>]*>([^<]*)|([^<]+)"
|
||||
results = []
|
||||
for match in re.finditer(pattern, text, re.DOTALL):
|
||||
tag1, content1, tag2, content2, raw_text = match.groups()
|
||||
if tag1:
|
||||
results.append((tag1, content1)) # 保留原始内容(含空格)
|
||||
elif tag2:
|
||||
results.append((tag2, content2)) # 保留原始内容(含空格)
|
||||
elif raw_text:
|
||||
results.append(("", raw_text)) # 保留原始空格
|
||||
return results
|
||||
|
||||
|
||||
def extract_specific_tag_content(text, tag):
|
||||
|
|
|
@ -131,9 +131,11 @@ class PyBasedMathExecutor(ExecutorABC):
|
|||
)
|
||||
|
||||
parent_results = format_task_dep_context(task.parents)
|
||||
parent_results = "\n".join(parent_results)
|
||||
coder_content = context.kwargs.get("planner_thought", "") + "\n\n".join(
|
||||
parent_results
|
||||
)
|
||||
|
||||
parent_results += "\n\n" + contents
|
||||
coder_content += "\n\n" + contents
|
||||
tries = self.tries
|
||||
error = None
|
||||
|
||||
|
@ -141,7 +143,7 @@ class PyBasedMathExecutor(ExecutorABC):
|
|||
tries -= 1
|
||||
rst, error, code = self.run_once(
|
||||
math_query,
|
||||
parent_results,
|
||||
coder_content,
|
||||
error,
|
||||
segment_name=tag_id,
|
||||
tag_name=f"{task_query}_code_generator",
|
||||
|
|
|
@ -42,6 +42,15 @@ from kag.solver.utils import init_prompt_with_fallback
|
|||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def _wrapped_invoke(retriever, task, context, segment_name, kwargs):
|
||||
start_time = time.time()
|
||||
output = retriever.invoke(
|
||||
task, context=context, segment_name=segment_name, **kwargs
|
||||
)
|
||||
elapsed_time = time.time() - start_time
|
||||
return output, elapsed_time
|
||||
|
||||
|
||||
@ExecutorABC.register("kag_hybrid_retrieval_executor")
|
||||
class KAGHybridRetrievalExecutor(ExecutorABC):
|
||||
def __init__(
|
||||
|
@ -76,6 +85,7 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
|
|||
self.context_select_prompt = context_select_prompt or PromptABC.from_config(
|
||||
{"type": "context_select_prompt"}
|
||||
)
|
||||
self.with_llm_select = kwargs.get("with_llm_select", True)
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1))
|
||||
def context_select_call(self, variables):
|
||||
|
@ -152,22 +162,30 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
|
|||
"FINISH",
|
||||
component_name=retriever.name,
|
||||
)
|
||||
|
||||
# Record start time before submitting the task
|
||||
start_time = time.time()
|
||||
# Prepare function and submit to thread pool
|
||||
func = partial(
|
||||
retriever.invoke,
|
||||
_wrapped_invoke,
|
||||
retriever,
|
||||
task,
|
||||
context=context,
|
||||
segment_name=tag_id,
|
||||
**kwargs,
|
||||
context,
|
||||
tag_id,
|
||||
kwargs.copy(),
|
||||
)
|
||||
future = executor.submit(func)
|
||||
# Save future, retriever, and start_time together
|
||||
futures.append((future, retriever))
|
||||
|
||||
# Collect results from each future
|
||||
for future, retriever in futures:
|
||||
try:
|
||||
output = future.result() # Wait for result
|
||||
output, elapsed_time = future.result() # Wait for result
|
||||
|
||||
# Log the elapsed time for this retriever
|
||||
logger.info(
|
||||
f"Retriever {retriever.name} executed in {elapsed_time:.2f} seconds"
|
||||
)
|
||||
outputs.append(output)
|
||||
|
||||
# Log data report after successful execution
|
||||
|
@ -241,13 +259,18 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
|
|||
selected_rel = list(set(selected_rel))
|
||||
formatted_docs = [str(rel) for rel in selected_rel]
|
||||
if retrieved_data.chunks:
|
||||
try:
|
||||
selected_chunks = self.context_select(task_query, retrieved_data.chunks)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"select context failed {e}, we use default top 10 to summary",
|
||||
exc_info=True,
|
||||
)
|
||||
if self.with_llm_select:
|
||||
try:
|
||||
selected_chunks = self.context_select(
|
||||
task_query, retrieved_data.chunks
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"select context failed {e}, we use default top 10 to summary",
|
||||
exc_info=True,
|
||||
)
|
||||
selected_chunks = retrieved_data.chunks[:10]
|
||||
else:
|
||||
selected_chunks = retrieved_data.chunks[:10]
|
||||
for doc in selected_chunks:
|
||||
formatted_docs.append(f"{doc.content}")
|
||||
|
@ -280,68 +303,81 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
|
|||
task_query = task.arguments["query"]
|
||||
|
||||
tag_id = f"{task_query}_begin_task"
|
||||
self.report_content(reporter, "thinker", tag_id, "", "FINISH", step=task.name)
|
||||
self.report_content(reporter, "thinker", tag_id, "", "INIT", step=task.name)
|
||||
try:
|
||||
retrieved_data = self.do_main(task_query, tag_id, task, context, **kwargs)
|
||||
except Exception as e:
|
||||
logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True)
|
||||
retrieved_data = RetrieverOutput(
|
||||
retriever_method=self.schema().get("name", ""), err_msg=str(e)
|
||||
)
|
||||
|
||||
self.report_content(
|
||||
reporter,
|
||||
"reference",
|
||||
f"{task_query}_kag_retriever_result",
|
||||
retrieved_data,
|
||||
"FINISH",
|
||||
)
|
||||
|
||||
retrieved_data.task = task
|
||||
logical_node = task.arguments.get("logic_form_node", None)
|
||||
if (
|
||||
logical_node
|
||||
and isinstance(logical_node, GetSPONode)
|
||||
and retrieved_data.summary
|
||||
):
|
||||
if isinstance(retrieved_data.summary, str):
|
||||
target_answer = retrieved_data.summary.split("Answer:")[-1].strip()
|
||||
s_entities = context.variables_graph.get_entity_by_alias(
|
||||
logical_node.s.alias_name
|
||||
try:
|
||||
retrieved_data = self.do_main(
|
||||
task_query, tag_id, task, context, **kwargs
|
||||
)
|
||||
if (
|
||||
not s_entities
|
||||
and not logical_node.s.get_mention_name()
|
||||
and isinstance(logical_node.s, SPOEntity)
|
||||
):
|
||||
logical_node.s.entity_name = target_answer
|
||||
context.kwargs[logical_node.s.alias_name] = logical_node.s
|
||||
o_entities = context.variables_graph.get_entity_by_alias(
|
||||
logical_node.o.alias_name
|
||||
except Exception as e:
|
||||
logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True)
|
||||
retrieved_data = RetrieverOutput(
|
||||
retriever_method=self.schema().get("name", ""), err_msg=str(e)
|
||||
)
|
||||
if (
|
||||
not o_entities
|
||||
and not logical_node.o.get_mention_name()
|
||||
and isinstance(logical_node.o, SPOEntity)
|
||||
):
|
||||
logical_node.o.entity_name = target_answer
|
||||
context.kwargs[logical_node.o.alias_name] = logical_node.o
|
||||
|
||||
context.variables_graph.add_answered_alias(
|
||||
logical_node.s.alias_name.alias_name, retrieved_data.summary
|
||||
)
|
||||
context.variables_graph.add_answered_alias(
|
||||
logical_node.p.alias_name.alias_name, retrieved_data.summary
|
||||
)
|
||||
context.variables_graph.add_answered_alias(
|
||||
logical_node.o.alias_name.alias_name, retrieved_data.summary
|
||||
self.report_content(
|
||||
reporter,
|
||||
"reference",
|
||||
f"{task_query}_kag_retriever_result",
|
||||
retrieved_data,
|
||||
"FINISH",
|
||||
)
|
||||
|
||||
task.update_result(retrieved_data)
|
||||
logger.debug(
|
||||
f"kag hybrid retrieval {task_query} cost={time.time() - start_time}"
|
||||
)
|
||||
return retrieved_data
|
||||
retrieved_data.task = task
|
||||
logical_node = task.arguments.get("logic_form_node", None)
|
||||
if (
|
||||
logical_node
|
||||
and isinstance(logical_node, GetSPONode)
|
||||
and retrieved_data.summary
|
||||
):
|
||||
if isinstance(retrieved_data.summary, str):
|
||||
target_answer = retrieved_data.summary.split("Answer:")[-1].strip()
|
||||
s_entities = context.variables_graph.get_entity_by_alias(
|
||||
logical_node.s.alias_name
|
||||
)
|
||||
if (
|
||||
not s_entities
|
||||
and not logical_node.s.get_mention_name()
|
||||
and isinstance(logical_node.s, SPOEntity)
|
||||
):
|
||||
logical_node.s.entity_name = target_answer
|
||||
context.kwargs[logical_node.s.alias_name] = logical_node.s
|
||||
o_entities = context.variables_graph.get_entity_by_alias(
|
||||
logical_node.o.alias_name
|
||||
)
|
||||
if (
|
||||
not o_entities
|
||||
and not logical_node.o.get_mention_name()
|
||||
and isinstance(logical_node.o, SPOEntity)
|
||||
):
|
||||
logical_node.o.entity_name = target_answer
|
||||
context.kwargs[logical_node.o.alias_name] = logical_node.o
|
||||
|
||||
context.variables_graph.add_answered_alias(
|
||||
logical_node.s.alias_name.alias_name, retrieved_data.summary
|
||||
)
|
||||
context.variables_graph.add_answered_alias(
|
||||
logical_node.p.alias_name.alias_name, retrieved_data.summary
|
||||
)
|
||||
context.variables_graph.add_answered_alias(
|
||||
logical_node.o.alias_name.alias_name, retrieved_data.summary
|
||||
)
|
||||
|
||||
task.update_result(retrieved_data)
|
||||
logger.debug(
|
||||
f"kag hybrid retrieval {task_query} cost={time.time() - start_time}"
|
||||
)
|
||||
return retrieved_data
|
||||
finally:
|
||||
self.report_content(
|
||||
reporter,
|
||||
"thinker",
|
||||
tag_id,
|
||||
"",
|
||||
"FINISH",
|
||||
step=task.name,
|
||||
overwrite=False,
|
||||
)
|
||||
|
||||
def schema(self) -> dict:
|
||||
"""Function schema definition for OpenAI Function Calling
|
||||
|
@ -403,7 +439,7 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
|
|||
node_type=chunk.properties.get("__labels__"),
|
||||
)
|
||||
entity_prop = dict(chunk.properties) if chunk.properties else {}
|
||||
entity_prop["content"] = chunk.content
|
||||
entity_prop["content"] = f"{chunk.content[:10]}..."
|
||||
entity_prop["score"] = chunk.score
|
||||
entity.prop = Prop.from_dict(entity_prop, "Chunk", None)
|
||||
chunk_graph.append(entity)
|
||||
|
|
|
@ -140,8 +140,6 @@ def get_pipeline_conf(use_pipeline_name, config):
|
|||
raise RuntimeError("mcpServers not found in config.")
|
||||
default_solver_pipeline["executors"] = mcp_executors
|
||||
|
||||
# update KAG_CONFIG
|
||||
KAG_CONFIG.update_conf(default_pipeline_conf)
|
||||
return default_solver_pipeline
|
||||
|
||||
|
||||
|
@ -167,8 +165,11 @@ async def do_qa_pipeline(
|
|||
f"Knowledge base with id {kb_project_id} not found in qa_config['kb']"
|
||||
)
|
||||
continue
|
||||
|
||||
for index_name in matched_kb.get("index_list", []):
|
||||
index_list = matched_kb.get("index_list", [])
|
||||
if use_pipeline in ["default_pipeline"]:
|
||||
# we only use chunk index
|
||||
index_list = ["chunk_index"]
|
||||
for index_name in index_list:
|
||||
index_manager = KAGIndexManager.from_config(
|
||||
{
|
||||
"type": index_name,
|
||||
|
@ -339,7 +340,7 @@ class SolverMain:
|
|||
def invoke(
|
||||
self,
|
||||
project_id: int,
|
||||
task_id: int,
|
||||
task_id,
|
||||
query: str,
|
||||
session_id: str = "0",
|
||||
is_report=True,
|
||||
|
|
|
@ -3,20 +3,17 @@ pipeline_name: default_pipeline
|
|||
|
||||
#------------kag-solver configuration start----------------#
|
||||
|
||||
|
||||
chunk_retrieved_executor: &chunk_retrieved_executor_conf
|
||||
type: chunk_retrieved_executor
|
||||
top_k: 10
|
||||
retriever:
|
||||
type: vector_chunk_retriever
|
||||
score_threshold: 0.65
|
||||
vectorize_model: "{vectorize_model}"
|
||||
|
||||
kag_retriever_executor: &kag_retriever_executor_conf
|
||||
type: kag_hybrid_retrieval_executor
|
||||
retrievers: "{retrievers}"
|
||||
merger:
|
||||
type: kag_merger
|
||||
enable_summary: false
|
||||
|
||||
solver_pipeline:
|
||||
type: naive_rag_pipeline
|
||||
executors:
|
||||
- *chunk_retrieved_executor_conf
|
||||
- *kag_retriever_executor_conf
|
||||
generator:
|
||||
type: llm_index_generator
|
||||
llm_client: "{chat_llm}"
|
||||
|
|
|
@ -186,6 +186,7 @@ class KAGModelPlanner(PlannerABC):
|
|||
.replace("</answer>", "")
|
||||
.strip()
|
||||
)
|
||||
context.kwargs["planner_thought"] = logic_form_response
|
||||
|
||||
sub_queries, logic_forms = parse_logic_form_with_str(logic_form_str)
|
||||
logic_forms = self.logic_node_parser.parse_logic_form_set(
|
||||
|
|
|
@ -11,42 +11,49 @@ logger = logging.getLogger(__name__)
|
|||
class ExpressionBuildr(PromptABC):
|
||||
template_zh = (
|
||||
f"今天是{get_now(language='zh')}。"
|
||||
+ """\n# instruction
|
||||
+ """
|
||||
# instruction
|
||||
根据给出的问题和数据,编写python代码,输出问题结果。
|
||||
为了便于理解,输出从context中提取的数据,输出中间计算过程和结果。
|
||||
注意严格根据输入内容进行编写代码,不允许进行假设
|
||||
例如伤残等级如果context中未提及,则认为没有被认定为残疾
|
||||
如果无法回答问题,直接返回:I don't know.
|
||||
从context中提取的数据必须显式赋值,所有计算步骤必须用代码实现,不得隐含推断。
|
||||
必须输出中间计算过程和结果,格式为print语句。
|
||||
如果context未提供必要数据或无法计算,直接打印"I don't know."
|
||||
|
||||
# output format
|
||||
直接输出python代码,python版本为3.10,不要包含任何其他信息
|
||||
严格输出以下结构的python代码(版本3.10):
|
||||
1. 数据提取部分:代码中涉及输入的数值需要从context及question中提取,不允许进行假设
|
||||
2. 计算过程:分步实现所有数学运算,每个步骤对应独立变量
|
||||
3. 输出:每个中间变量和最终结果必须用print语句输出
|
||||
|
||||
# examples
|
||||
## 例子1
|
||||
### input
|
||||
#### question
|
||||
47000元按照万分之1.5一共612天,计算利息,一共多少钱?
|
||||
4百万元按照日利率万分之1.5,一共612天,计算利息,一共多少钱?
|
||||
#### context
|
||||
日利率万分之1.5
|
||||
### output
|
||||
```python
|
||||
# 初始本金
|
||||
principal = 47000
|
||||
# 初始本金(单位:百万)
|
||||
principal = 4 # 单位:百万
|
||||
|
||||
# 利率(万分之1.5)
|
||||
rate = 1.5 / 10000
|
||||
# 日利率计算(万分之1.5)
|
||||
daily_rate = 1.5 / 10000
|
||||
|
||||
# 天数
|
||||
# 计算周期
|
||||
days = 612
|
||||
|
||||
# 计算年利率
|
||||
annual_rate = rate * 365
|
||||
# 单日利息计算
|
||||
daily_interest = principal * daily_rate
|
||||
|
||||
# 计算利息
|
||||
interest = principal * (annual_rate / 365) * days
|
||||
# 累计利息计算
|
||||
total_interest = daily_interest * days
|
||||
|
||||
# 输出总金额(本金+利息)
|
||||
total_amount = principal + interest
|
||||
# 总金额计算
|
||||
total_amount = principal + total_interest
|
||||
|
||||
print(f"总金额:{total_amount:.2f}元")
|
||||
print(f"单日利息:{daily_interest:.2f}百万")
|
||||
print(f"累计利息:{total_interest:.2f}百万")
|
||||
print(f"总金额:{total_amount:.2f}百万")
|
||||
```
|
||||
|
||||
## 例子2
|
||||
|
@ -70,13 +77,26 @@ revenue_2020 = revenue_2019 * (1 + growth_rate)
|
|||
print(f"2020年的预计收入为: {revenue_2020:.2f}万")
|
||||
```
|
||||
|
||||
## 例子3
|
||||
### input
|
||||
#### question
|
||||
47000元按照612天计算利息,本息一共多少钱?
|
||||
#### content
|
||||
|
||||
### output
|
||||
```python
|
||||
# 未给出利率,无法计算
|
||||
print("未给出利率,无法计算")
|
||||
```
|
||||
# input
|
||||
## question
|
||||
$question
|
||||
## context
|
||||
$context
|
||||
## error
|
||||
$error"""
|
||||
$error
|
||||
## output
|
||||
"""
|
||||
)
|
||||
template_en = (
|
||||
f"Today is {get_now(language='en')}。\n"
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import logging
|
||||
import re
|
||||
import time
|
||||
|
||||
from kag.common.conf import KAG_PROJECT_CONF
|
||||
from kag.common.parser.logic_node_parser import extract_steps_and_actions
|
||||
|
@ -72,14 +73,16 @@ def process_tag_template(text):
|
|||
}
|
||||
clean_text = ""
|
||||
for tag_info in all_tags:
|
||||
content = tag_info[1]
|
||||
if tag_info[0] in xml_tag_template:
|
||||
content = tag_info[1]
|
||||
if "search" == tag_info[0]:
|
||||
content = process_planning(content)
|
||||
clean_text += xml_tag_template[tag_info[0]][
|
||||
KAG_PROJECT_CONF.language
|
||||
].format_map(SafeDict({"content": content}))
|
||||
return remove_xml_tags(clean_text)
|
||||
else:
|
||||
clean_text += content
|
||||
text = remove_xml_tags(clean_text)
|
||||
return text
|
||||
|
||||
|
||||
|
|
|
@ -159,12 +159,12 @@ def render_jinja2_template(template_str, context):
|
|||
"""
|
||||
try:
|
||||
template = Template(template_str, undefined=SilentUndefined)
|
||||
return template.render(**context).strip()
|
||||
return template.render(**context)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Jinja2 rendering failed: {e}, Original template: {template_str}"
|
||||
)
|
||||
return template_str.strip() # Fallback to raw template string on failure
|
||||
return template_str # Fallback to raw template string on failure
|
||||
|
||||
|
||||
@ReporterABC.register("open_spg_reporter")
|
||||
|
@ -264,12 +264,12 @@ Rerank the documents and take the top {{ chunk_num }}.
|
|||
}
|
||||
self.tag_mapping = {
|
||||
"Graph Show": {
|
||||
"en": "{content}",
|
||||
"zh": "{content}",
|
||||
"en": "{{ content }}",
|
||||
"zh": "{{ content }}",
|
||||
},
|
||||
"Rewrite query": {
|
||||
"en": "Rethinking question using LLM: {content}",
|
||||
"zh": "根据依赖问题重写子问题: {content}",
|
||||
"en": "Rethinking question using LLM: {{ content }}",
|
||||
"zh": "根据依赖问题重写子问题: {{ content }}",
|
||||
},
|
||||
"language_setting": {
|
||||
"en": "",
|
||||
|
@ -277,125 +277,153 @@ Rerank the documents and take the top {{ chunk_num }}.
|
|||
},
|
||||
"Iterative planning": {
|
||||
"en": """
|
||||
<step status="{status}" title="Global planning">
|
||||
<step status="{{status}}" title="Global planning">
|
||||
|
||||
{content}
|
||||
{{ content }}
|
||||
|
||||
</step>""",
|
||||
{% if status == 'success' %}
|
||||
</step>
|
||||
{% endif %}""",
|
||||
"zh": """
|
||||
<step status="{status}" title="思考当前步骤">
|
||||
<step status="{{status}}" title="思考当前步骤">
|
||||
|
||||
{content}
|
||||
{{ content }}
|
||||
|
||||
</step>""",
|
||||
{% if status == 'success' %}
|
||||
</step>
|
||||
{% endif %}""",
|
||||
},
|
||||
"Static planning": {
|
||||
"en": """
|
||||
<step status="{status}" title="Global planning">
|
||||
<step status="{{status}}" title="Global planning">
|
||||
|
||||
{content}
|
||||
{{ content }}
|
||||
|
||||
</step>""",
|
||||
{% if status == 'success' %}
|
||||
</step>
|
||||
{% endif %}""",
|
||||
"zh": """
|
||||
<step status="{status}" title="思考全局步骤">
|
||||
<step status="{{status}}" title="思考全局步骤">
|
||||
|
||||
{content}
|
||||
{{ content }}
|
||||
|
||||
</step>""",
|
||||
{% if status == 'success' %}
|
||||
</step>
|
||||
{% endif %}""",
|
||||
},
|
||||
"begin_sub_kag_retriever": {
|
||||
"en": "Starting {component_name}: {content} {desc}",
|
||||
"zh": "执行{component_name}: {content} {desc}",
|
||||
"en": "Starting {{component_name}}: {{content}} {{desc}}",
|
||||
"zh": "执行{{component_name}}: {{content}} {{desc}}",
|
||||
},
|
||||
"end_sub_kag_retriever": {
|
||||
"en": " {content}",
|
||||
"zh": " {content}",
|
||||
"en": " {{ content }}",
|
||||
"zh": " {{ content }}",
|
||||
},
|
||||
"rc_retriever_rewrite": {
|
||||
"en": """
|
||||
<step status="{status}" title="Rewriting chunk retriever query">
|
||||
<step status="{{status}}" title="Rewriting chunk retriever query">
|
||||
|
||||
Rewritten question:\n{content}
|
||||
Rewritten question:
|
||||
{{ content }}
|
||||
|
||||
</step>""",
|
||||
{% if status == 'success' %}
|
||||
</step>
|
||||
{% endif %}""",
|
||||
"zh": """
|
||||
<step status="{status}" title="正在根据依赖问题重写检索子问题">
|
||||
<step status="{{status}}" title="正在根据依赖问题重写检索子问题">
|
||||
|
||||
重写问题为:\n\n{content}
|
||||
重写问题为:
|
||||
{{ content }}
|
||||
|
||||
</step>""",
|
||||
{% if status == 'success' %}
|
||||
</step>
|
||||
{% endif %}""",
|
||||
},
|
||||
"rc_retriever_summary": {
|
||||
"en": "Summarizing retrieved documents,{content}",
|
||||
"zh": "对文档进行总结,{content}",
|
||||
"en": "Summarizing retrieved documents,{{ content }}",
|
||||
"zh": "对文档进行总结,{{ content }}",
|
||||
},
|
||||
"kg_retriever_summary": {
|
||||
"en": "Summarizing retrieved graph,{content}",
|
||||
"zh": "对召回的知识进行总结,{content}",
|
||||
"en": "Summarizing retrieved graph,{{ content }}",
|
||||
"zh": "对召回的知识进行总结,{{ content }}",
|
||||
},
|
||||
"retriever_summary": {
|
||||
"en": "Summarizing retrieved documents,{content}",
|
||||
"zh": "对文档进行总结,{content}",
|
||||
"en": "Summarizing retrieved documents,{{ content }}",
|
||||
"zh": "对文档进行总结,{{ content }}",
|
||||
},
|
||||
"begin_summary": {
|
||||
"en": "Summarizing retrieved information, {content}",
|
||||
"zh": "对检索的信息进行总结, {content}",
|
||||
"en": "Summarizing retrieved information, {{ content }}",
|
||||
"zh": "对检索的信息进行总结, {{ content }}",
|
||||
},
|
||||
"begin_task": {
|
||||
"en": """
|
||||
<step status="{status}" title="Starting Task {step}">
|
||||
<step status="{{status}}" title="Starting Task {{step}}">
|
||||
|
||||
{content}
|
||||
{{ content }}
|
||||
|
||||
</step>""",
|
||||
{% if status == 'success' %}
|
||||
</step>
|
||||
{% endif %}""",
|
||||
"zh": """
|
||||
<step status="{status}" title="执行 {step}">
|
||||
<step status="{{status}}" title="执行 {{step}}">
|
||||
|
||||
{content}
|
||||
{{ content }}
|
||||
|
||||
</step>""",
|
||||
{% if status == 'success' %}
|
||||
</step>
|
||||
{% endif %}""",
|
||||
},
|
||||
"logic_node": {
|
||||
"en": """Translate query to logic form expression
|
||||
|
||||
|
||||
```json
|
||||
{content}
|
||||
{{ content }}
|
||||
```""",
|
||||
"zh": """将query转换成逻辑形式表达
|
||||
|
||||
|
||||
```json
|
||||
{content}
|
||||
{{ content }}
|
||||
```""",
|
||||
},
|
||||
"kag_retriever_result": {
|
||||
"en": "Retrieved documents\n\n{content}",
|
||||
"zh": "检索到的文档\n\n{content}",
|
||||
"en": """Retrieved documents
|
||||
{{ content }}""",
|
||||
"zh": """检索到的文档
|
||||
{{ content }}""",
|
||||
},
|
||||
"failed_kag_retriever": {
|
||||
"en": """KAG retriever failed
|
||||
|
||||
|
||||
```json
|
||||
{content}
|
||||
{{ content }}
|
||||
```
|
||||
""",
|
||||
"zh": """KAG检索失败
|
||||
|
||||
|
||||
```json
|
||||
{content}
|
||||
{{ content }}
|
||||
```
|
||||
""",
|
||||
},
|
||||
"end_math_executor": {
|
||||
"en": "Math executor completed\n\n{content}",
|
||||
"zh": "计算结束\n\n{content}",
|
||||
"en": """Math executor completed
|
||||
{{ content }}""",
|
||||
"zh": """计算结束
|
||||
{{ content }}""",
|
||||
},
|
||||
"code_generator": {
|
||||
"en": "Generating code\n \n{content}\n",
|
||||
"zh": "正在生成代码\n \n{content}\n",
|
||||
"en": """Generating code
|
||||
{{ content }}
|
||||
|
||||
""",
|
||||
"zh": """正在生成代码
|
||||
{{ content }}
|
||||
|
||||
""",
|
||||
},
|
||||
}
|
||||
task_id = kwargs.get(KAGConstants.KAG_QA_TASK_CONFIG_KEY, None)
|
||||
|
@ -425,7 +453,11 @@ Rewritten question:\n{content}
|
|||
if tpl:
|
||||
format_params = {"content": datas}
|
||||
format_params.update(content_params)
|
||||
datas = tpl.format_map(SafeDict(format_params))
|
||||
if "{" in tpl or "%}" in tpl:
|
||||
rendered = render_jinja2_template(tpl, format_params)
|
||||
else:
|
||||
rendered = tpl.format_map(SafeDict(format_params))
|
||||
datas = rendered
|
||||
elif str(datas).strip() != "":
|
||||
output = str(datas).strip()
|
||||
if output != "":
|
||||
|
@ -516,7 +548,7 @@ Rewritten question:\n{content}
|
|||
if self.last_report.to_dict() == request.to_dict():
|
||||
return
|
||||
logger.info(
|
||||
f"do_report: {content.answer} think={content.think} status={status_enum} ret={ret}"
|
||||
f"do_report: think={content.think} {content.answer} status={status_enum} ret={ret}"
|
||||
)
|
||||
self.last_report = request
|
||||
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
import concurrent.futures
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
import logging
|
||||
from cachetools import TTLCache
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
class AsyncTaskManager:
|
||||
def __init__(self, max_workers=10, ttl=3600):
|
||||
"""
|
||||
Initialize async task manager
|
||||
|
||||
Args:
|
||||
max_workers (int): Maximum number of worker threads
|
||||
ttl (int): Time-to-live for task results in seconds
|
||||
"""
|
||||
self.max_workers = max_workers
|
||||
self.task_queue = queue.Queue()
|
||||
self.result_cache = TTLCache(maxsize=1000, ttl=ttl)
|
||||
self.result_cache_lock = threading.Lock() # Protect cache from race conditions
|
||||
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
|
||||
self.workers = [
|
||||
threading.Thread(target=self.worker, daemon=True)
|
||||
for _ in range(max_workers)
|
||||
]
|
||||
for w in self.workers:
|
||||
w.start()
|
||||
|
||||
def worker(self):
|
||||
"""Worker thread main loop that processes tasks"""
|
||||
while True:
|
||||
try:
|
||||
# Get next task from queue with timeout to allow shutdown detection
|
||||
task = self.task_queue.get()
|
||||
task_id, func, args, kwargs = task
|
||||
logger.info(f"Processing task {task_id}")
|
||||
# finish flag
|
||||
if task_id is None:
|
||||
self.task_queue.task_done()
|
||||
break
|
||||
|
||||
# Update cache with running status
|
||||
with self.result_cache_lock:
|
||||
self.result_cache[task_id] = {
|
||||
"task_id": task_id,
|
||||
"status": "running",
|
||||
"result": None,
|
||||
}
|
||||
|
||||
# Execute task
|
||||
future = self.executor.submit(func, *args, **kwargs)
|
||||
result = future.result()
|
||||
status = "completed"
|
||||
|
||||
except queue.Empty:
|
||||
# Handle queue empty timeout (normal operation)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
# Handle task execution errors
|
||||
result = str(e)
|
||||
status = "failed"
|
||||
logger.error(f"Task {task_id} failed with error: {e}", exc_info=True)
|
||||
|
||||
# Store final result in cache
|
||||
try:
|
||||
with self.result_cache_lock:
|
||||
self.result_cache[task_id] = {
|
||||
"task_id": task_id,
|
||||
"status": status,
|
||||
"result": result,
|
||||
}
|
||||
logger.info(f"Task {task_id} completed with status: {status}")
|
||||
finally:
|
||||
# Always mark task as done
|
||||
self.task_queue.task_done()
|
||||
|
||||
def submit_task(self, func, *args, **kwargs):
|
||||
"""
|
||||
Submit a new task to the queue
|
||||
|
||||
Args:
|
||||
func: Callable function to execute
|
||||
*args: Positional arguments for the function
|
||||
**kwargs: Keyword arguments for the function
|
||||
|
||||
Returns:
|
||||
str: Unique task ID
|
||||
"""
|
||||
task_id = str(uuid.uuid4())
|
||||
self.task_queue.put((task_id, func, args, kwargs))
|
||||
return task_id
|
||||
|
||||
def get_task_result(self, task_id):
|
||||
"""
|
||||
Get result for a specific task
|
||||
|
||||
Args:
|
||||
task_id (str): Unique task identifier
|
||||
|
||||
Returns:
|
||||
dict: Task result information or expired status
|
||||
"""
|
||||
with self.result_cache_lock:
|
||||
return self.result_cache.get(
|
||||
task_id,
|
||||
{
|
||||
"task_id": task_id,
|
||||
"status": "failed",
|
||||
"result": "Result not found or expired",
|
||||
},
|
||||
)
|
||||
|
||||
def shutdown(self):
|
||||
"""Gracefully shutdown all worker threads and executors"""
|
||||
# Send shutdown signals
|
||||
for _ in range(self.max_workers):
|
||||
self.task_queue.put((None, None, (), {}))
|
||||
|
||||
# Wait for queue to empty and workers to terminate
|
||||
self.task_queue.join()
|
||||
|
||||
# Shutdown executors
|
||||
self.executor.shutdown(wait=True)
|
||||
for worker in self.workers:
|
||||
worker.join(timeout=5)
|
||||
|
||||
|
||||
# Global async task manager instance
|
||||
asyn_task = AsyncTaskManager()
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Create task manager instance
|
||||
task_manager = AsyncTaskManager(max_workers=5, ttl=600)
|
||||
|
||||
# Example task function
|
||||
def example_task(x, y):
|
||||
time.sleep(1) # Simulate work
|
||||
return x
|
||||
|
||||
# Submit test tasks
|
||||
task_ids = [task_manager.submit_task(example_task, i, i + 1) for i in range(6)]
|
||||
|
||||
# Monitor task progress
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1)
|
||||
if all(
|
||||
"completed" in task_manager.get_task_result(tid)["status"]
|
||||
for tid in task_ids
|
||||
):
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Shutting down due to user interrupt")
|
||||
|
||||
# Print results
|
||||
for task_id in task_ids:
|
||||
print(f"Task {task_id} result: {task_manager.get_task_result(task_id)}")
|
||||
|
||||
# Clean up resources
|
||||
task_manager.shutdown()
|
|
@ -0,0 +1,63 @@
|
|||
from fastapi import FastAPI
|
||||
import uvicorn
|
||||
|
||||
from kag.solver.main_solver import SolverMain
|
||||
from kag.solver.server.asyn_task_manager import AsyncTaskManager
|
||||
from kag.solver.server.model.task_req import FeatureRequest, TaskReq
|
||||
|
||||
|
||||
def run_main_solver(task: TaskReq):
|
||||
return SolverMain().invoke(
|
||||
project_id=task.project_id,
|
||||
task_id=task.req_id,
|
||||
query=task.req.query,
|
||||
is_report=task.req.report,
|
||||
host_addr=task.req.host_addr,
|
||||
app_id=task.app_id,
|
||||
params=task.config,
|
||||
)
|
||||
|
||||
|
||||
class KAGSolverServer:
|
||||
def __init__(self, service_name: str):
|
||||
"""
|
||||
Initialize a FastAPI service instance
|
||||
|
||||
Args:
|
||||
service_name (str): Service name, determines which routing logic to load
|
||||
"""
|
||||
self.service_name = service_name
|
||||
self.app = FastAPI(title=f"{service_name} API")
|
||||
|
||||
# Bind routes according to service name
|
||||
self._setup_routes()
|
||||
self.async_manager = AsyncTaskManager()
|
||||
|
||||
def sync_task(self, task: TaskReq):
|
||||
if task.cmd == "submit":
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"status": "init",
|
||||
"result": self.async_manager.submit_task(run_main_solver, task),
|
||||
}
|
||||
elif task.cmd == "query":
|
||||
return self.async_manager.get_task_result(task_id=task.req_id)
|
||||
else:
|
||||
return {
|
||||
"success": False,
|
||||
"status": "failed",
|
||||
"result": f"invalid input cmd {task.cmd}",
|
||||
}
|
||||
|
||||
def _setup_routes(self):
|
||||
"""Dynamically bind routes according to service name"""
|
||||
|
||||
@self.app.post("/process")
|
||||
def process(req: FeatureRequest):
|
||||
return self.sync_task(task=req.features.task_req)
|
||||
|
||||
def run(self, host="0.0.0.0", port=8000):
|
||||
"""Start the service"""
|
||||
print(f"Starting {self.service_name} service on {host}:{port}")
|
||||
uvicorn.run(self.app, host=host, port=port)
|
|
@ -0,0 +1,122 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, model_validator, field_serializer
|
||||
|
||||
|
||||
class ReqBody(BaseModel):
|
||||
"""Request body model containing query parameters"""
|
||||
|
||||
query: str = ""
|
||||
report: bool = True
|
||||
host_addr: str = ""
|
||||
|
||||
|
||||
class TaskReq(BaseModel):
|
||||
"""Task request model with validation logic"""
|
||||
|
||||
app_id: int = ""
|
||||
project_id: int = 0
|
||||
req_id: str = ""
|
||||
cmd: str = ""
|
||||
mode: str = ""
|
||||
req: str = None
|
||||
config: str = "{}"
|
||||
|
||||
@model_validator(mode="after")
|
||||
def parse_req_to_req_body(self):
|
||||
"""Parse req string to ReqBody object and process config field"""
|
||||
try:
|
||||
import json
|
||||
|
||||
if isinstance(self.req, str):
|
||||
req_body_dict = json.loads(self.req)
|
||||
self.req = ReqBody(**req_body_dict)
|
||||
if isinstance(self.config, str) and self.config:
|
||||
config_dict = json.loads(self.config)
|
||||
self.config = config_dict
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse 'req' field to ReqBody: {e}")
|
||||
return self
|
||||
|
||||
@field_serializer("req")
|
||||
def serialize_req(self, value: object) -> object:
|
||||
"""Serialize ReqBody back to JSON string"""
|
||||
if isinstance(value, ReqBody):
|
||||
return value.model_dump_json()
|
||||
return value # Return as-is if already a string
|
||||
|
||||
|
||||
# Request model with TaskReq parsing capability
|
||||
class Request(BaseModel):
|
||||
"""Container model for task request data"""
|
||||
|
||||
in_string: str
|
||||
task_req: Optional[TaskReq] = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def parse_in_string_to_task_req(self):
|
||||
"""Convert in_string JSON string to TaskReq object"""
|
||||
try:
|
||||
import json
|
||||
|
||||
task_req_dict = json.loads(self.in_string)
|
||||
self.task_req = TaskReq(**task_req_dict)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid TaskReq JSON string: {e}")
|
||||
return self
|
||||
|
||||
|
||||
class FeatureRequest(BaseModel):
|
||||
"""Top-level request wrapper with features container"""
|
||||
|
||||
features: Request
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def feature_request_parsing():
|
||||
"""Demonstrate nested model parsing workflow"""
|
||||
# Build innermost ReqBody JSON string
|
||||
req_body = ReqBody(
|
||||
query="阿里巴巴财报中,2024年-截至9月30日止六个月的收入是多少?其中云智能集团收入是多少?占比是多少",
|
||||
report=True,
|
||||
host_addr="https://spg.alipay.com",
|
||||
)
|
||||
req_body_json = json.dumps(req_body.model_dump())
|
||||
|
||||
# Build TaskReq dictionary and serialize to string
|
||||
task_req = TaskReq(
|
||||
req_id="9400110",
|
||||
cmd="submit",
|
||||
mode="async",
|
||||
req=req_body_json,
|
||||
app_id="app_id",
|
||||
project_id=4200050,
|
||||
config={"timeout": 10},
|
||||
)
|
||||
task_req_json = json.dumps(task_req.model_dump())
|
||||
|
||||
# Construct final FeatureRequest JSON string
|
||||
input_data = {"features": {"in_string": task_req_json}}
|
||||
|
||||
# Deserialize to FeatureRequest model
|
||||
feature_request = FeatureRequest(**input_data)
|
||||
|
||||
# Validate in_string parsed to TaskReq
|
||||
assert isinstance(feature_request.features.task_req, TaskReq)
|
||||
assert feature_request.features.task_req.req_id == "abc123"
|
||||
assert feature_request.features.task_req.cmd == "run"
|
||||
assert feature_request.features.task_req.mode == "sync"
|
||||
assert feature_request.features.task_req.config == {"timeout": 10}
|
||||
|
||||
# Validate TaskReq.req parsed to ReqBody
|
||||
req_body_parsed = feature_request.features.task_req.req
|
||||
assert isinstance(req_body_parsed, ReqBody)
|
||||
assert req_body_parsed.query == "What is AI?"
|
||||
assert req_body_parsed.report is True
|
||||
assert req_body_parsed.host_addr == "localhost"
|
||||
|
||||
print("✅ All assertions passed!")
|
||||
|
||||
feature_request_parsing()
|
|
@ -0,0 +1 @@
|
|||
fastapi
|
|
@ -48,13 +48,16 @@ class Environment:
|
|||
@property
|
||||
def config(self):
|
||||
|
||||
closest_config = self._closest_config()
|
||||
if not hasattr(self, "_config_path") or self._config_path != closest_config:
|
||||
self._config_path = closest_config
|
||||
self._config = self.get_config()
|
||||
try:
|
||||
closest_config = self._closest_config()
|
||||
if not hasattr(self, "_config_path") or self._config_path != closest_config:
|
||||
self._config_path = closest_config
|
||||
self._config = self.get_config()
|
||||
|
||||
if self._config is None:
|
||||
self._config = self.get_config()
|
||||
if self._config is None:
|
||||
self._config = self.get_config()
|
||||
except:
|
||||
return {}
|
||||
|
||||
return self._config
|
||||
|
||||
|
|
|
@ -22,7 +22,7 @@ numpy>=1.23.1
|
|||
pypdf
|
||||
pandas
|
||||
pycryptodome
|
||||
markdown
|
||||
markdown==3.7
|
||||
bs4
|
||||
protobuf==3.20.1
|
||||
neo4j
|
||||
|
|
|
@ -0,0 +1,64 @@
|
|||
from kag.common.utils import extract_tag_content
|
||||
|
||||
|
||||
def run_extra_tag():
|
||||
test_cases = [
|
||||
{
|
||||
"input": "<tag1>abced</tag1>some word<tag2>other tags</tag2>",
|
||||
"expected": [("tag1", "abced"), ("", "some word"), ("tag2", "other tags")],
|
||||
"description": "基本闭合标签与无标签文本混合",
|
||||
},
|
||||
{
|
||||
"input": "<p>Hello <b>world</b> this is <i>test</i>",
|
||||
"expected": [
|
||||
("p", "Hello "),
|
||||
("b", "world"),
|
||||
("", " this is "),
|
||||
("i", "test"),
|
||||
],
|
||||
"description": "混合闭合与未闭合标签",
|
||||
},
|
||||
{
|
||||
"input": "plain text without any tags",
|
||||
"expected": [("", "plain text without any tags")],
|
||||
"description": "纯文本无标签",
|
||||
},
|
||||
{
|
||||
"input": "<div>\n Line 1\n <span>Line 2</span>\n Line 3\n</div>",
|
||||
"expected": [
|
||||
("div", "\n Line 1\n <span>Line 2</span>\n Line 3\n")
|
||||
],
|
||||
"description": "多行内容和空白处理",
|
||||
},
|
||||
{
|
||||
"input": "<a>A</a><b>B</b><c>C</c>",
|
||||
"expected": [("a", "A"), ("b", "B"), ("c", "C")],
|
||||
"description": "连续多个闭合标签",
|
||||
},
|
||||
{
|
||||
"input": "<title>My Document</title><content>This is the content",
|
||||
"expected": [("title", "My Document"), ("content", "This is the content")],
|
||||
"description": "未闭合标签(EOF结尾)",
|
||||
},
|
||||
{
|
||||
"input": "<log>Error: &*^%$#@!;</log><note>End of log</note>",
|
||||
"expected": [("log", "Error: &*^%$#@!;"), ("note", "End of log")],
|
||||
"description": "含特殊字符的内容",
|
||||
},
|
||||
{
|
||||
"input": "",
|
||||
"expected": [],
|
||||
"description": "空字符串输入",
|
||||
},
|
||||
]
|
||||
|
||||
for i, test in enumerate(test_cases):
|
||||
result = extract_tag_content(test["input"])
|
||||
assert (
|
||||
result == test["expected"]
|
||||
), f"Test {i+1} failed: {test['description']}\nGot: {result}\nExpected: {test['expected']}"
|
||||
print(f"Test {i+1} passed: {test['description']}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_extra_tag()
|
Loading…
Reference in New Issue