feat(solver): support kag thinker (#640)

* feat(kag): update to v0.7 (#456)

* add think cost

* update csv scanner

* add final rerank

* add reasoner

* add iterative planner

* fix dpr search

* fix dpr search

* add reference data

* move odps import

* update requirement.txt

* update 2wiki

* add missing file

* fix markdown reader

* add iterative planning

* update version

* update runner

* update 2wiki example

* update bridge

* merge solver and solver_new

* add cur day

* writer delete

* update multi process

* add missing files

* fix report

* add chunk retrieved executor

* update try in stream runner result

* add path

* add math executor

* update hotpotqa example

* remove log

* fix python coder solver

* update hotpotqa example

* fix python coder solver

* update config

* fix bad

* add log

* remove unused code

* commit with task thought

* move kag model to common

* add default chat llm

* fix

* use static planner

* support chunk graph node

* add args

* support naive rag

* llm client support tool calls

* add default async

* add openai

* fix result

* fix markdown reader

* fix thinker

* update asyncio interface

* feat(solver): add mcp support (#444)

* 上传mcp client相关代码

* 1、完成一套mcp client的调用,从pipeline到planner、executor
2、允许json中传入多个mcp_server,通过大模型进行调用并选择
3、调通baidu_map_mcp的使用

* 1、schema

* bugfix:删减冗余代码

---------

Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com>

* fix affairqa after solver refactor

* fix affairqa after solver refactor

* fix readme

* add params

* update version

* update mcp executor

* update mcp executor

* solver add mcp executor

* add missing file

* add mpc executor

* add executor

* x

* update

* fix requirement

* fix main llm config

* fix solver

* bugfix:修复invoke函数调用逻辑

* chg eva

* update example

* add kag layer

* add step task

* support dot refresh

* support dot refresh

* support dot refresh

* support dot refresh

* add retrieved num

* add retrieved num

* add pipelineconf

* update ppr

* update musique prompts

* update

* add to_dict for BuilderComponentData

* async build

* add deduce prompt

* add deduce prompt

* add deduce prompt

* fix reader

* add deduce prompt

* add page thinker report

* modify prmpt

* add step status

* add self cognition

* add self cognition

* add memory graph storage

* add now time

* update memory config

* add now time

* chg graph loader

* 添加prqa数据集和代码

* bugfix:prqa调用逻辑修复

* optimize:优化代码逻辑,生成答案规范化

* add retry py code

* update memory graph

* update memory graph

* fix

* fix ner

* add with_out_refer generator prompt

* fix

* close ckpt

* fix query

* fix query

* update version

* add llm checker

* add llm checker

* 1、上传evalutor.py以及修改gold_answer.json格式
2、优化代码逻辑
3、修改README.md文件

* update exp

* update exp

* rerank support

* add static rewrite query

* recall more chunks

* fix graph load

* add static rewrite query

* fix bugs

* add finish check

* add finish check

* add finish check

* add finish check

* 1、上传evalutor.py的结果
2、优化代码逻辑,优化readme文件

* add lf retry

* add memory graph api

* fix reader api

* add ner

* add metrics

* fix bug

* remove ner

* add reraise fo retry

* add edge prop to memory graph

* add memory graph

* 1、评测数据集结果修正
2、优化evaluator.py代码
3、删除结果不存在而gold_answer中有答案的问题

* 删除评测结果文件

* fix knext host addr

* async eva

* add lf prompt

* add lf prompt

* add config

* add retry

* add unknown check

* add rc result

* add rc result

* add rc result

* add rc result

* 依据kag pipeline格式修改代码逻辑并通过测试

* bugfix:删除冗余代码

* fix report prompt

* bugfix:触发重试机制

* bugfix:中文符号错误

* fix rethinker prompt

* update version to 0.6.2b78

* update version

* 1、修改evaluator.py,通过大模型计算准确率,符合最新调用逻辑
2、修改prompt,让没有回答的结果重复测试

* update affairqa for evaluate

* update affairqa for evaluate

* bugfix:修正数据集

* bugfix:修正数据集

* bugfix:修正数据集

* fix name conflict

* bugfix:删除错误问题

* bugfix:文件名命名错误导致evaluator失败

* update for affairqa eval

* bugfix:修改代码保持evaluate逻辑一致

* x

* update for affairqa readme

* remove temp eval scripts

* bugfix for math deduce

* merge 0.6.2_dev

* merge 0.6.2_dev

* fix

* update client addr

* updated version

* update for affairqa eval

* evaUtils 支持中文

* fix affairqa eval:

* remove unused example

* update kag config

* fix default value

* update readme

* fix init

* 注释信息修改,并添加部分class说明

* update example config

* Tc 0.7.0 (#459)

* 提交affairQA 代码

* fix affairqa eval

---------

Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>

* fix all examples

* reformat

---------

Co-authored-by: peilong <peilong.zpl@antgroup.com>
Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com>
Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com>
Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>

* update chunk metadata

* update chunk metadata

* add debug reporter

* update table text

* add server

* fix math executor

* update api-key for openai vec

* update

* fix naive rag bug

* format code

* fix

---------

Co-authored-by: zhuzhongshu123 <152354526+zhuzhongshu123@users.noreply.github.com>
Co-authored-by: 锦呈 <zhangxinhong.zxh@antgroup.com>
Co-authored-by: wanxingyu.wxy <wanxingyu.wxy@antgroup.com>
Co-authored-by: zhengke.gzk <zhengke.gzk@antgroup.com>
This commit is contained in:
royzhao 2025-07-08 17:44:32 +08:00 committed by GitHub
parent 9b2d894295
commit e1012d39e4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 696 additions and 175 deletions

View File

@ -1 +1 @@
0.8.0
0.8.0

View File

@ -463,9 +463,17 @@ def resolve_instance(
def extract_tag_content(text):
# 匹配<tag>和</tag>之间的内容,支持任意标签名
matches = re.findall(r"<([^>]+)>(.*?)</\1>", text, flags=re.DOTALL)
return [(tag, content.strip()) for tag, content in matches]
pattern = r"<(\w+)\b[^>]*>(.*?)</\1>|<(\w+)\b[^>]*>([^<]*)|([^<]+)"
results = []
for match in re.finditer(pattern, text, re.DOTALL):
tag1, content1, tag2, content2, raw_text = match.groups()
if tag1:
results.append((tag1, content1)) # 保留原始内容(含空格)
elif tag2:
results.append((tag2, content2)) # 保留原始内容(含空格)
elif raw_text:
results.append(("", raw_text)) # 保留原始空格
return results
def extract_specific_tag_content(text, tag):

View File

@ -131,9 +131,11 @@ class PyBasedMathExecutor(ExecutorABC):
)
parent_results = format_task_dep_context(task.parents)
parent_results = "\n".join(parent_results)
coder_content = context.kwargs.get("planner_thought", "") + "\n\n".join(
parent_results
)
parent_results += "\n\n" + contents
coder_content += "\n\n" + contents
tries = self.tries
error = None
@ -141,7 +143,7 @@ class PyBasedMathExecutor(ExecutorABC):
tries -= 1
rst, error, code = self.run_once(
math_query,
parent_results,
coder_content,
error,
segment_name=tag_id,
tag_name=f"{task_query}_code_generator",

View File

@ -42,6 +42,15 @@ from kag.solver.utils import init_prompt_with_fallback
logger = logging.getLogger()
def _wrapped_invoke(retriever, task, context, segment_name, kwargs):
start_time = time.time()
output = retriever.invoke(
task, context=context, segment_name=segment_name, **kwargs
)
elapsed_time = time.time() - start_time
return output, elapsed_time
@ExecutorABC.register("kag_hybrid_retrieval_executor")
class KAGHybridRetrievalExecutor(ExecutorABC):
def __init__(
@ -76,6 +85,7 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
self.context_select_prompt = context_select_prompt or PromptABC.from_config(
{"type": "context_select_prompt"}
)
self.with_llm_select = kwargs.get("with_llm_select", True)
@retry(stop=stop_after_attempt(3), wait=wait_exponential(multiplier=1))
def context_select_call(self, variables):
@ -152,22 +162,30 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
"FINISH",
component_name=retriever.name,
)
# Record start time before submitting the task
start_time = time.time()
# Prepare function and submit to thread pool
func = partial(
retriever.invoke,
_wrapped_invoke,
retriever,
task,
context=context,
segment_name=tag_id,
**kwargs,
context,
tag_id,
kwargs.copy(),
)
future = executor.submit(func)
# Save future, retriever, and start_time together
futures.append((future, retriever))
# Collect results from each future
for future, retriever in futures:
try:
output = future.result() # Wait for result
output, elapsed_time = future.result() # Wait for result
# Log the elapsed time for this retriever
logger.info(
f"Retriever {retriever.name} executed in {elapsed_time:.2f} seconds"
)
outputs.append(output)
# Log data report after successful execution
@ -241,13 +259,18 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
selected_rel = list(set(selected_rel))
formatted_docs = [str(rel) for rel in selected_rel]
if retrieved_data.chunks:
try:
selected_chunks = self.context_select(task_query, retrieved_data.chunks)
except Exception as e:
logger.warning(
f"select context failed {e}, we use default top 10 to summary",
exc_info=True,
)
if self.with_llm_select:
try:
selected_chunks = self.context_select(
task_query, retrieved_data.chunks
)
except Exception as e:
logger.warning(
f"select context failed {e}, we use default top 10 to summary",
exc_info=True,
)
selected_chunks = retrieved_data.chunks[:10]
else:
selected_chunks = retrieved_data.chunks[:10]
for doc in selected_chunks:
formatted_docs.append(f"{doc.content}")
@ -280,68 +303,81 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
task_query = task.arguments["query"]
tag_id = f"{task_query}_begin_task"
self.report_content(reporter, "thinker", tag_id, "", "FINISH", step=task.name)
self.report_content(reporter, "thinker", tag_id, "", "INIT", step=task.name)
try:
retrieved_data = self.do_main(task_query, tag_id, task, context, **kwargs)
except Exception as e:
logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True)
retrieved_data = RetrieverOutput(
retriever_method=self.schema().get("name", ""), err_msg=str(e)
)
self.report_content(
reporter,
"reference",
f"{task_query}_kag_retriever_result",
retrieved_data,
"FINISH",
)
retrieved_data.task = task
logical_node = task.arguments.get("logic_form_node", None)
if (
logical_node
and isinstance(logical_node, GetSPONode)
and retrieved_data.summary
):
if isinstance(retrieved_data.summary, str):
target_answer = retrieved_data.summary.split("Answer:")[-1].strip()
s_entities = context.variables_graph.get_entity_by_alias(
logical_node.s.alias_name
try:
retrieved_data = self.do_main(
task_query, tag_id, task, context, **kwargs
)
if (
not s_entities
and not logical_node.s.get_mention_name()
and isinstance(logical_node.s, SPOEntity)
):
logical_node.s.entity_name = target_answer
context.kwargs[logical_node.s.alias_name] = logical_node.s
o_entities = context.variables_graph.get_entity_by_alias(
logical_node.o.alias_name
except Exception as e:
logger.warning(f"kag hybrid retrieval failed! {e}", exc_info=True)
retrieved_data = RetrieverOutput(
retriever_method=self.schema().get("name", ""), err_msg=str(e)
)
if (
not o_entities
and not logical_node.o.get_mention_name()
and isinstance(logical_node.o, SPOEntity)
):
logical_node.o.entity_name = target_answer
context.kwargs[logical_node.o.alias_name] = logical_node.o
context.variables_graph.add_answered_alias(
logical_node.s.alias_name.alias_name, retrieved_data.summary
)
context.variables_graph.add_answered_alias(
logical_node.p.alias_name.alias_name, retrieved_data.summary
)
context.variables_graph.add_answered_alias(
logical_node.o.alias_name.alias_name, retrieved_data.summary
self.report_content(
reporter,
"reference",
f"{task_query}_kag_retriever_result",
retrieved_data,
"FINISH",
)
task.update_result(retrieved_data)
logger.debug(
f"kag hybrid retrieval {task_query} cost={time.time() - start_time}"
)
return retrieved_data
retrieved_data.task = task
logical_node = task.arguments.get("logic_form_node", None)
if (
logical_node
and isinstance(logical_node, GetSPONode)
and retrieved_data.summary
):
if isinstance(retrieved_data.summary, str):
target_answer = retrieved_data.summary.split("Answer:")[-1].strip()
s_entities = context.variables_graph.get_entity_by_alias(
logical_node.s.alias_name
)
if (
not s_entities
and not logical_node.s.get_mention_name()
and isinstance(logical_node.s, SPOEntity)
):
logical_node.s.entity_name = target_answer
context.kwargs[logical_node.s.alias_name] = logical_node.s
o_entities = context.variables_graph.get_entity_by_alias(
logical_node.o.alias_name
)
if (
not o_entities
and not logical_node.o.get_mention_name()
and isinstance(logical_node.o, SPOEntity)
):
logical_node.o.entity_name = target_answer
context.kwargs[logical_node.o.alias_name] = logical_node.o
context.variables_graph.add_answered_alias(
logical_node.s.alias_name.alias_name, retrieved_data.summary
)
context.variables_graph.add_answered_alias(
logical_node.p.alias_name.alias_name, retrieved_data.summary
)
context.variables_graph.add_answered_alias(
logical_node.o.alias_name.alias_name, retrieved_data.summary
)
task.update_result(retrieved_data)
logger.debug(
f"kag hybrid retrieval {task_query} cost={time.time() - start_time}"
)
return retrieved_data
finally:
self.report_content(
reporter,
"thinker",
tag_id,
"",
"FINISH",
step=task.name,
overwrite=False,
)
def schema(self) -> dict:
"""Function schema definition for OpenAI Function Calling
@ -403,7 +439,7 @@ class KAGHybridRetrievalExecutor(ExecutorABC):
node_type=chunk.properties.get("__labels__"),
)
entity_prop = dict(chunk.properties) if chunk.properties else {}
entity_prop["content"] = chunk.content
entity_prop["content"] = f"{chunk.content[:10]}..."
entity_prop["score"] = chunk.score
entity.prop = Prop.from_dict(entity_prop, "Chunk", None)
chunk_graph.append(entity)

View File

@ -140,8 +140,6 @@ def get_pipeline_conf(use_pipeline_name, config):
raise RuntimeError("mcpServers not found in config.")
default_solver_pipeline["executors"] = mcp_executors
# update KAG_CONFIG
KAG_CONFIG.update_conf(default_pipeline_conf)
return default_solver_pipeline
@ -167,8 +165,11 @@ async def do_qa_pipeline(
f"Knowledge base with id {kb_project_id} not found in qa_config['kb']"
)
continue
for index_name in matched_kb.get("index_list", []):
index_list = matched_kb.get("index_list", [])
if use_pipeline in ["default_pipeline"]:
# we only use chunk index
index_list = ["chunk_index"]
for index_name in index_list:
index_manager = KAGIndexManager.from_config(
{
"type": index_name,
@ -339,7 +340,7 @@ class SolverMain:
def invoke(
self,
project_id: int,
task_id: int,
task_id,
query: str,
session_id: str = "0",
is_report=True,

View File

@ -3,20 +3,17 @@ pipeline_name: default_pipeline
#------------kag-solver configuration start----------------#
chunk_retrieved_executor: &chunk_retrieved_executor_conf
type: chunk_retrieved_executor
top_k: 10
retriever:
type: vector_chunk_retriever
score_threshold: 0.65
vectorize_model: "{vectorize_model}"
kag_retriever_executor: &kag_retriever_executor_conf
type: kag_hybrid_retrieval_executor
retrievers: "{retrievers}"
merger:
type: kag_merger
enable_summary: false
solver_pipeline:
type: naive_rag_pipeline
executors:
- *chunk_retrieved_executor_conf
- *kag_retriever_executor_conf
generator:
type: llm_index_generator
llm_client: "{chat_llm}"

View File

@ -186,6 +186,7 @@ class KAGModelPlanner(PlannerABC):
.replace("</answer>", "")
.strip()
)
context.kwargs["planner_thought"] = logic_form_response
sub_queries, logic_forms = parse_logic_form_with_str(logic_form_str)
logic_forms = self.logic_node_parser.parse_logic_form_set(

View File

@ -11,42 +11,49 @@ logger = logging.getLogger(__name__)
class ExpressionBuildr(PromptABC):
template_zh = (
f"今天是{get_now(language='zh')}"
+ """\n# instruction
+ """
# instruction
根据给出的问题和数据编写python代码输出问题结果
为了便于理解输出从context中提取的数据输出中间计算过程和结果
注意严格根据输入内容进行编写代码,不允许进行假设
例如伤残等级如果context中未提及,则认为没有被认定为残疾
如果无法回答问题直接返回I don't know.
从context中提取的数据必须显式赋值所有计算步骤必须用代码实现不得隐含推断
必须输出中间计算过程和结果格式为print语句
如果context未提供必要数据或无法计算直接打印"I don't know."
# output format
直接输出python代码python版本为3.10不要包含任何其他信息
严格输出以下结构的python代码版本3.10
1. 数据提取部分代码中涉及输入的数值需要从context及question中提取不允许进行假设
2. 计算过程分步实现所有数学运算每个步骤对应独立变量
3. 输出每个中间变量和最终结果必须用print语句输出
# examples
## 例子1
### input
#### question
47000元按照万分之1.5一共612天计算利息一共多少钱
4百万元按照日利率万分之1.5一共612天计算利息一共多少钱
#### context
日利率万分之1.5
### output
```python
# 初始本金
principal = 47000
# 初始本金(单位:百万)
principal = 4 # 单位:百万
# 利率万分之1.5
rate = 1.5 / 10000
# 利率计算万分之1.5
daily_rate = 1.5 / 10000
# 天数
# 计算周期
days = 612
# 计算年利率
annual_rate = rate * 365
# 单日利息计算
daily_interest = principal * daily_rate
# 利息
interest = principal * (annual_rate / 365) * days
# 计利息计算
total_interest = daily_interest * days
# 输出总金额(本金+利息)
total_amount = principal + interest
# 总金额计算
total_amount = principal + total_interest
print(f"总金额:{total_amount:.2f}")
print(f"单日利息:{daily_interest:.2f}百万")
print(f"累计利息:{total_interest:.2f}百万")
print(f"总金额:{total_amount:.2f}百万")
```
## 例子2
@ -70,13 +77,26 @@ revenue_2020 = revenue_2019 * (1 + growth_rate)
print(f"2020年的预计收入为: {revenue_2020:.2f}")
```
## 例子3
### input
#### question
47000元按照612天计算利息本息一共多少钱
#### content
### output
```python
# 未给出利率,无法计算
print("未给出利率,无法计算")
```
# input
## question
$question
## context
$context
## error
$error"""
$error
## output
"""
)
template_en = (
f"Today is {get_now(language='en')}\n"

View File

@ -1,5 +1,6 @@
import logging
import re
import time
from kag.common.conf import KAG_PROJECT_CONF
from kag.common.parser.logic_node_parser import extract_steps_and_actions
@ -72,14 +73,16 @@ def process_tag_template(text):
}
clean_text = ""
for tag_info in all_tags:
content = tag_info[1]
if tag_info[0] in xml_tag_template:
content = tag_info[1]
if "search" == tag_info[0]:
content = process_planning(content)
clean_text += xml_tag_template[tag_info[0]][
KAG_PROJECT_CONF.language
].format_map(SafeDict({"content": content}))
return remove_xml_tags(clean_text)
else:
clean_text += content
text = remove_xml_tags(clean_text)
return text

View File

@ -159,12 +159,12 @@ def render_jinja2_template(template_str, context):
"""
try:
template = Template(template_str, undefined=SilentUndefined)
return template.render(**context).strip()
return template.render(**context)
except Exception as e:
logging.error(
f"Jinja2 rendering failed: {e}, Original template: {template_str}"
)
return template_str.strip() # Fallback to raw template string on failure
return template_str # Fallback to raw template string on failure
@ReporterABC.register("open_spg_reporter")
@ -264,12 +264,12 @@ Rerank the documents and take the top {{ chunk_num }}.
}
self.tag_mapping = {
"Graph Show": {
"en": "{content}",
"zh": "{content}",
"en": "{{ content }}",
"zh": "{{ content }}",
},
"Rewrite query": {
"en": "Rethinking question using LLM: {content}",
"zh": "根据依赖问题重写子问题: {content}",
"en": "Rethinking question using LLM: {{ content }}",
"zh": "根据依赖问题重写子问题: {{ content }}",
},
"language_setting": {
"en": "",
@ -277,125 +277,153 @@ Rerank the documents and take the top {{ chunk_num }}.
},
"Iterative planning": {
"en": """
<step status="{status}" title="Global planning">
<step status="{{status}}" title="Global planning">
{content}
{{ content }}
</step>""",
{% if status == 'success' %}
</step>
{% endif %}""",
"zh": """
<step status="{status}" title="思考当前步骤">
<step status="{{status}}" title="思考当前步骤">
{content}
{{ content }}
</step>""",
{% if status == 'success' %}
</step>
{% endif %}""",
},
"Static planning": {
"en": """
<step status="{status}" title="Global planning">
<step status="{{status}}" title="Global planning">
{content}
{{ content }}
</step>""",
{% if status == 'success' %}
</step>
{% endif %}""",
"zh": """
<step status="{status}" title="思考全局步骤">
<step status="{{status}}" title="思考全局步骤">
{content}
{{ content }}
</step>""",
{% if status == 'success' %}
</step>
{% endif %}""",
},
"begin_sub_kag_retriever": {
"en": "Starting {component_name}: {content} {desc}",
"zh": "执行{component_name}: {content} {desc}",
"en": "Starting {{component_name}}: {{content}} {{desc}}",
"zh": "执行{{component_name}}: {{content}} {{desc}}",
},
"end_sub_kag_retriever": {
"en": " {content}",
"zh": " {content}",
"en": " {{ content }}",
"zh": " {{ content }}",
},
"rc_retriever_rewrite": {
"en": """
<step status="{status}" title="Rewriting chunk retriever query">
<step status="{{status}}" title="Rewriting chunk retriever query">
Rewritten question:\n{content}
Rewritten question:
{{ content }}
</step>""",
{% if status == 'success' %}
</step>
{% endif %}""",
"zh": """
<step status="{status}" title="正在根据依赖问题重写检索子问题">
<step status="{{status}}" title="正在根据依赖问题重写检索子问题">
重写问题为\n\n{content}
重写问题为
{{ content }}
</step>""",
{% if status == 'success' %}
</step>
{% endif %}""",
},
"rc_retriever_summary": {
"en": "Summarizing retrieved documents,{content}",
"zh": "对文档进行总结,{content}",
"en": "Summarizing retrieved documents,{{ content }}",
"zh": "对文档进行总结,{{ content }}",
},
"kg_retriever_summary": {
"en": "Summarizing retrieved graph,{content}",
"zh": "对召回的知识进行总结,{content}",
"en": "Summarizing retrieved graph,{{ content }}",
"zh": "对召回的知识进行总结,{{ content }}",
},
"retriever_summary": {
"en": "Summarizing retrieved documents,{content}",
"zh": "对文档进行总结,{content}",
"en": "Summarizing retrieved documents,{{ content }}",
"zh": "对文档进行总结,{{ content }}",
},
"begin_summary": {
"en": "Summarizing retrieved information, {content}",
"zh": "对检索的信息进行总结, {content}",
"en": "Summarizing retrieved information, {{ content }}",
"zh": "对检索的信息进行总结, {{ content }}",
},
"begin_task": {
"en": """
<step status="{status}" title="Starting Task {step}">
<step status="{{status}}" title="Starting Task {{step}}">
{content}
{{ content }}
</step>""",
{% if status == 'success' %}
</step>
{% endif %}""",
"zh": """
<step status="{status}" title="执行 {step}">
<step status="{{status}}" title="执行 {{step}}">
{content}
{{ content }}
</step>""",
{% if status == 'success' %}
</step>
{% endif %}""",
},
"logic_node": {
"en": """Translate query to logic form expression
```json
{content}
{{ content }}
```""",
"zh": """将query转换成逻辑形式表达
```json
{content}
{{ content }}
```""",
},
"kag_retriever_result": {
"en": "Retrieved documents\n\n{content}",
"zh": "检索到的文档\n\n{content}",
"en": """Retrieved documents
{{ content }}""",
"zh": """检索到的文档
{{ content }}""",
},
"failed_kag_retriever": {
"en": """KAG retriever failed
```json
{content}
{{ content }}
```
""",
"zh": """KAG检索失败
```json
{content}
{{ content }}
```
""",
},
"end_math_executor": {
"en": "Math executor completed\n\n{content}",
"zh": "计算结束\n\n{content}",
"en": """Math executor completed
{{ content }}""",
"zh": """计算结束
{{ content }}""",
},
"code_generator": {
"en": "Generating code\n \n{content}\n",
"zh": "正在生成代码\n \n{content}\n",
"en": """Generating code
{{ content }}
""",
"zh": """正在生成代码
{{ content }}
""",
},
}
task_id = kwargs.get(KAGConstants.KAG_QA_TASK_CONFIG_KEY, None)
@ -425,7 +453,11 @@ Rewritten question:\n{content}
if tpl:
format_params = {"content": datas}
format_params.update(content_params)
datas = tpl.format_map(SafeDict(format_params))
if "{" in tpl or "%}" in tpl:
rendered = render_jinja2_template(tpl, format_params)
else:
rendered = tpl.format_map(SafeDict(format_params))
datas = rendered
elif str(datas).strip() != "":
output = str(datas).strip()
if output != "":
@ -516,7 +548,7 @@ Rewritten question:\n{content}
if self.last_report.to_dict() == request.to_dict():
return
logger.info(
f"do_report: {content.answer} think={content.think} status={status_enum} ret={ret}"
f"do_report: think={content.think} {content.answer} status={status_enum} ret={ret}"
)
self.last_report = request

View File

View File

@ -0,0 +1,168 @@
import concurrent.futures
import queue
import threading
import time
import uuid
import logging
from cachetools import TTLCache
logger = logging.getLogger()
class AsyncTaskManager:
def __init__(self, max_workers=10, ttl=3600):
"""
Initialize async task manager
Args:
max_workers (int): Maximum number of worker threads
ttl (int): Time-to-live for task results in seconds
"""
self.max_workers = max_workers
self.task_queue = queue.Queue()
self.result_cache = TTLCache(maxsize=1000, ttl=ttl)
self.result_cache_lock = threading.Lock() # Protect cache from race conditions
self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
self.workers = [
threading.Thread(target=self.worker, daemon=True)
for _ in range(max_workers)
]
for w in self.workers:
w.start()
def worker(self):
"""Worker thread main loop that processes tasks"""
while True:
try:
# Get next task from queue with timeout to allow shutdown detection
task = self.task_queue.get()
task_id, func, args, kwargs = task
logger.info(f"Processing task {task_id}")
# finish flag
if task_id is None:
self.task_queue.task_done()
break
# Update cache with running status
with self.result_cache_lock:
self.result_cache[task_id] = {
"task_id": task_id,
"status": "running",
"result": None,
}
# Execute task
future = self.executor.submit(func, *args, **kwargs)
result = future.result()
status = "completed"
except queue.Empty:
# Handle queue empty timeout (normal operation)
continue
except Exception as e:
# Handle task execution errors
result = str(e)
status = "failed"
logger.error(f"Task {task_id} failed with error: {e}", exc_info=True)
# Store final result in cache
try:
with self.result_cache_lock:
self.result_cache[task_id] = {
"task_id": task_id,
"status": status,
"result": result,
}
logger.info(f"Task {task_id} completed with status: {status}")
finally:
# Always mark task as done
self.task_queue.task_done()
def submit_task(self, func, *args, **kwargs):
"""
Submit a new task to the queue
Args:
func: Callable function to execute
*args: Positional arguments for the function
**kwargs: Keyword arguments for the function
Returns:
str: Unique task ID
"""
task_id = str(uuid.uuid4())
self.task_queue.put((task_id, func, args, kwargs))
return task_id
def get_task_result(self, task_id):
"""
Get result for a specific task
Args:
task_id (str): Unique task identifier
Returns:
dict: Task result information or expired status
"""
with self.result_cache_lock:
return self.result_cache.get(
task_id,
{
"task_id": task_id,
"status": "failed",
"result": "Result not found or expired",
},
)
def shutdown(self):
"""Gracefully shutdown all worker threads and executors"""
# Send shutdown signals
for _ in range(self.max_workers):
self.task_queue.put((None, None, (), {}))
# Wait for queue to empty and workers to terminate
self.task_queue.join()
# Shutdown executors
self.executor.shutdown(wait=True)
for worker in self.workers:
worker.join(timeout=5)
# Global async task manager instance
asyn_task = AsyncTaskManager()
if __name__ == "__main__":
# Configure logging
logging.basicConfig(level=logging.INFO)
# Create task manager instance
task_manager = AsyncTaskManager(max_workers=5, ttl=600)
# Example task function
def example_task(x, y):
time.sleep(1) # Simulate work
return x
# Submit test tasks
task_ids = [task_manager.submit_task(example_task, i, i + 1) for i in range(6)]
# Monitor task progress
try:
while True:
time.sleep(1)
if all(
"completed" in task_manager.get_task_result(tid)["status"]
for tid in task_ids
):
break
except KeyboardInterrupt:
logger.info("Shutting down due to user interrupt")
# Print results
for task_id in task_ids:
print(f"Task {task_id} result: {task_manager.get_task_result(task_id)}")
# Clean up resources
task_manager.shutdown()

View File

View File

@ -0,0 +1,63 @@
from fastapi import FastAPI
import uvicorn
from kag.solver.main_solver import SolverMain
from kag.solver.server.asyn_task_manager import AsyncTaskManager
from kag.solver.server.model.task_req import FeatureRequest, TaskReq
def run_main_solver(task: TaskReq):
return SolverMain().invoke(
project_id=task.project_id,
task_id=task.req_id,
query=task.req.query,
is_report=task.req.report,
host_addr=task.req.host_addr,
app_id=task.app_id,
params=task.config,
)
class KAGSolverServer:
def __init__(self, service_name: str):
"""
Initialize a FastAPI service instance
Args:
service_name (str): Service name, determines which routing logic to load
"""
self.service_name = service_name
self.app = FastAPI(title=f"{service_name} API")
# Bind routes according to service name
self._setup_routes()
self.async_manager = AsyncTaskManager()
def sync_task(self, task: TaskReq):
if task.cmd == "submit":
return {
"success": True,
"status": "init",
"result": self.async_manager.submit_task(run_main_solver, task),
}
elif task.cmd == "query":
return self.async_manager.get_task_result(task_id=task.req_id)
else:
return {
"success": False,
"status": "failed",
"result": f"invalid input cmd {task.cmd}",
}
def _setup_routes(self):
"""Dynamically bind routes according to service name"""
@self.app.post("/process")
def process(req: FeatureRequest):
return self.sync_task(task=req.features.task_req)
def run(self, host="0.0.0.0", port=8000):
"""Start the service"""
print(f"Starting {self.service_name} service on {host}:{port}")
uvicorn.run(self.app, host=host, port=port)

View File

View File

@ -0,0 +1,122 @@
import json
from typing import Optional
from pydantic import BaseModel, model_validator, field_serializer
class ReqBody(BaseModel):
"""Request body model containing query parameters"""
query: str = ""
report: bool = True
host_addr: str = ""
class TaskReq(BaseModel):
"""Task request model with validation logic"""
app_id: int = ""
project_id: int = 0
req_id: str = ""
cmd: str = ""
mode: str = ""
req: str = None
config: str = "{}"
@model_validator(mode="after")
def parse_req_to_req_body(self):
"""Parse req string to ReqBody object and process config field"""
try:
import json
if isinstance(self.req, str):
req_body_dict = json.loads(self.req)
self.req = ReqBody(**req_body_dict)
if isinstance(self.config, str) and self.config:
config_dict = json.loads(self.config)
self.config = config_dict
except Exception as e:
raise ValueError(f"Failed to parse 'req' field to ReqBody: {e}")
return self
@field_serializer("req")
def serialize_req(self, value: object) -> object:
"""Serialize ReqBody back to JSON string"""
if isinstance(value, ReqBody):
return value.model_dump_json()
return value # Return as-is if already a string
# Request model with TaskReq parsing capability
class Request(BaseModel):
"""Container model for task request data"""
in_string: str
task_req: Optional[TaskReq] = None
@model_validator(mode="after")
def parse_in_string_to_task_req(self):
"""Convert in_string JSON string to TaskReq object"""
try:
import json
task_req_dict = json.loads(self.in_string)
self.task_req = TaskReq(**task_req_dict)
except Exception as e:
raise ValueError(f"Invalid TaskReq JSON string: {e}")
return self
class FeatureRequest(BaseModel):
"""Top-level request wrapper with features container"""
features: Request
if __name__ == "__main__":
def feature_request_parsing():
"""Demonstrate nested model parsing workflow"""
# Build innermost ReqBody JSON string
req_body = ReqBody(
query="阿里巴巴财报中2024年-截至9月30日止六个月的收入是多少其中云智能集团收入是多少占比是多少",
report=True,
host_addr="https://spg.alipay.com",
)
req_body_json = json.dumps(req_body.model_dump())
# Build TaskReq dictionary and serialize to string
task_req = TaskReq(
req_id="9400110",
cmd="submit",
mode="async",
req=req_body_json,
app_id="app_id",
project_id=4200050,
config={"timeout": 10},
)
task_req_json = json.dumps(task_req.model_dump())
# Construct final FeatureRequest JSON string
input_data = {"features": {"in_string": task_req_json}}
# Deserialize to FeatureRequest model
feature_request = FeatureRequest(**input_data)
# Validate in_string parsed to TaskReq
assert isinstance(feature_request.features.task_req, TaskReq)
assert feature_request.features.task_req.req_id == "abc123"
assert feature_request.features.task_req.cmd == "run"
assert feature_request.features.task_req.mode == "sync"
assert feature_request.features.task_req.config == {"timeout": 10}
# Validate TaskReq.req parsed to ReqBody
req_body_parsed = feature_request.features.task_req.req
assert isinstance(req_body_parsed, ReqBody)
assert req_body_parsed.query == "What is AI?"
assert req_body_parsed.report is True
assert req_body_parsed.host_addr == "localhost"
print("✅ All assertions passed!")
feature_request_parsing()

View File

@ -0,0 +1 @@
fastapi

View File

@ -48,13 +48,16 @@ class Environment:
@property
def config(self):
closest_config = self._closest_config()
if not hasattr(self, "_config_path") or self._config_path != closest_config:
self._config_path = closest_config
self._config = self.get_config()
try:
closest_config = self._closest_config()
if not hasattr(self, "_config_path") or self._config_path != closest_config:
self._config_path = closest_config
self._config = self.get_config()
if self._config is None:
self._config = self.get_config()
if self._config is None:
self._config = self.get_config()
except:
return {}
return self._config

View File

@ -22,7 +22,7 @@ numpy>=1.23.1
pypdf
pandas
pycryptodome
markdown
markdown==3.7
bs4
protobuf==3.20.1
neo4j

View File

@ -0,0 +1,64 @@
from kag.common.utils import extract_tag_content
def run_extra_tag():
test_cases = [
{
"input": "<tag1>abced</tag1>some word<tag2>other tags</tag2>",
"expected": [("tag1", "abced"), ("", "some word"), ("tag2", "other tags")],
"description": "基本闭合标签与无标签文本混合",
},
{
"input": "<p>Hello <b>world</b> this is <i>test</i>",
"expected": [
("p", "Hello "),
("b", "world"),
("", " this is "),
("i", "test"),
],
"description": "混合闭合与未闭合标签",
},
{
"input": "plain text without any tags",
"expected": [("", "plain text without any tags")],
"description": "纯文本无标签",
},
{
"input": "<div>\n Line 1\n <span>Line 2</span>\n Line 3\n</div>",
"expected": [
("div", "\n Line 1\n <span>Line 2</span>\n Line 3\n")
],
"description": "多行内容和空白处理",
},
{
"input": "<a>A</a><b>B</b><c>C</c>",
"expected": [("a", "A"), ("b", "B"), ("c", "C")],
"description": "连续多个闭合标签",
},
{
"input": "<title>My Document</title><content>This is the content",
"expected": [("title", "My Document"), ("content", "This is the content")],
"description": "未闭合标签EOF结尾",
},
{
"input": "<log>Error: &*^%$#@!;</log><note>End of log</note>",
"expected": [("log", "Error: &*^%$#@!;"), ("note", "End of log")],
"description": "含特殊字符的内容",
},
{
"input": "",
"expected": [],
"description": "空字符串输入",
},
]
for i, test in enumerate(test_cases):
result = extract_tag_content(test["input"])
assert (
result == test["expected"]
), f"Test {i+1} failed: {test['description']}\nGot: {result}\nExpected: {test['expected']}"
print(f"Test {i+1} passed: {test['description']}")
if __name__ == "__main__":
run_extra_tag()