mirror of https://github.com/HKUDS/LightRAG.git
Merge pull request #1897 from danielaskdd/json-repair
refactor: improve JSON parsing reliability with json-repair library
This commit is contained in:
commit
f6b90fe482
|
@ -272,7 +272,6 @@ if __name__ == "__main__":
|
|||
| **enable_llm_cache** | `bool` | 如果为`TRUE`,将LLM结果存储在缓存中;重复的提示返回缓存的响应 | `TRUE` |
|
||||
| **enable_llm_cache_for_entity_extract** | `bool` | 如果为`TRUE`,将实体提取的LLM结果存储在缓存中;适合初学者调试应用程序 | `TRUE` |
|
||||
| **addon_params** | `dict` | 附加参数,例如`{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"]}`:设置示例限制、输出语言和文档处理的批量大小 | `example_number: 所有示例, language: English` |
|
||||
| **convert_response_to_json_func** | `callable` | 未使用 | `convert_response_to_json` |
|
||||
| **embedding_cache_config** | `dict` | 问答缓存的配置。包含三个参数:`enabled`:布尔值,启用/禁用缓存查找功能。启用时,系统将在生成新答案之前检查缓存的响应。`similarity_threshold`:浮点值(0-1),相似度阈值。当新问题与缓存问题的相似度超过此阈值时,将直接返回缓存的答案而不调用LLM。`use_llm_check`:布尔值,启用/禁用LLM相似度验证。启用时,在返回缓存答案之前,将使用LLM作为二次检查来验证问题之间的相似度。 | 默认:`{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
|
||||
|
||||
</details>
|
||||
|
|
|
@ -279,7 +279,6 @@ A full list of LightRAG init parameters:
|
|||
| **enable_llm_cache** | `bool` | If `TRUE`, stores LLM results in cache; repeated prompts return cached responses | `TRUE` |
|
||||
| **enable_llm_cache_for_entity_extract** | `bool` | If `TRUE`, stores LLM results in cache for entity extraction; Good for beginners to debug your application | `TRUE` |
|
||||
| **addon_params** | `dict` | Additional parameters, e.g., `{"example_number": 1, "language": "Simplified Chinese", "entity_types": ["organization", "person", "geo", "event"]}`: sets example limit, entiy/relation extraction output language | `example_number: all examples, language: English` |
|
||||
| **convert_response_to_json_func** | `callable` | Not used | `convert_response_to_json` |
|
||||
| **embedding_cache_config** | `dict` | Configuration for question-answer caching. Contains three parameters: `enabled`: Boolean value to enable/disable cache lookup functionality. When enabled, the system will check cached responses before generating new answers. `similarity_threshold`: Float value (0-1), similarity threshold. When a new question's similarity with a cached question exceeds this threshold, the cached answer will be returned directly without calling the LLM. `use_llm_check`: Boolean value to enable/disable LLM similarity verification. When enabled, LLM will be used as a secondary check to verify the similarity between questions before returning cached answers. | Default: `{"enabled": False, "similarity_threshold": 0.95, "use_llm_check": False}` |
|
||||
|
||||
</details>
|
||||
|
|
|
@ -81,7 +81,6 @@ from .utils import (
|
|||
EmbeddingFunc,
|
||||
always_get_an_event_loop,
|
||||
compute_mdhash_id,
|
||||
convert_response_to_json,
|
||||
lazy_external_import,
|
||||
priority_limit_async_func_call,
|
||||
get_content_summary,
|
||||
|
@ -341,15 +340,6 @@ class LightRAG:
|
|||
# Storages Management
|
||||
# ---
|
||||
|
||||
convert_response_to_json_func: Callable[[str], dict[str, Any]] = field(
|
||||
default_factory=lambda: convert_response_to_json
|
||||
)
|
||||
"""
|
||||
Custom function for converting LLM responses to JSON format.
|
||||
|
||||
The default function is :func:`.utils.convert_response_to_json`.
|
||||
"""
|
||||
|
||||
cosine_better_than_threshold: float = field(
|
||||
default=float(os.getenv("COSINE_THRESHOLD", 0.2))
|
||||
)
|
||||
|
|
|
@ -23,7 +23,6 @@ from tenacity import (
|
|||
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
locate_json_string_body_from_string,
|
||||
safe_unicode_decode,
|
||||
)
|
||||
|
||||
|
@ -108,7 +107,7 @@ async def azure_openai_complete_if_cache(
|
|||
async def azure_openai_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
result = await azure_openai_complete_if_cache(
|
||||
os.getenv("LLM_MODEL", "gpt-4o-mini"),
|
||||
prompt,
|
||||
|
@ -116,8 +115,6 @@ async def azure_openai_complete(
|
|||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction: # TODO: use JSON API
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -15,10 +15,6 @@ from tenacity import (
|
|||
retry_if_exception_type,
|
||||
)
|
||||
|
||||
from lightrag.utils import (
|
||||
locate_json_string_body_from_string,
|
||||
)
|
||||
|
||||
|
||||
class BedrockError(Exception):
|
||||
"""Generic error for issues related to Amazon Bedrock"""
|
||||
|
@ -96,7 +92,7 @@ async def bedrock_complete_if_cache(
|
|||
async def bedrock_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
result = await bedrock_complete_if_cache(
|
||||
model_name,
|
||||
|
@ -105,8 +101,6 @@ async def bedrock_complete(
|
|||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction: # TODO: use JSON API
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -24,9 +24,6 @@ from lightrag.exceptions import (
|
|||
RateLimitError,
|
||||
APITimeoutError,
|
||||
)
|
||||
from lightrag.utils import (
|
||||
locate_json_string_body_from_string,
|
||||
)
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
|
@ -119,7 +116,7 @@ async def hf_model_if_cache(
|
|||
async def hf_model_complete(
|
||||
prompt, system_prompt=None, history_messages=[], keyword_extraction=False, **kwargs
|
||||
) -> str:
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
model_name = kwargs["hashing_kv"].global_config["llm_model_name"]
|
||||
result = await hf_model_if_cache(
|
||||
model_name,
|
||||
|
@ -128,8 +125,6 @@ async def hf_model_complete(
|
|||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction: # TODO: use JSON API
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -21,7 +21,6 @@ from tenacity import (
|
|||
)
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
locate_json_string_body_from_string,
|
||||
)
|
||||
from lightrag.exceptions import (
|
||||
APIConnectionError,
|
||||
|
@ -157,7 +156,7 @@ async def llama_index_complete(
|
|||
if history_messages is None:
|
||||
history_messages = []
|
||||
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
result = await llama_index_complete_if_cache(
|
||||
kwargs.get("llm_instance"),
|
||||
prompt,
|
||||
|
@ -165,8 +164,6 @@ async def llama_index_complete(
|
|||
history_messages=history_messages,
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction:
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -27,7 +27,6 @@ from tenacity import (
|
|||
)
|
||||
from lightrag.utils import (
|
||||
wrap_embedding_func_with_attrs,
|
||||
locate_json_string_body_from_string,
|
||||
safe_unicode_decode,
|
||||
logger,
|
||||
)
|
||||
|
@ -418,7 +417,7 @@ async def nvidia_openai_complete(
|
|||
) -> str:
|
||||
if history_messages is None:
|
||||
history_messages = []
|
||||
keyword_extraction = kwargs.pop("keyword_extraction", None)
|
||||
kwargs.pop("keyword_extraction", None)
|
||||
result = await openai_complete_if_cache(
|
||||
"nvidia/llama-3.1-nemotron-70b-instruct", # context length 128k
|
||||
prompt,
|
||||
|
@ -427,8 +426,6 @@ async def nvidia_openai_complete(
|
|||
base_url="https://integrate.api.nvidia.com/v1",
|
||||
**kwargs,
|
||||
)
|
||||
if keyword_extraction: # TODO: use JSON API
|
||||
return locate_json_string_body_from_string(result)
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@ import asyncio
|
|||
import json
|
||||
import re
|
||||
import os
|
||||
import json_repair
|
||||
from typing import Any, AsyncIterator
|
||||
from collections import Counter, defaultdict
|
||||
|
||||
|
@ -1781,10 +1782,10 @@ async def extract_keywords_only(
|
|||
)
|
||||
if cached_response is not None:
|
||||
try:
|
||||
keywords_data = json.loads(cached_response)
|
||||
return keywords_data["high_level_keywords"], keywords_data[
|
||||
"low_level_keywords"
|
||||
]
|
||||
keywords_data = json_repair.loads(cached_response)
|
||||
return keywords_data.get("high_level_keywords", []), keywords_data.get(
|
||||
"low_level_keywords", []
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
logger.warning(
|
||||
"Invalid cache format for keywords, proceeding with extraction"
|
||||
|
@ -1832,12 +1833,11 @@ async def extract_keywords_only(
|
|||
|
||||
# 6. Parse out JSON from the LLM response
|
||||
result = remove_think_tags(result)
|
||||
match = re.search(r"\{.*?\}", result, re.DOTALL)
|
||||
if not match:
|
||||
logger.error("No JSON-like structure found in the LLM respond.")
|
||||
return [], []
|
||||
try:
|
||||
keywords_data = json.loads(match.group(0))
|
||||
keywords_data = json_repair.loads(result)
|
||||
if not keywords_data:
|
||||
logger.error("No JSON-like structure found in the LLM respond.")
|
||||
return [], []
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"JSON parsing error: {e}")
|
||||
logger.error(f"LLM respond: {result}")
|
||||
|
|
|
@ -248,43 +248,6 @@ class EmbeddingFunc:
|
|||
return await self.func(*args, **kwargs)
|
||||
|
||||
|
||||
def locate_json_string_body_from_string(content: str) -> str | None:
|
||||
"""Locate the JSON string body from a string"""
|
||||
try:
|
||||
maybe_json_str = re.search(r"{.*}", content, re.DOTALL)
|
||||
if maybe_json_str is not None:
|
||||
maybe_json_str = maybe_json_str.group(0)
|
||||
maybe_json_str = maybe_json_str.replace("\\n", "")
|
||||
maybe_json_str = maybe_json_str.replace("\n", "")
|
||||
maybe_json_str = maybe_json_str.replace("'", '"')
|
||||
# json.loads(maybe_json_str) # don't check here, cannot validate schema after all
|
||||
return maybe_json_str
|
||||
except Exception:
|
||||
pass
|
||||
# try:
|
||||
# content = (
|
||||
# content.replace(kw_prompt[:-1], "")
|
||||
# .replace("user", "")
|
||||
# .replace("model", "")
|
||||
# .strip()
|
||||
# )
|
||||
# maybe_json_str = "{" + content.split("{")[1].split("}")[0] + "}"
|
||||
# json.loads(maybe_json_str)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def convert_response_to_json(response: str) -> dict[str, Any]:
|
||||
json_str = locate_json_string_body_from_string(response)
|
||||
assert json_str is not None, f"Unable to parse JSON from response: {response}"
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
return data
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Failed to parse JSON: {json_str}")
|
||||
raise e from None
|
||||
|
||||
|
||||
def compute_args_hash(*args: Any) -> str:
|
||||
"""Compute a hash for the given arguments.
|
||||
Args:
|
||||
|
|
|
@ -25,6 +25,7 @@ dependencies = [
|
|||
"configparser",
|
||||
"dotenv",
|
||||
"future",
|
||||
"json-repair",
|
||||
"nano-vectordb",
|
||||
"networkx",
|
||||
"numpy",
|
||||
|
|
Loading…
Reference in New Issue